934 lines
26 KiB
Go
934 lines
26 KiB
Go
/*
|
||
* Copyright (C) 2020-2022, IrineSistiana
|
||
*
|
||
* This file is part of mosdns.
|
||
*
|
||
* mosdns is free software: you can redistribute it and/or modify
|
||
* it under the terms of the GNU General Public License as published by
|
||
* the Free Software Foundation, either version 3 of the License, or
|
||
* (at your option) any later version.
|
||
*
|
||
* mosdns is distributed in the hope that it will be useful,
|
||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||
* GNU General Public License for more details.
|
||
*
|
||
* You should have received a copy of the GNU General Public License
|
||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||
*/
|
||
|
||
package mikrotik_addresslist
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/netip"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
|
||
"github.com/miekg/dns"
|
||
"go.uber.org/zap"
|
||
|
||
routeros "github.com/go-routeros/routeros/v3"
|
||
)
|
||
|
||
// verifyTask 验证任务
|
||
type verifyTask struct {
|
||
listName string
|
||
cidrAddr string
|
||
retries int
|
||
}
|
||
|
||
type mikrotikAddressListPlugin struct {
|
||
args *Args
|
||
conn *routeros.Client
|
||
log *zap.Logger
|
||
|
||
// 并发控制
|
||
workerPool chan struct{}
|
||
verifyPool chan struct{} // 专门用于验证的工作池
|
||
wg sync.WaitGroup
|
||
mu sync.RWMutex // 保护连接的重连操作
|
||
isConnected bool // 连接状态标记
|
||
|
||
// 内存缓存
|
||
cache map[string]time.Time // key: "listName:cidrAddr", value: 添加时间
|
||
cacheMu sync.RWMutex // 保护缓存访问
|
||
cacheTTL time.Duration // 缓存 TTL,默认 1 小时
|
||
|
||
// 验证队列
|
||
verifyQueue chan verifyTask
|
||
stopVerify chan struct{}
|
||
}
|
||
|
||
func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) {
|
||
// 设置默认值,优化为支持所有IP地址写入
|
||
if args.Mask4 == 0 {
|
||
args.Mask4 = 32 // 默认单个IP掩码,确保每个IP都被单独添加
|
||
}
|
||
if args.Mask6 == 0 {
|
||
args.Mask6 = 128 // 默认单个IP掩码,确保每个IP都被单独添加
|
||
}
|
||
if args.Port == 0 {
|
||
args.Port = 8728
|
||
}
|
||
if args.Timeout == 0 {
|
||
args.Timeout = 10
|
||
}
|
||
// 默认启用添加所有IP功能
|
||
if !args.AddAllIPs {
|
||
args.AddAllIPs = true
|
||
}
|
||
|
||
// 构建连接地址
|
||
addr := fmt.Sprintf("%s:%d", args.Host, args.Port)
|
||
|
||
// 创建 MikroTik 连接
|
||
conn, err := routeros.Dial(addr, args.Username, args.Password)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect to MikroTik: %w", err)
|
||
}
|
||
|
||
// 测试连接
|
||
if err := testMikrotikConnection(conn); err != nil {
|
||
conn.Close()
|
||
return nil, fmt.Errorf("failed to test MikroTik connection: %w", err)
|
||
}
|
||
|
||
// 设置工作池大小(可以根据需要调整)
|
||
workerCount := 10 // 并发工作线程数
|
||
verifyCount := 5 // 验证工作线程数
|
||
|
||
// 设置缓存 TTL
|
||
cacheTTL := time.Hour // 默认 1 小时
|
||
if args.CacheTTL > 0 {
|
||
cacheTTL = time.Duration(args.CacheTTL) * time.Second
|
||
}
|
||
|
||
plugin := &mikrotikAddressListPlugin{
|
||
args: args,
|
||
conn: conn,
|
||
log: zap.L().Named("mikrotik_addresslist"),
|
||
workerPool: make(chan struct{}, workerCount),
|
||
verifyPool: make(chan struct{}, verifyCount),
|
||
cache: make(map[string]time.Time),
|
||
cacheTTL: cacheTTL,
|
||
isConnected: true,
|
||
verifyQueue: make(chan verifyTask, 100), // 验证任务队列
|
||
stopVerify: make(chan struct{}),
|
||
}
|
||
|
||
// 启动验证工作协程
|
||
if args.VerifyAdd {
|
||
for i := 0; i < verifyCount; i++ {
|
||
go plugin.verifyWorker()
|
||
}
|
||
}
|
||
|
||
// 记录连接成功信息
|
||
plugin.log.Info("successfully connected to MikroTik",
|
||
zap.String("host", args.Host),
|
||
zap.Int("port", args.Port),
|
||
zap.String("username", args.Username),
|
||
zap.String("address_list4", args.AddressList4),
|
||
zap.Int("worker_count", workerCount),
|
||
zap.Int("verify_count", verifyCount),
|
||
zap.Bool("verify_add", args.VerifyAdd),
|
||
zap.Duration("cache_ttl", cacheTTL))
|
||
|
||
return plugin, nil
|
||
}
|
||
|
||
// testMikrotikConnection 测试 MikroTik 连接是否正常
|
||
func testMikrotikConnection(conn *routeros.Client) error {
|
||
// 尝试执行一个简单的命令来测试连接
|
||
resp, err := conn.Run("/system/resource/print")
|
||
if err != nil {
|
||
return fmt.Errorf("connection test failed: %w", err)
|
||
}
|
||
|
||
// 检查响应是否有效
|
||
if len(resp.Re) == 0 {
|
||
return fmt.Errorf("connection test failed: no response from MikroTik")
|
||
}
|
||
|
||
// 记录连接测试成功
|
||
zap.L().Named("mikrotik_addresslist").Info("MikroTik connection test successful")
|
||
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) reconnect() error {
|
||
p.log.Info("attempting to reconnect to MikroTik")
|
||
|
||
// 关闭旧连接
|
||
if p.conn != nil {
|
||
p.conn.Close()
|
||
}
|
||
|
||
// 重新建立连接
|
||
var err error
|
||
p.conn, err = routeros.Dial(p.args.Host+":"+strconv.Itoa(p.args.Port), p.args.Username, p.args.Password)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to reconnect to MikroTik: %w", err)
|
||
}
|
||
|
||
// 测试连接
|
||
if err := testMikrotikConnection(p.conn); err != nil {
|
||
return fmt.Errorf("failed to test reconnection: %w", err)
|
||
}
|
||
|
||
p.log.Info("successfully reconnected to MikroTik")
|
||
return nil
|
||
}
|
||
|
||
// 生成缓存键
|
||
func (p *mikrotikAddressListPlugin) cacheKey(listName, cidrAddr string) string {
|
||
return listName + ":" + cidrAddr
|
||
}
|
||
|
||
// 检查缓存中是否存在
|
||
func (p *mikrotikAddressListPlugin) isInCache(listName, cidrAddr string) bool {
|
||
p.cacheMu.RLock()
|
||
defer p.cacheMu.RUnlock()
|
||
|
||
key := p.cacheKey(listName, cidrAddr)
|
||
if addTime, exists := p.cache[key]; exists {
|
||
// 检查是否过期
|
||
if time.Since(addTime) < p.cacheTTL {
|
||
return true
|
||
}
|
||
// 过期了,删除
|
||
p.cacheMu.RUnlock()
|
||
p.cacheMu.Lock()
|
||
delete(p.cache, key)
|
||
p.cacheMu.Unlock()
|
||
p.cacheMu.RLock()
|
||
}
|
||
return false
|
||
}
|
||
|
||
// 添加到缓存
|
||
func (p *mikrotikAddressListPlugin) addToCache(listName, cidrAddr string) {
|
||
p.cacheMu.Lock()
|
||
defer p.cacheMu.Unlock()
|
||
|
||
key := p.cacheKey(listName, cidrAddr)
|
||
p.cache[key] = time.Now()
|
||
|
||
// 清理过期的缓存项
|
||
p.cleanupExpiredCache()
|
||
}
|
||
|
||
// 清理过期的缓存项
|
||
func (p *mikrotikAddressListPlugin) cleanupExpiredCache() {
|
||
now := time.Now()
|
||
for key, addTime := range p.cache {
|
||
if now.Sub(addTime) >= p.cacheTTL {
|
||
delete(p.cache, key)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取缓存统计信息
|
||
func (p *mikrotikAddressListPlugin) getCacheStats() (int, int) {
|
||
p.cacheMu.RLock()
|
||
defer p.cacheMu.RUnlock()
|
||
|
||
total := len(p.cache)
|
||
valid := 0
|
||
now := time.Now()
|
||
for _, addTime := range p.cache {
|
||
if now.Sub(addTime) < p.cacheTTL {
|
||
valid++
|
||
}
|
||
}
|
||
return total, valid
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error {
|
||
// 检查连接是否正常
|
||
if p.conn == nil {
|
||
p.log.Error("MikroTik connection is nil")
|
||
// 不返回错误,避免影响DNS响应
|
||
return nil
|
||
}
|
||
|
||
r := qCtx.R()
|
||
if r != nil {
|
||
// 获取查询的域名
|
||
var domain string
|
||
if len(qCtx.Q().Question) > 0 {
|
||
domain = strings.TrimSuffix(qCtx.Q().Question[0].Name, ".")
|
||
}
|
||
|
||
p.log.Debug("processing DNS response",
|
||
zap.String("qname", domain),
|
||
zap.Int("answer_count", len(r.Answer)))
|
||
|
||
// 异步处理Mikrotik操作,不阻塞DNS响应
|
||
go func(response *dns.Msg, domainName string) {
|
||
if err := p.addToAddressList(response, domainName); err != nil {
|
||
p.log.Error("failed to add addresses to MikroTik", zap.Error(err))
|
||
// 不返回错误到主流程,避免影响DNS响应
|
||
}
|
||
}(r, domain)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) Close() error {
|
||
// 停止验证工作器
|
||
if p.args.VerifyAdd {
|
||
close(p.stopVerify)
|
||
}
|
||
|
||
// 等待所有工作完成
|
||
p.wg.Wait()
|
||
|
||
// 清理缓存
|
||
p.cacheMu.Lock()
|
||
cacheSize := len(p.cache)
|
||
p.cache = nil
|
||
p.cacheMu.Unlock()
|
||
|
||
p.log.Info("plugin closed", zap.Int("cache_cleared", cacheSize))
|
||
|
||
// 关闭连接
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
if p.conn != nil {
|
||
return p.conn.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg, domain string) error {
|
||
p.log.Debug("starting to process DNS response",
|
||
zap.String("configured_address_list4", p.args.AddressList4),
|
||
zap.String("configured_address_list6", p.args.AddressList6),
|
||
zap.Int("answer_count", len(r.Answer)),
|
||
zap.Bool("add_all_ips", p.args.AddAllIPs),
|
||
zap.Int("max_ips", p.args.MaxIPs))
|
||
|
||
// 如果未启用添加所有IP,只处理第一个IP(保持兼容性)
|
||
if !p.args.AddAllIPs {
|
||
p.log.Debug("add_all_ips disabled, processing only first IP")
|
||
return p.addFirstIPOnly(r, domain)
|
||
}
|
||
|
||
// 收集所有需要处理的 IPv4 和 IPv6 地址
|
||
var ipv4Addresses []netip.Addr
|
||
var ipv6Addresses []netip.Addr
|
||
|
||
for i := range r.Answer {
|
||
switch rr := r.Answer[i].(type) {
|
||
case *dns.A:
|
||
if len(p.args.AddressList4) == 0 {
|
||
p.log.Debug("skipping A record, no IPv4 address list configured")
|
||
continue
|
||
}
|
||
addr, ok := netip.AddrFromSlice(rr.A.To4())
|
||
if !ok {
|
||
p.log.Error("invalid A record", zap.String("ip", rr.A.String()))
|
||
continue // 跳过无效记录,不中断处理
|
||
}
|
||
ipv4Addresses = append(ipv4Addresses, addr)
|
||
p.log.Debug("queued A record for processing",
|
||
zap.String("ip", addr.String()),
|
||
zap.String("address_list4", p.args.AddressList4))
|
||
|
||
case *dns.AAAA:
|
||
if len(p.args.AddressList6) == 0 {
|
||
p.log.Debug("skipping AAAA record, no IPv6 address list configured")
|
||
continue
|
||
}
|
||
addr, ok := netip.AddrFromSlice(rr.AAAA.To16())
|
||
if !ok {
|
||
p.log.Error("invalid AAAA record", zap.String("ip", rr.AAAA.String()))
|
||
continue // 跳过无效记录,不中断处理
|
||
}
|
||
ipv6Addresses = append(ipv6Addresses, addr)
|
||
p.log.Debug("queued AAAA record for processing",
|
||
zap.String("ip", addr.String()),
|
||
zap.String("address_list6", p.args.AddressList6))
|
||
|
||
default:
|
||
p.log.Debug("skipping non-A/AAAA record", zap.String("type", fmt.Sprintf("%T", rr)))
|
||
continue
|
||
}
|
||
}
|
||
|
||
// 应用IP数量限制
|
||
if p.args.MaxIPs > 0 {
|
||
if len(ipv4Addresses) > p.args.MaxIPs {
|
||
p.log.Info("limiting IPv4 addresses",
|
||
zap.Int("total", len(ipv4Addresses)),
|
||
zap.Int("limit", p.args.MaxIPs),
|
||
zap.String("domain", domain))
|
||
ipv4Addresses = ipv4Addresses[:p.args.MaxIPs]
|
||
}
|
||
if len(ipv6Addresses) > p.args.MaxIPs {
|
||
p.log.Info("limiting IPv6 addresses",
|
||
zap.Int("total", len(ipv6Addresses)),
|
||
zap.Int("limit", p.args.MaxIPs),
|
||
zap.String("domain", domain))
|
||
ipv6Addresses = ipv6Addresses[:p.args.MaxIPs]
|
||
}
|
||
}
|
||
|
||
totalAddresses := len(ipv4Addresses) + len(ipv6Addresses)
|
||
if totalAddresses == 0 {
|
||
p.log.Debug("no addresses to process")
|
||
return nil
|
||
}
|
||
|
||
// 立即记录并启动异步处理,不等待任何操作
|
||
p.log.Info("queuing addresses for async processing",
|
||
zap.Int("ipv4_count", len(ipv4Addresses)),
|
||
zap.Int("ipv6_count", len(ipv6Addresses)),
|
||
zap.Int("total_count", totalAddresses),
|
||
zap.String("domain", domain))
|
||
|
||
// 异步处理IPv4地址
|
||
if len(ipv4Addresses) > 0 && len(p.args.AddressList4) > 0 {
|
||
go func(addrs []netip.Addr, listName string, mask int, domainName string) {
|
||
// 在异步线程中调整工作池大小
|
||
p.adjustWorkerPoolSize(len(addrs))
|
||
|
||
// 启动批量处理
|
||
if err := p.batchAddAddresses(addrs, listName, mask, domainName); err != nil {
|
||
p.log.Error("async IPv4 batch processing failed", zap.Error(err))
|
||
}
|
||
|
||
// 记录缓存统计信息
|
||
total, valid := p.getCacheStats()
|
||
p.log.Debug("IPv4 async processing stats",
|
||
zap.Int("processed_count", len(addrs)),
|
||
zap.Int("cache_total", total),
|
||
zap.Int("cache_valid", valid),
|
||
zap.String("domain", domainName))
|
||
}(ipv4Addresses, p.args.AddressList4, p.args.Mask4, domain)
|
||
}
|
||
|
||
// 异步处理IPv6地址
|
||
if len(ipv6Addresses) > 0 && len(p.args.AddressList6) > 0 {
|
||
go func(addrs []netip.Addr, listName string, mask int, domainName string) {
|
||
// 在异步线程中调整工作池大小
|
||
p.adjustWorkerPoolSize(len(addrs))
|
||
|
||
// 启动批量处理
|
||
if err := p.batchAddAddresses(addrs, listName, mask, domainName); err != nil {
|
||
p.log.Error("async IPv6 batch processing failed", zap.Error(err))
|
||
}
|
||
|
||
// 记录缓存统计信息
|
||
total, valid := p.getCacheStats()
|
||
p.log.Debug("IPv6 async processing stats",
|
||
zap.Int("processed_count", len(addrs)),
|
||
zap.Int("cache_total", total),
|
||
zap.Int("cache_valid", valid),
|
||
zap.String("domain", domainName))
|
||
}(ipv6Addresses, p.args.AddressList6, p.args.Mask6, domain)
|
||
}
|
||
|
||
// 立即返回,不等待任何异步操作
|
||
return nil
|
||
}
|
||
|
||
// addFirstIPOnly 兼容性函数,只添加第一个IP地址(向后兼容)
|
||
func (p *mikrotikAddressListPlugin) addFirstIPOnly(r *dns.Msg, domain string) error {
|
||
for i := range r.Answer {
|
||
switch rr := r.Answer[i].(type) {
|
||
case *dns.A:
|
||
if len(p.args.AddressList4) == 0 {
|
||
continue
|
||
}
|
||
addr, ok := netip.AddrFromSlice(rr.A.To4())
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// 只处理第一个IPv4地址
|
||
p.log.Debug("processing first IPv4 address only",
|
||
zap.String("ip", addr.String()),
|
||
zap.String("domain", domain))
|
||
|
||
go func(address netip.Addr, listName string, mask int, domainName string) {
|
||
if err := p.addAddressToMikrotik(address, listName, mask, domainName); err != nil {
|
||
p.log.Error("failed to add first IPv4 address", zap.Error(err))
|
||
}
|
||
}(addr, p.args.AddressList4, p.args.Mask4, domain)
|
||
|
||
return nil // 只处理第一个,然后返回
|
||
|
||
case *dns.AAAA:
|
||
if len(p.args.AddressList6) == 0 {
|
||
continue
|
||
}
|
||
addr, ok := netip.AddrFromSlice(rr.AAAA.To16())
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// 只处理第一个IPv6地址
|
||
p.log.Debug("processing first IPv6 address only",
|
||
zap.String("ip", addr.String()),
|
||
zap.String("domain", domain))
|
||
|
||
go func(address netip.Addr, listName string, mask int, domainName string) {
|
||
if err := p.addAddressToMikrotik(address, listName, mask, domainName); err != nil {
|
||
p.log.Error("failed to add first IPv6 address", zap.Error(err))
|
||
}
|
||
}(addr, p.args.AddressList6, p.args.Mask6, domain)
|
||
|
||
return nil // 只处理第一个,然后返回
|
||
}
|
||
}
|
||
|
||
p.log.Debug("no valid addresses found for first IP processing")
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listName string, mask int, domain string) error {
|
||
p.log.Debug("addAddressToMikrotik called",
|
||
zap.String("addr", addr.String()),
|
||
zap.String("listName", listName),
|
||
zap.Int("mask", mask))
|
||
|
||
// 构建 CIDR 格式的地址
|
||
// 为了支持多个IP地址写入,我们有两种选择:
|
||
// 1. 写入具体的IP地址(/32 for IPv4, /128 for IPv6)
|
||
// 2. 写入网段地址(使用配置的掩码)
|
||
// 根据用户需求,这里优化为支持两种模式
|
||
var cidrAddr string
|
||
if addr.Is4() {
|
||
if p.args.Mask4 == 32 {
|
||
// 如果掩码是32,直接使用IP地址,这样每个IP都会被单独添加
|
||
cidrAddr = addr.String() + "/32"
|
||
} else {
|
||
// 使用网段地址,将多个IP归并到同一网段
|
||
networkAddr := netip.PrefixFrom(addr, p.args.Mask4).Addr()
|
||
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask4)
|
||
}
|
||
} else {
|
||
if p.args.Mask6 == 128 {
|
||
// 如果掩码是128,直接使用IP地址,这样每个IP都会被单独添加
|
||
cidrAddr = addr.String() + "/128"
|
||
} else {
|
||
// 使用网段地址,将多个IP归并到同一网段
|
||
networkAddr := netip.PrefixFrom(addr, p.args.Mask6).Addr()
|
||
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask6)
|
||
}
|
||
}
|
||
|
||
p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName))
|
||
|
||
// 首先检查内存缓存
|
||
if p.isInCache(listName, cidrAddr) {
|
||
p.log.Debug("address found in cache, skipping", zap.String("cidr", cidrAddr), zap.String("list", listName))
|
||
return nil
|
||
}
|
||
|
||
// 跳过耗时的 MikroTik 存在性检查,直接尝试添加
|
||
// 依赖 Mikrotik 的内置重复检查和错误处理
|
||
p.log.Debug("skipping existence check, will attempt direct add", zap.String("cidr", cidrAddr), zap.String("list", listName))
|
||
|
||
// 构造 RouterOS 参数,注意必须以 = 开头!
|
||
params := []string{
|
||
"=list=" + listName,
|
||
"=address=" + cidrAddr,
|
||
}
|
||
|
||
// 使用域名作为注释(优先级高于配置文件中的comment)
|
||
comment := domain
|
||
if comment == "" && p.args.Comment != "" {
|
||
comment = p.args.Comment
|
||
}
|
||
if comment != "" {
|
||
params = append(params, "=comment="+comment)
|
||
}
|
||
|
||
// 添加超时时间(如果配置了)
|
||
if p.args.TimeoutAddr > 0 {
|
||
params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr))
|
||
}
|
||
|
||
p.log.Info("adding address to MikroTik",
|
||
zap.String("original_ip", addr.String()),
|
||
zap.String("cidr", cidrAddr),
|
||
zap.String("list", listName),
|
||
zap.String("domain", domain),
|
||
zap.String("comment", comment),
|
||
zap.Int("timeout", p.args.TimeoutAddr))
|
||
|
||
p.log.Debug("Add to list: ", zap.Strings("params", params))
|
||
|
||
// 发送到 RouterOS,优化重试机制以减少延迟
|
||
maxRetries := 2 // 减少重试次数
|
||
backoffDuration := 50 * time.Millisecond // 减少退避时间
|
||
var err error // 声明 err 变量
|
||
|
||
for i := 0; i < maxRetries; i++ {
|
||
// 使用读锁保护连接访问
|
||
p.mu.RLock()
|
||
conn := p.conn
|
||
isConnected := p.isConnected
|
||
p.mu.RUnlock()
|
||
|
||
if conn == nil || !isConnected {
|
||
p.log.Debug("connection not available, attempting to reconnect")
|
||
p.mu.Lock()
|
||
p.isConnected = false
|
||
if err := p.reconnect(); err != nil {
|
||
p.mu.Unlock()
|
||
p.log.Error("failed to reconnect", zap.Error(err))
|
||
time.Sleep(backoffDuration)
|
||
backoffDuration += 25 * time.Millisecond // 线性增加,减少总延迟
|
||
continue
|
||
}
|
||
conn = p.conn
|
||
p.mu.Unlock()
|
||
}
|
||
|
||
args := append([]string{"/ip/firewall/address-list/add"}, params...)
|
||
_, err = conn.Run(args...)
|
||
if err != nil {
|
||
if strings.Contains(err.Error(), "already have such entry") {
|
||
p.log.Debug("Address already exists", zap.String("cidr", cidrAddr))
|
||
p.addToCache(listName, cidrAddr) // 添加到缓存
|
||
return nil
|
||
}
|
||
|
||
// 检查是否为连接错误
|
||
if p.isConnectionError(err) {
|
||
p.log.Warn("connection error, will retry",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.Int("retry", i+1),
|
||
zap.Error(err))
|
||
|
||
p.mu.Lock()
|
||
p.isConnected = false
|
||
p.mu.Unlock()
|
||
|
||
// 指数退避
|
||
time.Sleep(backoffDuration)
|
||
backoffDuration += 25 * time.Millisecond // 线性增加,减少总延迟
|
||
continue
|
||
}
|
||
|
||
// 其他错误,直接返回
|
||
p.log.Error("failed to add address to MikroTik",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.String("list", listName),
|
||
zap.Error(err))
|
||
return fmt.Errorf("failed to add address %s to list %s: %v", cidrAddr, listName, err)
|
||
}
|
||
|
||
// 成功,跳出重试循环
|
||
break
|
||
}
|
||
|
||
p.log.Info("successfully added address to MikroTik",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.String("list", listName))
|
||
|
||
// 添加到缓存
|
||
p.addToCache(listName, cidrAddr)
|
||
|
||
// 如果启用了验证,提交验证任务
|
||
if p.args.VerifyAdd {
|
||
select {
|
||
case p.verifyQueue <- verifyTask{
|
||
listName: listName,
|
||
cidrAddr: cidrAddr,
|
||
retries: 0,
|
||
}:
|
||
p.log.Debug("verification task queued", zap.String("cidr", cidrAddr))
|
||
default:
|
||
p.log.Warn("verification queue full, skipping verification", zap.String("cidr", cidrAddr))
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) {
|
||
// 查询地址列表中是否已存在该地址
|
||
params := []string{
|
||
"=list=" + listName,
|
||
"=address=" + address,
|
||
}
|
||
|
||
p.log.Debug("checking address existence",
|
||
zap.String("list", listName),
|
||
zap.String("address", address))
|
||
|
||
args := append([]string{"/ip/firewall/address-list/print"}, params...)
|
||
resp, err := p.conn.Run(args...)
|
||
if err != nil {
|
||
p.log.Error("failed to check address existence",
|
||
zap.String("list", listName),
|
||
zap.String("address", address),
|
||
zap.Error(err))
|
||
return false, err
|
||
}
|
||
|
||
// 如果返回结果不为空,说明地址已存在
|
||
exists := len(resp.Re) > 0
|
||
p.log.Debug("address existence check",
|
||
zap.String("list", listName),
|
||
zap.String("address", address),
|
||
zap.Bool("exists", exists))
|
||
|
||
return exists, nil
|
||
}
|
||
|
||
// batchAddAddresses 批量添加地址到MikroTik(完全异步化)
|
||
func (p *mikrotikAddressListPlugin) batchAddAddresses(addresses []netip.Addr, listName string, mask int, domain string) error {
|
||
if len(addresses) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 分批处理,每批10个地址
|
||
batchSize := 10
|
||
totalBatches := (len(addresses) + batchSize - 1) / batchSize
|
||
|
||
p.log.Info("starting async batch processing",
|
||
zap.Int("total_addresses", len(addresses)),
|
||
zap.Int("batch_size", batchSize),
|
||
zap.Int("total_batches", totalBatches),
|
||
zap.String("domain", domain))
|
||
|
||
for i := 0; i < len(addresses); i += batchSize {
|
||
end := i + batchSize
|
||
if end > len(addresses) {
|
||
end = len(addresses)
|
||
}
|
||
|
||
batch := addresses[i:end]
|
||
batchIndex := i/batchSize + 1
|
||
|
||
// 异步处理每个批次,不等待完成
|
||
go func(batch []netip.Addr, batchIdx int) {
|
||
// 获取工作池槽位
|
||
select {
|
||
case p.workerPool <- struct{}{}:
|
||
defer func() { <-p.workerPool }()
|
||
default:
|
||
p.log.Debug("worker pool full, processing batch directly",
|
||
zap.Int("batch", batchIdx))
|
||
}
|
||
|
||
successCount := 0
|
||
errorCount := 0
|
||
|
||
for _, addr := range batch {
|
||
if err := p.addAddressToMikrotik(addr, listName, mask, domain); err != nil {
|
||
errorCount++
|
||
p.log.Debug("failed to add address in batch",
|
||
zap.String("addr", addr.String()),
|
||
zap.Int("batch", batchIdx),
|
||
zap.Error(err))
|
||
} else {
|
||
successCount++
|
||
}
|
||
}
|
||
|
||
p.log.Debug("batch processing completed",
|
||
zap.Int("batch", batchIdx),
|
||
zap.Int("success_count", successCount),
|
||
zap.Int("error_count", errorCount),
|
||
zap.String("domain", domain))
|
||
}(batch, batchIndex)
|
||
}
|
||
|
||
// 立即返回,不等待批次处理完成
|
||
p.log.Debug("all batches queued for async processing",
|
||
zap.Int("total_addresses", len(addresses)),
|
||
zap.String("domain", domain))
|
||
|
||
return nil
|
||
}
|
||
|
||
// adjustWorkerPoolSize 动态调整工作池大小
|
||
func (p *mikrotikAddressListPlugin) adjustWorkerPoolSize(addressCount int) {
|
||
var targetSize int
|
||
switch {
|
||
case addressCount <= 5:
|
||
targetSize = 3
|
||
case addressCount <= 20:
|
||
targetSize = 5
|
||
case addressCount <= 50:
|
||
targetSize = 10
|
||
default:
|
||
targetSize = 15
|
||
}
|
||
|
||
// 如果当前容量不够,创建新的工作池
|
||
if cap(p.workerPool) < targetSize {
|
||
p.log.Debug("adjusting worker pool size",
|
||
zap.Int("old_size", cap(p.workerPool)),
|
||
zap.Int("new_size", targetSize),
|
||
zap.Int("address_count", addressCount))
|
||
|
||
// 创建新的工作池
|
||
p.workerPool = make(chan struct{}, targetSize)
|
||
}
|
||
}
|
||
|
||
// isConnectionError 检查是否为连接错误
|
||
func (p *mikrotikAddressListPlugin) isConnectionError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
errStr := err.Error()
|
||
return strings.Contains(errStr, "EOF") ||
|
||
strings.Contains(errStr, "connection") ||
|
||
strings.Contains(errStr, "closed") ||
|
||
strings.Contains(errStr, "timeout")
|
||
}
|
||
|
||
// verifyWorker 验证工作器,独立处理验证任务,避免阻塞写入操作
|
||
func (p *mikrotikAddressListPlugin) verifyWorker() {
|
||
p.wg.Add(1)
|
||
defer p.wg.Done()
|
||
|
||
ticker := time.NewTicker(2 * time.Second) // 每2秒处理一批验证任务
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-p.stopVerify:
|
||
p.log.Info("verification worker stopping")
|
||
return
|
||
|
||
case <-ticker.C:
|
||
// 批量处理验证任务
|
||
p.processBatchVerification()
|
||
|
||
case task := <-p.verifyQueue:
|
||
// 获取验证工作池槽位
|
||
select {
|
||
case p.verifyPool <- struct{}{}:
|
||
go func(task verifyTask) {
|
||
defer func() { <-p.verifyPool }()
|
||
p.processVerificationTask(task)
|
||
}(task)
|
||
default:
|
||
// 验证池满,延迟处理
|
||
go func() {
|
||
time.Sleep(time.Second)
|
||
select {
|
||
case p.verifyQueue <- task:
|
||
default:
|
||
p.log.Warn("failed to requeue verification task",
|
||
zap.String("cidr", task.cidrAddr))
|
||
}
|
||
}()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// processBatchVerification 批量处理验证队列中的任务
|
||
func (p *mikrotikAddressListPlugin) processBatchVerification() {
|
||
var tasks []verifyTask
|
||
|
||
// 收集最多10个任务进行批量处理
|
||
for i := 0; i < 10; i++ {
|
||
select {
|
||
case task := <-p.verifyQueue:
|
||
tasks = append(tasks, task)
|
||
default:
|
||
goto exitLoop
|
||
}
|
||
}
|
||
exitLoop:
|
||
|
||
if len(tasks) == 0 {
|
||
return
|
||
}
|
||
|
||
p.log.Debug("processing batch verification", zap.Int("task_count", len(tasks)))
|
||
|
||
for _, task := range tasks {
|
||
select {
|
||
case p.verifyPool <- struct{}{}:
|
||
go func(task verifyTask) {
|
||
defer func() { <-p.verifyPool }()
|
||
p.processVerificationTask(task)
|
||
}(task)
|
||
default:
|
||
// 如果池满,重新排队
|
||
select {
|
||
case p.verifyQueue <- task:
|
||
default:
|
||
p.log.Warn("verification queue full, dropping task",
|
||
zap.String("cidr", task.cidrAddr))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// processVerificationTask 处理单个验证任务
|
||
func (p *mikrotikAddressListPlugin) processVerificationTask(task verifyTask) {
|
||
// 等待一段时间再验证,让MikroTik有时间处理
|
||
time.Sleep(time.Duration(500+task.retries*200) * time.Millisecond)
|
||
|
||
exists, err := p.addressExists(task.listName, task.cidrAddr)
|
||
if err != nil {
|
||
if task.retries < 3 {
|
||
// 重试
|
||
task.retries++
|
||
p.log.Debug("verification failed, retrying",
|
||
zap.String("cidr", task.cidrAddr),
|
||
zap.Int("retries", task.retries),
|
||
zap.Error(err))
|
||
|
||
select {
|
||
case p.verifyQueue <- task:
|
||
default:
|
||
p.log.Warn("failed to requeue verification task for retry",
|
||
zap.String("cidr", task.cidrAddr))
|
||
}
|
||
} else {
|
||
p.log.Error("verification failed after max retries",
|
||
zap.String("cidr", task.cidrAddr),
|
||
zap.String("list", task.listName),
|
||
zap.Error(err))
|
||
}
|
||
return
|
||
}
|
||
|
||
if !exists {
|
||
p.log.Warn("address not found in MikroTik after add operation",
|
||
zap.String("cidr", task.cidrAddr),
|
||
zap.String("list", task.listName))
|
||
|
||
// 从缓存中移除,下次会重新尝试添加
|
||
p.cacheMu.Lock()
|
||
key := p.cacheKey(task.listName, task.cidrAddr)
|
||
delete(p.cache, key)
|
||
p.cacheMu.Unlock()
|
||
|
||
// 可以选择重新添加地址
|
||
if task.retries < 2 {
|
||
p.log.Info("attempting to re-add address",
|
||
zap.String("cidr", task.cidrAddr),
|
||
zap.String("list", task.listName))
|
||
|
||
// 这里可以重新调用添加逻辑,但要避免无限循环
|
||
// 暂时只记录警告,由下次DNS查询触发重新添加
|
||
}
|
||
} else {
|
||
p.log.Debug("address verification successful",
|
||
zap.String("cidr", task.cidrAddr),
|
||
zap.String("list", task.listName))
|
||
}
|
||
}
|