mosdns/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go
dengxiongjian 819576c450
Some checks failed
Test mosdns / build (push) Has been cancelled
优化项目
1. 增强 mikrotik_addresslist 插件
新增 domain_files 参数支持
自动域名匹配功能
保持原有所有功能不变
向后兼容,不影响现有用法
2. 核心功能实现
GFW 域名分流:gfwlist.out.txt 仅用于分流,不写入任何设备
多设备支持:a.txt → 设备A,b.txt → 设备B
自动匹配:插件自动检查域名是否在其域名文件中
性能优化:内存缓存、异步处理、智能跳过
3. 配置大幅简化
从 ~60 行复杂配置减少到 ~15 行
不需要手动定义 domain_set
不需要复杂的 sequence 逻辑
添加新设备只需要几行配置
2025-10-14 22:40:50 +08:00

985 lines
27 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package mikrotik_addresslist
import (
"bytes"
"context"
"fmt"
"net/netip"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/matcher/domain"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/miekg/dns"
"go.uber.org/zap"
routeros "github.com/go-routeros/routeros/v3"
)
// verifyTask 验证任务
type verifyTask struct {
listName string
cidrAddr string
retries int
}
// loadDomainFile 从文件加载域名到匹配器
func loadDomainFile(m *domain.MixMatcher[struct{}], file string) error {
if len(file) == 0 {
return nil
}
b, err := os.ReadFile(file)
if err != nil {
return fmt.Errorf("failed to read domain file: %w", err)
}
if err := domain.LoadFromTextReader[struct{}](m, bytes.NewReader(b), nil); err != nil {
return fmt.Errorf("failed to parse domain file: %w", err)
}
return nil
}
type mikrotikAddressListPlugin struct {
args *Args
conn *routeros.Client
log *zap.Logger
// 🆕 新增:域名匹配器(用于自动匹配)
domainMatcher domain.Matcher[struct{}] // 如果配置了 domain_files则只处理匹配的域名
// 并发控制
workerPool chan struct{}
verifyPool chan struct{} // 专门用于验证的工作池
wg sync.WaitGroup
mu sync.RWMutex // 保护连接的重连操作
isConnected bool // 连接状态标记
// 内存缓存
cache map[string]time.Time // key: "listName:cidrAddr", value: 添加时间
cacheMu sync.RWMutex // 保护缓存访问
cacheTTL time.Duration // 缓存 TTL默认 1 小时
// 验证队列
verifyQueue chan verifyTask
stopVerify chan struct{}
}
func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) {
// 设置默认值优化为支持所有IP地址写入
if args.Mask4 == 0 {
args.Mask4 = 32 // 默认单个IP掩码确保每个IP都被单独添加
}
if args.Mask6 == 0 {
args.Mask6 = 128 // 默认单个IP掩码确保每个IP都被单独添加
}
if args.Port == 0 {
args.Port = 8728
}
if args.Timeout == 0 {
args.Timeout = 10
}
// 默认启用添加所有IP功能
if !args.AddAllIPs {
args.AddAllIPs = true
}
// 构建连接地址
addr := fmt.Sprintf("%s:%d", args.Host, args.Port)
// 创建 MikroTik 连接
conn, err := routeros.Dial(addr, args.Username, args.Password)
if err != nil {
return nil, fmt.Errorf("failed to connect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to test MikroTik connection: %w", err)
}
// 设置工作池大小(可以根据需要调整)
workerCount := 10 // 并发工作线程数
verifyCount := 5 // 验证工作线程数
// 设置缓存 TTL
cacheTTL := time.Hour // 默认 1 小时
if args.CacheTTL > 0 {
cacheTTL = time.Duration(args.CacheTTL) * time.Second
}
plugin := &mikrotikAddressListPlugin{
args: args,
conn: conn,
log: zap.L().Named("mikrotik_addresslist"),
workerPool: make(chan struct{}, workerCount),
verifyPool: make(chan struct{}, verifyCount),
cache: make(map[string]time.Time),
cacheTTL: cacheTTL,
isConnected: true,
verifyQueue: make(chan verifyTask, 100), // 验证任务队列
stopVerify: make(chan struct{}),
}
// 🆕 新增:初始化域名匹配器
if len(args.DomainFiles) > 0 {
domainMatcher := domain.NewMixMatcher[struct{}]()
domainMatcher.SetDefaultMatcher(domain.MatcherDomain) // 设置默认匹配器
for _, file := range args.DomainFiles {
if err := loadDomainFile(domainMatcher, file); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to load domain file %s: %w", file, err)
}
}
plugin.domainMatcher = domainMatcher
plugin.log.Info("domain matcher initialized",
zap.Strings("domain_files", args.DomainFiles))
}
// 启动验证工作协程
if args.VerifyAdd {
for i := 0; i < verifyCount; i++ {
go plugin.verifyWorker()
}
}
// 记录连接成功信息
plugin.log.Info("successfully connected to MikroTik",
zap.String("host", args.Host),
zap.Int("port", args.Port),
zap.String("username", args.Username),
zap.String("address_list4", args.AddressList4),
zap.Int("worker_count", workerCount),
zap.Int("verify_count", verifyCount),
zap.Bool("verify_add", args.VerifyAdd),
zap.Duration("cache_ttl", cacheTTL))
return plugin, nil
}
// testMikrotikConnection 测试 MikroTik 连接是否正常
func testMikrotikConnection(conn *routeros.Client) error {
// 尝试执行一个简单的命令来测试连接
resp, err := conn.Run("/system/resource/print")
if err != nil {
return fmt.Errorf("connection test failed: %w", err)
}
// 检查响应是否有效
if len(resp.Re) == 0 {
return fmt.Errorf("connection test failed: no response from MikroTik")
}
// 记录连接测试成功
zap.L().Named("mikrotik_addresslist").Info("MikroTik connection test successful")
return nil
}
func (p *mikrotikAddressListPlugin) reconnect() error {
p.log.Info("attempting to reconnect to MikroTik")
// 关闭旧连接
if p.conn != nil {
p.conn.Close()
}
// 重新建立连接
var err error
p.conn, err = routeros.Dial(p.args.Host+":"+strconv.Itoa(p.args.Port), p.args.Username, p.args.Password)
if err != nil {
return fmt.Errorf("failed to reconnect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(p.conn); err != nil {
return fmt.Errorf("failed to test reconnection: %w", err)
}
p.log.Info("successfully reconnected to MikroTik")
return nil
}
// 生成缓存键
func (p *mikrotikAddressListPlugin) cacheKey(listName, cidrAddr string) string {
return listName + ":" + cidrAddr
}
// 检查缓存中是否存在
func (p *mikrotikAddressListPlugin) isInCache(listName, cidrAddr string) bool {
p.cacheMu.RLock()
defer p.cacheMu.RUnlock()
key := p.cacheKey(listName, cidrAddr)
if addTime, exists := p.cache[key]; exists {
// 检查是否过期
if time.Since(addTime) < p.cacheTTL {
return true
}
// 过期了,删除
p.cacheMu.RUnlock()
p.cacheMu.Lock()
delete(p.cache, key)
p.cacheMu.Unlock()
p.cacheMu.RLock()
}
return false
}
// 添加到缓存
func (p *mikrotikAddressListPlugin) addToCache(listName, cidrAddr string) {
p.cacheMu.Lock()
defer p.cacheMu.Unlock()
key := p.cacheKey(listName, cidrAddr)
p.cache[key] = time.Now()
// 清理过期的缓存项
p.cleanupExpiredCache()
}
// 清理过期的缓存项
func (p *mikrotikAddressListPlugin) cleanupExpiredCache() {
now := time.Now()
for key, addTime := range p.cache {
if now.Sub(addTime) >= p.cacheTTL {
delete(p.cache, key)
}
}
}
// 获取缓存统计信息
func (p *mikrotikAddressListPlugin) getCacheStats() (int, int) {
p.cacheMu.RLock()
defer p.cacheMu.RUnlock()
total := len(p.cache)
valid := 0
now := time.Now()
for _, addTime := range p.cache {
if now.Sub(addTime) < p.cacheTTL {
valid++
}
}
return total, valid
}
func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error {
// 检查连接是否正常
if p.conn == nil {
p.log.Error("MikroTik connection is nil")
// 不返回错误避免影响DNS响应
return nil
}
r := qCtx.R()
if r != nil {
// 获取查询的域名
var domain string
if len(qCtx.Q().Question) > 0 {
domain = strings.TrimSuffix(qCtx.Q().Question[0].Name, ".")
}
// 🆕 新增:如果配置了域名匹配器,检查域名是否匹配
if p.domainMatcher != nil {
if _, matched := p.domainMatcher.Match(domain); !matched {
p.log.Debug("domain not matched, skipping",
zap.String("domain", domain))
return nil
}
p.log.Debug("domain matched, processing",
zap.String("domain", domain))
}
p.log.Debug("processing DNS response",
zap.String("qname", domain),
zap.Int("answer_count", len(r.Answer)))
// 异步处理Mikrotik操作不阻塞DNS响应
go func(response *dns.Msg, domainName string) {
if err := p.addToAddressList(response, domainName); err != nil {
p.log.Error("failed to add addresses to MikroTik", zap.Error(err))
// 不返回错误到主流程避免影响DNS响应
}
}(r, domain)
}
return nil
}
func (p *mikrotikAddressListPlugin) Close() error {
// 停止验证工作器
if p.args.VerifyAdd {
close(p.stopVerify)
}
// 等待所有工作完成
p.wg.Wait()
// 清理缓存
p.cacheMu.Lock()
cacheSize := len(p.cache)
p.cache = nil
p.cacheMu.Unlock()
p.log.Info("plugin closed", zap.Int("cache_cleared", cacheSize))
// 关闭连接
p.mu.Lock()
defer p.mu.Unlock()
if p.conn != nil {
return p.conn.Close()
}
return nil
}
func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg, domain string) error {
p.log.Debug("starting to process DNS response",
zap.String("configured_address_list4", p.args.AddressList4),
zap.String("configured_address_list6", p.args.AddressList6),
zap.Int("answer_count", len(r.Answer)),
zap.Bool("add_all_ips", p.args.AddAllIPs),
zap.Int("max_ips", p.args.MaxIPs))
// 如果未启用添加所有IP只处理第一个IP保持兼容性
if !p.args.AddAllIPs {
p.log.Debug("add_all_ips disabled, processing only first IP")
return p.addFirstIPOnly(r, domain)
}
// 收集所有需要处理的 IPv4 和 IPv6 地址
var ipv4Addresses []netip.Addr
var ipv6Addresses []netip.Addr
for i := range r.Answer {
switch rr := r.Answer[i].(type) {
case *dns.A:
if len(p.args.AddressList4) == 0 {
p.log.Debug("skipping A record, no IPv4 address list configured")
continue
}
addr, ok := netip.AddrFromSlice(rr.A.To4())
if !ok {
p.log.Error("invalid A record", zap.String("ip", rr.A.String()))
continue // 跳过无效记录,不中断处理
}
ipv4Addresses = append(ipv4Addresses, addr)
p.log.Debug("queued A record for processing",
zap.String("ip", addr.String()),
zap.String("address_list4", p.args.AddressList4))
case *dns.AAAA:
if len(p.args.AddressList6) == 0 {
p.log.Debug("skipping AAAA record, no IPv6 address list configured")
continue
}
addr, ok := netip.AddrFromSlice(rr.AAAA.To16())
if !ok {
p.log.Error("invalid AAAA record", zap.String("ip", rr.AAAA.String()))
continue // 跳过无效记录,不中断处理
}
ipv6Addresses = append(ipv6Addresses, addr)
p.log.Debug("queued AAAA record for processing",
zap.String("ip", addr.String()),
zap.String("address_list6", p.args.AddressList6))
default:
p.log.Debug("skipping non-A/AAAA record", zap.String("type", fmt.Sprintf("%T", rr)))
continue
}
}
// 应用IP数量限制
if p.args.MaxIPs > 0 {
if len(ipv4Addresses) > p.args.MaxIPs {
p.log.Info("limiting IPv4 addresses",
zap.Int("total", len(ipv4Addresses)),
zap.Int("limit", p.args.MaxIPs),
zap.String("domain", domain))
ipv4Addresses = ipv4Addresses[:p.args.MaxIPs]
}
if len(ipv6Addresses) > p.args.MaxIPs {
p.log.Info("limiting IPv6 addresses",
zap.Int("total", len(ipv6Addresses)),
zap.Int("limit", p.args.MaxIPs),
zap.String("domain", domain))
ipv6Addresses = ipv6Addresses[:p.args.MaxIPs]
}
}
totalAddresses := len(ipv4Addresses) + len(ipv6Addresses)
if totalAddresses == 0 {
p.log.Debug("no addresses to process")
return nil
}
// 立即记录并启动异步处理,不等待任何操作
p.log.Info("queuing addresses for async processing",
zap.Int("ipv4_count", len(ipv4Addresses)),
zap.Int("ipv6_count", len(ipv6Addresses)),
zap.Int("total_count", totalAddresses),
zap.String("domain", domain))
// 异步处理IPv4地址
if len(ipv4Addresses) > 0 && len(p.args.AddressList4) > 0 {
go func(addrs []netip.Addr, listName string, mask int, domainName string) {
// 在异步线程中调整工作池大小
p.adjustWorkerPoolSize(len(addrs))
// 启动批量处理
if err := p.batchAddAddresses(addrs, listName, mask, domainName); err != nil {
p.log.Error("async IPv4 batch processing failed", zap.Error(err))
}
// 记录缓存统计信息
total, valid := p.getCacheStats()
p.log.Debug("IPv4 async processing stats",
zap.Int("processed_count", len(addrs)),
zap.Int("cache_total", total),
zap.Int("cache_valid", valid),
zap.String("domain", domainName))
}(ipv4Addresses, p.args.AddressList4, p.args.Mask4, domain)
}
// 异步处理IPv6地址
if len(ipv6Addresses) > 0 && len(p.args.AddressList6) > 0 {
go func(addrs []netip.Addr, listName string, mask int, domainName string) {
// 在异步线程中调整工作池大小
p.adjustWorkerPoolSize(len(addrs))
// 启动批量处理
if err := p.batchAddAddresses(addrs, listName, mask, domainName); err != nil {
p.log.Error("async IPv6 batch processing failed", zap.Error(err))
}
// 记录缓存统计信息
total, valid := p.getCacheStats()
p.log.Debug("IPv6 async processing stats",
zap.Int("processed_count", len(addrs)),
zap.Int("cache_total", total),
zap.Int("cache_valid", valid),
zap.String("domain", domainName))
}(ipv6Addresses, p.args.AddressList6, p.args.Mask6, domain)
}
// 立即返回,不等待任何异步操作
return nil
}
// addFirstIPOnly 兼容性函数只添加第一个IP地址向后兼容
func (p *mikrotikAddressListPlugin) addFirstIPOnly(r *dns.Msg, domain string) error {
for i := range r.Answer {
switch rr := r.Answer[i].(type) {
case *dns.A:
if len(p.args.AddressList4) == 0 {
continue
}
addr, ok := netip.AddrFromSlice(rr.A.To4())
if !ok {
continue
}
// 只处理第一个IPv4地址
p.log.Debug("processing first IPv4 address only",
zap.String("ip", addr.String()),
zap.String("domain", domain))
go func(address netip.Addr, listName string, mask int, domainName string) {
if err := p.addAddressToMikrotik(address, listName, mask, domainName); err != nil {
p.log.Error("failed to add first IPv4 address", zap.Error(err))
}
}(addr, p.args.AddressList4, p.args.Mask4, domain)
return nil // 只处理第一个,然后返回
case *dns.AAAA:
if len(p.args.AddressList6) == 0 {
continue
}
addr, ok := netip.AddrFromSlice(rr.AAAA.To16())
if !ok {
continue
}
// 只处理第一个IPv6地址
p.log.Debug("processing first IPv6 address only",
zap.String("ip", addr.String()),
zap.String("domain", domain))
go func(address netip.Addr, listName string, mask int, domainName string) {
if err := p.addAddressToMikrotik(address, listName, mask, domainName); err != nil {
p.log.Error("failed to add first IPv6 address", zap.Error(err))
}
}(addr, p.args.AddressList6, p.args.Mask6, domain)
return nil // 只处理第一个,然后返回
}
}
p.log.Debug("no valid addresses found for first IP processing")
return nil
}
func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listName string, mask int, domain string) error {
p.log.Debug("addAddressToMikrotik called",
zap.String("addr", addr.String()),
zap.String("listName", listName),
zap.Int("mask", mask))
// 构建 CIDR 格式的地址
// 为了支持多个IP地址写入我们有两种选择
// 1. 写入具体的IP地址/32 for IPv4, /128 for IPv6
// 2. 写入网段地址(使用配置的掩码)
// 根据用户需求,这里优化为支持两种模式
var cidrAddr string
if addr.Is4() {
if p.args.Mask4 == 32 {
// 如果掩码是32直接使用IP地址这样每个IP都会被单独添加
cidrAddr = addr.String() + "/32"
} else {
// 使用网段地址将多个IP归并到同一网段
networkAddr := netip.PrefixFrom(addr, p.args.Mask4).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask4)
}
} else {
if p.args.Mask6 == 128 {
// 如果掩码是128直接使用IP地址这样每个IP都会被单独添加
cidrAddr = addr.String() + "/128"
} else {
// 使用网段地址将多个IP归并到同一网段
networkAddr := netip.PrefixFrom(addr, p.args.Mask6).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask6)
}
}
p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName))
// 首先检查内存缓存
if p.isInCache(listName, cidrAddr) {
p.log.Debug("address found in cache, skipping", zap.String("cidr", cidrAddr), zap.String("list", listName))
return nil
}
// 跳过耗时的 MikroTik 存在性检查,直接尝试添加
// 依赖 Mikrotik 的内置重复检查和错误处理
p.log.Debug("skipping existence check, will attempt direct add", zap.String("cidr", cidrAddr), zap.String("list", listName))
// 构造 RouterOS 参数,注意必须以 = 开头!
params := []string{
"=list=" + listName,
"=address=" + cidrAddr,
}
// 使用域名作为注释优先级高于配置文件中的comment
comment := domain
if comment == "" && p.args.Comment != "" {
comment = p.args.Comment
}
if comment != "" {
params = append(params, "=comment="+comment)
}
// 添加超时时间(如果配置了)
if p.args.TimeoutAddr > 0 {
params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr))
}
p.log.Info("adding address to MikroTik",
zap.String("original_ip", addr.String()),
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.String("domain", domain),
zap.String("comment", comment),
zap.Int("timeout", p.args.TimeoutAddr))
p.log.Debug("Add to list: ", zap.Strings("params", params))
// 发送到 RouterOS优化重试机制以减少延迟
maxRetries := 2 // 减少重试次数
backoffDuration := 50 * time.Millisecond // 减少退避时间
var err error // 声明 err 变量
for i := 0; i < maxRetries; i++ {
// 使用读锁保护连接访问
p.mu.RLock()
conn := p.conn
isConnected := p.isConnected
p.mu.RUnlock()
if conn == nil || !isConnected {
p.log.Debug("connection not available, attempting to reconnect")
p.mu.Lock()
p.isConnected = false
if err := p.reconnect(); err != nil {
p.mu.Unlock()
p.log.Error("failed to reconnect", zap.Error(err))
time.Sleep(backoffDuration)
backoffDuration += 25 * time.Millisecond // 线性增加,减少总延迟
continue
}
conn = p.conn
p.mu.Unlock()
}
args := append([]string{"/ip/firewall/address-list/add"}, params...)
_, err = conn.Run(args...)
if err != nil {
if strings.Contains(err.Error(), "already have such entry") {
p.log.Debug("Address already exists", zap.String("cidr", cidrAddr))
p.addToCache(listName, cidrAddr) // 添加到缓存
return nil
}
// 检查是否为连接错误
if p.isConnectionError(err) {
p.log.Warn("connection error, will retry",
zap.String("cidr", cidrAddr),
zap.Int("retry", i+1),
zap.Error(err))
p.mu.Lock()
p.isConnected = false
p.mu.Unlock()
// 指数退避
time.Sleep(backoffDuration)
backoffDuration += 25 * time.Millisecond // 线性增加,减少总延迟
continue
}
// 其他错误,直接返回
p.log.Error("failed to add address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.Error(err))
return fmt.Errorf("failed to add address %s to list %s: %v", cidrAddr, listName, err)
}
// 成功,跳出重试循环
break
}
p.log.Info("successfully added address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName))
// 添加到缓存
p.addToCache(listName, cidrAddr)
// 如果启用了验证,提交验证任务
if p.args.VerifyAdd {
select {
case p.verifyQueue <- verifyTask{
listName: listName,
cidrAddr: cidrAddr,
retries: 0,
}:
p.log.Debug("verification task queued", zap.String("cidr", cidrAddr))
default:
p.log.Warn("verification queue full, skipping verification", zap.String("cidr", cidrAddr))
}
}
return nil
}
func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) {
// 查询地址列表中是否已存在该地址
params := []string{
"=list=" + listName,
"=address=" + address,
}
p.log.Debug("checking address existence",
zap.String("list", listName),
zap.String("address", address))
args := append([]string{"/ip/firewall/address-list/print"}, params...)
resp, err := p.conn.Run(args...)
if err != nil {
p.log.Error("failed to check address existence",
zap.String("list", listName),
zap.String("address", address),
zap.Error(err))
return false, err
}
// 如果返回结果不为空,说明地址已存在
exists := len(resp.Re) > 0
p.log.Debug("address existence check",
zap.String("list", listName),
zap.String("address", address),
zap.Bool("exists", exists))
return exists, nil
}
// batchAddAddresses 批量添加地址到MikroTik完全异步化
func (p *mikrotikAddressListPlugin) batchAddAddresses(addresses []netip.Addr, listName string, mask int, domain string) error {
if len(addresses) == 0 {
return nil
}
// 分批处理每批10个地址
batchSize := 10
totalBatches := (len(addresses) + batchSize - 1) / batchSize
p.log.Info("starting async batch processing",
zap.Int("total_addresses", len(addresses)),
zap.Int("batch_size", batchSize),
zap.Int("total_batches", totalBatches),
zap.String("domain", domain))
for i := 0; i < len(addresses); i += batchSize {
end := i + batchSize
if end > len(addresses) {
end = len(addresses)
}
batch := addresses[i:end]
batchIndex := i/batchSize + 1
// 异步处理每个批次,不等待完成
go func(batch []netip.Addr, batchIdx int) {
// 获取工作池槽位
select {
case p.workerPool <- struct{}{}:
defer func() { <-p.workerPool }()
default:
p.log.Debug("worker pool full, processing batch directly",
zap.Int("batch", batchIdx))
}
successCount := 0
errorCount := 0
for _, addr := range batch {
if err := p.addAddressToMikrotik(addr, listName, mask, domain); err != nil {
errorCount++
p.log.Debug("failed to add address in batch",
zap.String("addr", addr.String()),
zap.Int("batch", batchIdx),
zap.Error(err))
} else {
successCount++
}
}
p.log.Debug("batch processing completed",
zap.Int("batch", batchIdx),
zap.Int("success_count", successCount),
zap.Int("error_count", errorCount),
zap.String("domain", domain))
}(batch, batchIndex)
}
// 立即返回,不等待批次处理完成
p.log.Debug("all batches queued for async processing",
zap.Int("total_addresses", len(addresses)),
zap.String("domain", domain))
return nil
}
// adjustWorkerPoolSize 动态调整工作池大小
func (p *mikrotikAddressListPlugin) adjustWorkerPoolSize(addressCount int) {
var targetSize int
switch {
case addressCount <= 5:
targetSize = 3
case addressCount <= 20:
targetSize = 5
case addressCount <= 50:
targetSize = 10
default:
targetSize = 15
}
// 如果当前容量不够,创建新的工作池
if cap(p.workerPool) < targetSize {
p.log.Debug("adjusting worker pool size",
zap.Int("old_size", cap(p.workerPool)),
zap.Int("new_size", targetSize),
zap.Int("address_count", addressCount))
// 创建新的工作池
p.workerPool = make(chan struct{}, targetSize)
}
}
// isConnectionError 检查是否为连接错误
func (p *mikrotikAddressListPlugin) isConnectionError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection") ||
strings.Contains(errStr, "closed") ||
strings.Contains(errStr, "timeout")
}
// verifyWorker 验证工作器,独立处理验证任务,避免阻塞写入操作
func (p *mikrotikAddressListPlugin) verifyWorker() {
p.wg.Add(1)
defer p.wg.Done()
ticker := time.NewTicker(2 * time.Second) // 每2秒处理一批验证任务
defer ticker.Stop()
for {
select {
case <-p.stopVerify:
p.log.Info("verification worker stopping")
return
case <-ticker.C:
// 批量处理验证任务
p.processBatchVerification()
case task := <-p.verifyQueue:
// 获取验证工作池槽位
select {
case p.verifyPool <- struct{}{}:
go func(task verifyTask) {
defer func() { <-p.verifyPool }()
p.processVerificationTask(task)
}(task)
default:
// 验证池满,延迟处理
go func() {
time.Sleep(time.Second)
select {
case p.verifyQueue <- task:
default:
p.log.Warn("failed to requeue verification task",
zap.String("cidr", task.cidrAddr))
}
}()
}
}
}
}
// processBatchVerification 批量处理验证队列中的任务
func (p *mikrotikAddressListPlugin) processBatchVerification() {
var tasks []verifyTask
// 收集最多10个任务进行批量处理
for i := 0; i < 10; i++ {
select {
case task := <-p.verifyQueue:
tasks = append(tasks, task)
default:
goto exitLoop
}
}
exitLoop:
if len(tasks) == 0 {
return
}
p.log.Debug("processing batch verification", zap.Int("task_count", len(tasks)))
for _, task := range tasks {
select {
case p.verifyPool <- struct{}{}:
go func(task verifyTask) {
defer func() { <-p.verifyPool }()
p.processVerificationTask(task)
}(task)
default:
// 如果池满,重新排队
select {
case p.verifyQueue <- task:
default:
p.log.Warn("verification queue full, dropping task",
zap.String("cidr", task.cidrAddr))
}
}
}
}
// processVerificationTask 处理单个验证任务
func (p *mikrotikAddressListPlugin) processVerificationTask(task verifyTask) {
// 等待一段时间再验证让MikroTik有时间处理
time.Sleep(time.Duration(500+task.retries*200) * time.Millisecond)
exists, err := p.addressExists(task.listName, task.cidrAddr)
if err != nil {
if task.retries < 3 {
// 重试
task.retries++
p.log.Debug("verification failed, retrying",
zap.String("cidr", task.cidrAddr),
zap.Int("retries", task.retries),
zap.Error(err))
select {
case p.verifyQueue <- task:
default:
p.log.Warn("failed to requeue verification task for retry",
zap.String("cidr", task.cidrAddr))
}
} else {
p.log.Error("verification failed after max retries",
zap.String("cidr", task.cidrAddr),
zap.String("list", task.listName),
zap.Error(err))
}
return
}
if !exists {
p.log.Warn("address not found in MikroTik after add operation",
zap.String("cidr", task.cidrAddr),
zap.String("list", task.listName))
// 从缓存中移除,下次会重新尝试添加
p.cacheMu.Lock()
key := p.cacheKey(task.listName, task.cidrAddr)
delete(p.cache, key)
p.cacheMu.Unlock()
// 可以选择重新添加地址
if task.retries < 2 {
p.log.Info("attempting to re-add address",
zap.String("cidr", task.cidrAddr),
zap.String("list", task.listName))
// 这里可以重新调用添加逻辑,但要避免无限循环
// 暂时只记录警告由下次DNS查询触发重新添加
}
} else {
p.log.Debug("address verification successful",
zap.String("cidr", task.cidrAddr),
zap.String("list", task.listName))
}
}