mosdns/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go
dengxiongjian 59a5ef4aae
Some checks failed
Test mosdns / build (push) Has been cancelled
主要优化点:
1. 连接管理优化 mikrotik_addresslist_impl.go:132
    - 添加连接状态管理和重连锁机制
    - 改进重连逻辑,防止并发重连
  2. 缓存机制增强 mikrotik_addresslist_impl.go:162-202
    - 优化缓存锁使用,避免死锁
    - 添加缓存大小限制和LRU驱逐策略
    - 定期清理过期缓存项
  3. 智能重试机制 mikrotik_addresslist_impl.go:420
    - 指数退避算法
    - 更智能的连接错误识别
    - 改进的错误处理
  4. 动态并发控制 mikrotik_addresslist_impl.go:589
    - 根据地址数量动态调整工作池大小
    - 批量处理优化
  5. 性能监控改进
    - 更详细的日志记录
    - 缓存统计信息
    - 处理过程可观察性
2025-08-04 09:02:30 +08:00

588 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package mikrotik_addresslist
import (
"context"
"fmt"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/miekg/dns"
"go.uber.org/zap"
routeros "github.com/go-routeros/routeros/v3"
)
type mikrotikAddressListPlugin struct {
args *Args
conn *routeros.Client
log *zap.Logger
// 并发控制
workerPool chan struct{}
wg sync.WaitGroup
mu sync.RWMutex // 保护连接的重连操作
// 内存缓存
cache map[string]time.Time // key: "listName:cidrAddr", value: 添加时间
cacheMu sync.RWMutex // 保护缓存访问
cacheTTL time.Duration // 缓存 TTL默认 1 小时
}
func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) {
if args.Mask4 == 0 {
args.Mask4 = 24
}
if args.Mask6 == 0 {
args.Mask6 = 32
}
if args.Port == 0 {
args.Port = 8728
}
if args.Timeout == 0 {
args.Timeout = 10
}
// 构建连接地址
addr := fmt.Sprintf("%s:%d", args.Host, args.Port)
// 创建 MikroTik 连接
conn, err := routeros.Dial(addr, args.Username, args.Password)
if err != nil {
return nil, fmt.Errorf("failed to connect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to test MikroTik connection: %w", err)
}
// 设置工作池大小(可以根据需要调整)
workerCount := 10 // 并发工作线程数
// 设置缓存 TTL
cacheTTL := time.Hour // 默认 1 小时
if args.CacheTTL > 0 {
cacheTTL = time.Duration(args.CacheTTL) * time.Second
}
plugin := &mikrotikAddressListPlugin{
args: args,
conn: conn,
log: zap.L().Named("mikrotik_addresslist"),
workerPool: make(chan struct{}, workerCount),
cache: make(map[string]time.Time),
cacheTTL: cacheTTL,
}
// 记录连接成功信息
plugin.log.Info("successfully connected to MikroTik",
zap.String("host", args.Host),
zap.Int("port", args.Port),
zap.String("username", args.Username),
zap.String("address_list4", args.AddressList4),
zap.Int("worker_count", workerCount),
zap.Duration("cache_ttl", cacheTTL))
return plugin, nil
}
// testMikrotikConnection 测试 MikroTik 连接是否正常
func testMikrotikConnection(conn *routeros.Client) error {
// 尝试执行一个简单的命令来测试连接
resp, err := conn.Run("/system/resource/print")
if err != nil {
return fmt.Errorf("connection test failed: %w", err)
}
// 检查响应是否有效
if len(resp.Re) == 0 {
return fmt.Errorf("connection test failed: no response from MikroTik")
}
// 记录连接测试成功
zap.L().Named("mikrotik_addresslist").Info("MikroTik connection test successful")
return nil
}
func (p *mikrotikAddressListPlugin) reconnect() error {
p.log.Info("attempting to reconnect to MikroTik")
// 关闭旧连接
if p.conn != nil {
p.conn.Close()
}
// 重新建立连接
var err error
p.conn, err = routeros.Dial(p.args.Host+":"+strconv.Itoa(p.args.Port), p.args.Username, p.args.Password)
if err != nil {
return fmt.Errorf("failed to reconnect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(p.conn); err != nil {
return fmt.Errorf("failed to test reconnection: %w", err)
}
p.log.Info("successfully reconnected to MikroTik")
return nil
}
// 生成缓存键
func (p *mikrotikAddressListPlugin) cacheKey(listName, cidrAddr string) string {
return listName + ":" + cidrAddr
}
// 检查缓存中是否存在
func (p *mikrotikAddressListPlugin) isInCache(listName, cidrAddr string) bool {
p.cacheMu.RLock()
defer p.cacheMu.RUnlock()
key := p.cacheKey(listName, cidrAddr)
if addTime, exists := p.cache[key]; exists {
// 检查是否过期
if time.Since(addTime) < p.cacheTTL {
return true
}
// 过期了,删除
p.cacheMu.RUnlock()
p.cacheMu.Lock()
delete(p.cache, key)
p.cacheMu.Unlock()
p.cacheMu.RLock()
}
return false
}
// 添加到缓存
func (p *mikrotikAddressListPlugin) addToCache(listName, cidrAddr string) {
p.cacheMu.Lock()
defer p.cacheMu.Unlock()
key := p.cacheKey(listName, cidrAddr)
p.cache[key] = time.Now()
// 清理过期的缓存项
p.cleanupExpiredCache()
}
// 清理过期的缓存项
func (p *mikrotikAddressListPlugin) cleanupExpiredCache() {
now := time.Now()
for key, addTime := range p.cache {
if now.Sub(addTime) >= p.cacheTTL {
delete(p.cache, key)
}
}
}
// 获取缓存统计信息
func (p *mikrotikAddressListPlugin) getCacheStats() (int, int) {
p.cacheMu.RLock()
defer p.cacheMu.RUnlock()
total := len(p.cache)
valid := 0
now := time.Now()
for _, addTime := range p.cache {
if now.Sub(addTime) < p.cacheTTL {
valid++
}
}
return total, valid
}
func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error {
// 检查连接是否正常
if p.conn == nil {
p.log.Error("MikroTik connection is nil")
return fmt.Errorf("mikrotik_addresslist: connection is nil")
}
r := qCtx.R()
if r != nil {
p.log.Debug("processing DNS response",
zap.String("qname", qCtx.Q().Question[0].Name),
zap.Int("answer_count", len(r.Answer)))
if err := p.addToAddressList(r); err != nil {
p.log.Error("failed to add addresses to MikroTik", zap.Error(err))
return fmt.Errorf("mikrotik_addresslist: %w", err)
}
}
return nil
}
func (p *mikrotikAddressListPlugin) Close() error {
// 等待所有工作完成
p.wg.Wait()
// 清理缓存
p.cacheMu.Lock()
cacheSize := len(p.cache)
p.cache = nil
p.cacheMu.Unlock()
p.log.Info("plugin closed", zap.Int("cache_cleared", cacheSize))
// 关闭连接
p.mu.Lock()
defer p.mu.Unlock()
if p.conn != nil {
return p.conn.Close()
}
return nil
}
func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error {
p.log.Debug("starting to process DNS response",
zap.String("configured_address_list4", p.args.AddressList4),
zap.Int("answer_count", len(r.Answer)))
// 收集所有需要处理的 IPv4 地址
var addresses []netip.Addr
for i := range r.Answer {
switch rr := r.Answer[i].(type) {
case *dns.A:
if len(p.args.AddressList4) == 0 {
p.log.Debug("skipping A record, no IPv4 address list configured")
continue
}
addr, ok := netip.AddrFromSlice(rr.A.To4())
if !ok {
p.log.Error("invalid A record", zap.String("ip", rr.A.String()))
continue // 跳过无效记录,不中断处理
}
addresses = append(addresses, addr)
p.log.Debug("queued A record for processing",
zap.String("ip", addr.String()),
zap.String("address_list4", p.args.AddressList4))
case *dns.AAAA:
// 跳过 IPv6 记录
p.log.Debug("skipping AAAA record (IPv6 not supported)")
continue
default:
p.log.Debug("skipping non-A record", zap.String("type", fmt.Sprintf("%T", rr)))
continue
}
}
if len(addresses) == 0 {
p.log.Debug("no IPv4 addresses to process")
return nil
}
// 动态调整工作池大小
p.adjustWorkerPoolSize(len(addresses))
// 使用优化的批量处理
if err := p.batchAddAddresses(addresses, p.args.AddressList4, p.args.Mask4); err != nil {
return err
}
// 记录缓存统计信息
total, valid := p.getCacheStats()
p.log.Info("IPv4 addresses processed",
zap.Int("processed_count", len(addresses)),
zap.Int("cache_total", total),
zap.Int("cache_valid", valid))
return nil
}
func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listName string, mask int) error {
p.log.Debug("addAddressToMikrotik called",
zap.String("addr", addr.String()),
zap.String("listName", listName),
zap.Int("mask", mask))
// 构建 CIDR 格式的地址,将 IP 转换为网段地址
var cidrAddr string
if addr.Is4() {
// 将 IPv4 地址转换为网段地址(主机位清零)
networkAddr := netip.PrefixFrom(addr, p.args.Mask4).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask4)
} else {
// 将 IPv6 地址转换为网段地址(主机位清零)
networkAddr := netip.PrefixFrom(addr, p.args.Mask6).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask6)
}
p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName))
// 首先检查内存缓存
if p.isInCache(listName, cidrAddr) {
p.log.Debug("address found in cache, skipping", zap.String("cidr", cidrAddr), zap.String("list", listName))
return nil
}
// 缓存中没有,检查 MikroTik 中是否已存在
exists, err := p.addressExists(listName, cidrAddr)
if err != nil {
// 如果检查失败,可能是地址列表不存在,继续尝试添加
p.log.Debug("failed to check if address exists in MikroTik, will try to add anyway", zap.Error(err))
} else if exists {
// 地址已存在于 MikroTik 中,添加到缓存并跳过
p.log.Debug("address already exists in MikroTik, adding to cache", zap.String("cidr", cidrAddr), zap.String("list", listName))
p.addToCache(listName, cidrAddr)
return nil
}
// 构造 RouterOS 参数,注意必须以 = 开头!
params := []string{
"=list=" + listName,
"=address=" + cidrAddr,
}
// 添加注释(如果配置了)
if p.args.Comment != "" {
params = append(params, "=comment="+p.args.Comment)
}
// 添加超时时间(如果配置了)
if p.args.TimeoutAddr > 0 {
params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr))
}
p.log.Info("adding address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.String("comment", p.args.Comment),
zap.Int("timeout", p.args.TimeoutAddr))
p.log.Debug("Add to list: ", zap.Strings("params", params))
// 发送到 RouterOS带智能重试机制
maxRetries := 3
backoffDuration := 100 * time.Millisecond
for i := 0; i < maxRetries; i++ {
// 使用读锁保护连接访问
p.mu.RLock()
conn := p.conn
isConnected := p.isConnected
p.mu.RUnlock()
if conn == nil || !isConnected {
p.log.Debug("connection not available, attempting to reconnect")
p.mu.Lock()
p.isConnected = false
if err := p.reconnect(); err != nil {
p.mu.Unlock()
p.log.Error("failed to reconnect", zap.Error(err))
time.Sleep(backoffDuration)
backoffDuration *= 2 // 指数退避
continue
}
conn = p.conn
p.mu.Unlock()
}
args := append([]string{"/ip/firewall/address-list/add"}, params...)
_, err = conn.Run(args...)
if err != nil {
if strings.Contains(err.Error(), "already have such entry") {
p.log.Debug("Address already exists", zap.String("cidr", cidrAddr))
p.addToCache(listName, cidrAddr) // 添加到缓存
return nil
}
// 检查是否为连接错误
if p.isConnectionError(err) {
p.log.Warn("connection error, will retry",
zap.String("cidr", cidrAddr),
zap.Int("retry", i+1),
zap.Error(err))
p.mu.Lock()
p.isConnected = false
p.mu.Unlock()
// 指数退避
time.Sleep(backoffDuration)
backoffDuration *= 2
continue
}
// 其他错误,直接返回
p.log.Error("failed to add address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.Error(err))
return fmt.Errorf("failed to add address %s to list %s: %v", cidrAddr, listName, err)
}
// 成功,跳出重试循环
break
}
p.log.Info("successfully added address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName))
// 添加到缓存
p.addToCache(listName, cidrAddr)
return nil
}
func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) {
// 查询地址列表中是否已存在该地址
params := []string{
"=list=" + listName,
"=address=" + address,
}
p.log.Debug("checking address existence",
zap.String("list", listName),
zap.String("address", address))
args := append([]string{"/ip/firewall/address-list/print"}, params...)
resp, err := p.conn.Run(args...)
if err != nil {
p.log.Error("failed to check address existence",
zap.String("list", listName),
zap.String("address", address),
zap.Error(err))
return false, err
}
// 如果返回结果不为空,说明地址已存在
exists := len(resp.Re) > 0
p.log.Debug("address existence check",
zap.String("list", listName),
zap.String("address", address),
zap.Bool("exists", exists))
return exists, nil
}
// batchAddAddresses 批量添加地址到MikroTik批量操作优化
func (p *mikrotikAddressListPlugin) batchAddAddresses(addresses []netip.Addr, listName string, mask int) error {
if len(addresses) == 0 {
return nil
}
// 分批处理每批10个地址
batchSize := 10
var wg sync.WaitGroup
var mu sync.Mutex
var errors []error
successCount := 0
for i := 0; i < len(addresses); i += batchSize {
end := i + batchSize
if end > len(addresses) {
end = len(addresses)
}
batch := addresses[i:end]
wg.Add(1)
go func(batch []netip.Addr) {
defer wg.Done()
// 获取工作池槽位
select {
case p.workerPool <- struct{}{}:
defer func() { <-p.workerPool }()
default:
p.log.Debug("worker pool full, processing batch directly")
}
for _, addr := range batch {
if err := p.addAddressToMikrotik(addr, listName, mask); err != nil {
mu.Lock()
errors = append(errors, err)
mu.Unlock()
} else {
mu.Lock()
successCount++
mu.Unlock()
}
}
}(batch)
}
wg.Wait()
if len(errors) > 0 {
p.log.Error("batch processing completed with errors",
zap.Int("success_count", successCount),
zap.Int("error_count", len(errors)),
zap.Error(errors[0]))
return errors[0]
}
p.log.Info("batch processing completed successfully",
zap.Int("success_count", successCount),
zap.Int("total_count", len(addresses)))
return nil
}
// adjustWorkerPoolSize 动态调整工作池大小
func (p *mikrotikAddressListPlugin) adjustWorkerPoolSize(addressCount int) {
var targetSize int
switch {
case addressCount <= 5:
targetSize = 3
case addressCount <= 20:
targetSize = 5
case addressCount <= 50:
targetSize = 10
default:
targetSize = 15
}
// 如果当前容量不够,创建新的工作池
if cap(p.workerPool) < targetSize {
p.log.Debug("adjusting worker pool size",
zap.Int("old_size", cap(p.workerPool)),
zap.Int("new_size", targetSize),
zap.Int("address_count", addressCount))
// 创建新的工作池
p.workerPool = make(chan struct{}, targetSize)
}
}
// isConnectionError 检查是否为连接错误
func (p *mikrotikAddressListPlugin) isConnectionError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection") ||
strings.Contains(errStr, "closed") ||
strings.Contains(errStr, "timeout")
}