Add concurrency support and reconnect logic in Mikrotik address list plugin; introduce worker pool for processing IP addresses
Some checks failed
Test mosdns / build (push) Has been cancelled

This commit is contained in:
dengxiongjian 2025-07-31 12:47:29 +08:00
parent cd761e8145
commit eb82f1c2f5
2 changed files with 166 additions and 31 deletions

Binary file not shown.

View File

@ -25,6 +25,7 @@ import (
"net/netip"
"strconv"
"strings"
"sync"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/miekg/dns"
@ -37,6 +38,11 @@ type mikrotikAddressListPlugin struct {
args *Args
conn *routeros.Client
log *zap.Logger
// 并发控制
workerPool chan struct{}
wg sync.WaitGroup
mu sync.RWMutex // 保护连接的重连操作
}
func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) {
@ -68,10 +74,14 @@ func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error
return nil, fmt.Errorf("failed to test MikroTik connection: %w", err)
}
// 设置工作池大小(可以根据需要调整)
workerCount := 10 // 并发工作线程数
plugin := &mikrotikAddressListPlugin{
args: args,
conn: conn,
log: zap.L().Named("mikrotik_addresslist"),
args: args,
conn: conn,
log: zap.L().Named("mikrotik_addresslist"),
workerPool: make(chan struct{}, workerCount),
}
// 记录连接成功信息
@ -79,7 +89,8 @@ func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error
zap.String("host", args.Host),
zap.Int("port", args.Port),
zap.String("username", args.Username),
zap.String("address_list4", args.AddressList4))
zap.String("address_list4", args.AddressList4),
zap.Int("worker_count", workerCount))
return plugin, nil
}
@ -103,6 +114,30 @@ func testMikrotikConnection(conn *routeros.Client) error {
return nil
}
func (p *mikrotikAddressListPlugin) reconnect() error {
p.log.Info("attempting to reconnect to MikroTik")
// 关闭旧连接
if p.conn != nil {
p.conn.Close()
}
// 重新建立连接
var err error
p.conn, err = routeros.Dial(p.args.Host+":"+strconv.Itoa(p.args.Port), p.args.Username, p.args.Password)
if err != nil {
return fmt.Errorf("failed to reconnect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(p.conn); err != nil {
return fmt.Errorf("failed to test reconnection: %w", err)
}
p.log.Info("successfully reconnected to MikroTik")
return nil
}
func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error {
// 检查连接是否正常
if p.conn == nil {
@ -125,6 +160,13 @@ func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.
}
func (p *mikrotikAddressListPlugin) Close() error {
// 等待所有工作完成
p.wg.Wait()
// 关闭连接
p.mu.Lock()
defer p.mu.Unlock()
if p.conn != nil {
return p.conn.Close()
}
@ -132,11 +174,12 @@ func (p *mikrotikAddressListPlugin) Close() error {
}
func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error {
addedCount := 0
p.log.Debug("starting to process DNS response",
zap.String("configured_address_list4", p.args.AddressList4),
zap.Int("answer_count", len(r.Answer)))
// 收集所有需要处理的 IPv4 地址
var addresses []netip.Addr
for i := range r.Answer {
switch rr := r.Answer[i].(type) {
case *dns.A:
@ -147,15 +190,12 @@ func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error {
addr, ok := netip.AddrFromSlice(rr.A.To4())
if !ok {
p.log.Error("invalid A record", zap.String("ip", rr.A.String()))
return fmt.Errorf("invalid A record with ip: %s", rr.A)
continue // 跳过无效记录,不中断处理
}
p.log.Debug("processing A record",
addresses = append(addresses, addr)
p.log.Debug("queued A record for processing",
zap.String("ip", addr.String()),
zap.String("address_list4", p.args.AddressList4))
if err := p.addAddressToMikrotik(addr, p.args.AddressList4, p.args.Mask4); err != nil {
return err
}
addedCount++
case *dns.AAAA:
// 跳过 IPv6 记录
@ -167,12 +207,61 @@ func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error {
}
}
if len(addresses) == 0 {
p.log.Debug("no IPv4 addresses to process")
return nil
}
// 并发处理所有地址
var wg sync.WaitGroup
var mu sync.Mutex
var errors []error
addedCount := 0
for _, addr := range addresses {
wg.Add(1)
go func(addr netip.Addr) {
defer wg.Done()
// 获取工作池槽位
select {
case p.workerPool <- struct{}{}:
defer func() { <-p.workerPool }()
default:
// 如果工作池满了,直接处理(避免阻塞)
p.log.Debug("worker pool full, processing directly")
}
if err := p.addAddressToMikrotik(addr, p.args.AddressList4, p.args.Mask4); err != nil {
mu.Lock()
errors = append(errors, err)
mu.Unlock()
} else {
mu.Lock()
addedCount++
mu.Unlock()
}
}(addr)
}
// 等待所有工作完成
wg.Wait()
// 记录结果
if addedCount > 0 {
p.log.Info("added IPv4 addresses to MikroTik", zap.Int("count", addedCount))
p.log.Info("concurrently added IPv4 addresses to MikroTik",
zap.Int("success_count", addedCount),
zap.Int("total_count", len(addresses)),
zap.Int("error_count", len(errors)))
} else {
p.log.Debug("no IPv4 addresses added to MikroTik")
}
// 如果有错误,返回第一个错误
if len(errors) > 0 {
return fmt.Errorf("some addresses failed to add: %v", errors[0])
}
return nil
}
@ -182,12 +271,16 @@ func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listNa
zap.String("listName", listName),
zap.Int("mask", mask))
// 构建 CIDR 格式的地址
// 构建 CIDR 格式的地址,将 IP 转换为网段地址
var cidrAddr string
if addr.Is4() {
cidrAddr = addr.String() + "/" + strconv.Itoa(p.args.Mask4)
// 将 IPv4 地址转换为网段地址(主机位清零)
networkAddr := netip.PrefixFrom(addr, p.args.Mask4).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask4)
} else {
cidrAddr = addr.String() + "/" + strconv.Itoa(p.args.Mask6)
// 将 IPv6 地址转换为网段地址(主机位清零)
networkAddr := netip.PrefixFrom(addr, p.args.Mask6).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask6)
}
p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName))
@ -227,19 +320,57 @@ func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listNa
p.log.Debug("Add to list: ", zap.Strings("params", params))
// 发送到 RouterOS
args := append([]string{"/ip/firewall/address-list/add"}, params...)
_, err = p.conn.Run(args...)
if err != nil {
if strings.Contains(err.Error(), "already have such entry") {
p.log.Debug("Already exists: ", zap.String("cidr", cidrAddr))
return nil
// 发送到 RouterOS带重试机制
maxRetries := 3
for i := 0; i < maxRetries; i++ {
// 使用读锁保护连接访问
p.mu.RLock()
conn := p.conn
p.mu.RUnlock()
if conn == nil {
p.log.Error("connection is nil")
return fmt.Errorf("connection is nil")
}
p.log.Error("failed to add address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.Error(err))
return fmt.Errorf("failed to add address %s to list %s: from RouterOS device: %v", cidrAddr, listName, err)
args := append([]string{"/ip/firewall/address-list/add"}, params...)
_, err = conn.Run(args...)
if err != nil {
if strings.Contains(err.Error(), "already have such entry") {
p.log.Debug("Already exists: ", zap.String("cidr", cidrAddr))
return nil
}
// 如果是连接错误,尝试重新连接
if strings.Contains(err.Error(), "EOF") || strings.Contains(err.Error(), "connection") {
p.log.Warn("connection error, attempting to reconnect",
zap.String("cidr", cidrAddr),
zap.Int("retry", i+1),
zap.Error(err))
// 使用写锁保护重连操作
p.mu.Lock()
if err := p.reconnect(); err != nil {
p.mu.Unlock()
p.log.Error("failed to reconnect", zap.Error(err))
continue
}
p.mu.Unlock()
// 重试
continue
}
// 其他错误,记录并返回
p.log.Error("failed to add address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.Error(err))
return fmt.Errorf("failed to add address %s to list %s: from RouterOS device: %v", cidrAddr, listName, err)
}
// 成功,跳出重试循环
break
}
p.log.Info("successfully added address to MikroTik",
@ -251,13 +382,17 @@ func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listNa
func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) {
// 查询地址列表中是否已存在该地址
query := fmt.Sprintf("?list=%s&address=%s", listName, address)
params := []string{
"=list=" + listName,
"=address=" + address,
}
p.log.Debug("checking address existence",
zap.String("list", listName),
zap.String("address", address),
zap.String("query", query))
zap.String("address", address))
resp, err := p.conn.Run("/ip/firewall/address-list/print", query)
args := append([]string{"/ip/firewall/address-list/print"}, params...)
resp, err := p.conn.Run(args...)
if err != nil {
p.log.Error("failed to check address existence",
zap.String("list", listName),