mosdns/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go
dengxiongjian 444c01d207
Some checks failed
Test mosdns / build (push) Has been cancelled
主要修改:
1. 在Exec方法中获取域名:从DNS查询中提取域名并去除末尾的点
  2. 传递域名参数:将域名参数传递给所有相关的方法
  3. 动态设置注释:优先使用域名作为注释,如果域名为空则使用配置文件中的comment
  4. 更新日志:添加域名信息到日志中便于调试
  5.添加了二次延迟,会对添加到Mikrotik中的IP进行二次验证,确定是否添加成功
2025-08-05 22:05:58 +08:00

783 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package mikrotik_addresslist
import (
"context"
"fmt"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/miekg/dns"
"go.uber.org/zap"
routeros "github.com/go-routeros/routeros/v3"
)
// verifyTask 验证任务
type verifyTask struct {
listName string
cidrAddr string
retries int
}
type mikrotikAddressListPlugin struct {
args *Args
conn *routeros.Client
log *zap.Logger
// 并发控制
workerPool chan struct{}
verifyPool chan struct{} // 专门用于验证的工作池
wg sync.WaitGroup
mu sync.RWMutex // 保护连接的重连操作
isConnected bool // 连接状态标记
// 内存缓存
cache map[string]time.Time // key: "listName:cidrAddr", value: 添加时间
cacheMu sync.RWMutex // 保护缓存访问
cacheTTL time.Duration // 缓存 TTL默认 1 小时
// 验证队列
verifyQueue chan verifyTask
stopVerify chan struct{}
}
func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) {
if args.Mask4 == 0 {
args.Mask4 = 24
}
if args.Mask6 == 0 {
args.Mask6 = 32
}
if args.Port == 0 {
args.Port = 8728
}
if args.Timeout == 0 {
args.Timeout = 10
}
// 构建连接地址
addr := fmt.Sprintf("%s:%d", args.Host, args.Port)
// 创建 MikroTik 连接
conn, err := routeros.Dial(addr, args.Username, args.Password)
if err != nil {
return nil, fmt.Errorf("failed to connect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to test MikroTik connection: %w", err)
}
// 设置工作池大小(可以根据需要调整)
workerCount := 10 // 并发工作线程数
verifyCount := 5 // 验证工作线程数
// 设置缓存 TTL
cacheTTL := time.Hour // 默认 1 小时
if args.CacheTTL > 0 {
cacheTTL = time.Duration(args.CacheTTL) * time.Second
}
plugin := &mikrotikAddressListPlugin{
args: args,
conn: conn,
log: zap.L().Named("mikrotik_addresslist"),
workerPool: make(chan struct{}, workerCount),
verifyPool: make(chan struct{}, verifyCount),
cache: make(map[string]time.Time),
cacheTTL: cacheTTL,
isConnected: true,
verifyQueue: make(chan verifyTask, 100), // 验证任务队列
stopVerify: make(chan struct{}),
}
// 启动验证工作协程
if args.VerifyAdd {
for i := 0; i < verifyCount; i++ {
go plugin.verifyWorker()
}
}
// 记录连接成功信息
plugin.log.Info("successfully connected to MikroTik",
zap.String("host", args.Host),
zap.Int("port", args.Port),
zap.String("username", args.Username),
zap.String("address_list4", args.AddressList4),
zap.Int("worker_count", workerCount),
zap.Int("verify_count", verifyCount),
zap.Bool("verify_add", args.VerifyAdd),
zap.Duration("cache_ttl", cacheTTL))
return plugin, nil
}
// testMikrotikConnection 测试 MikroTik 连接是否正常
func testMikrotikConnection(conn *routeros.Client) error {
// 尝试执行一个简单的命令来测试连接
resp, err := conn.Run("/system/resource/print")
if err != nil {
return fmt.Errorf("connection test failed: %w", err)
}
// 检查响应是否有效
if len(resp.Re) == 0 {
return fmt.Errorf("connection test failed: no response from MikroTik")
}
// 记录连接测试成功
zap.L().Named("mikrotik_addresslist").Info("MikroTik connection test successful")
return nil
}
func (p *mikrotikAddressListPlugin) reconnect() error {
p.log.Info("attempting to reconnect to MikroTik")
// 关闭旧连接
if p.conn != nil {
p.conn.Close()
}
// 重新建立连接
var err error
p.conn, err = routeros.Dial(p.args.Host+":"+strconv.Itoa(p.args.Port), p.args.Username, p.args.Password)
if err != nil {
return fmt.Errorf("failed to reconnect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(p.conn); err != nil {
return fmt.Errorf("failed to test reconnection: %w", err)
}
p.log.Info("successfully reconnected to MikroTik")
return nil
}
// 生成缓存键
func (p *mikrotikAddressListPlugin) cacheKey(listName, cidrAddr string) string {
return listName + ":" + cidrAddr
}
// 检查缓存中是否存在
func (p *mikrotikAddressListPlugin) isInCache(listName, cidrAddr string) bool {
p.cacheMu.RLock()
defer p.cacheMu.RUnlock()
key := p.cacheKey(listName, cidrAddr)
if addTime, exists := p.cache[key]; exists {
// 检查是否过期
if time.Since(addTime) < p.cacheTTL {
return true
}
// 过期了,删除
p.cacheMu.RUnlock()
p.cacheMu.Lock()
delete(p.cache, key)
p.cacheMu.Unlock()
p.cacheMu.RLock()
}
return false
}
// 添加到缓存
func (p *mikrotikAddressListPlugin) addToCache(listName, cidrAddr string) {
p.cacheMu.Lock()
defer p.cacheMu.Unlock()
key := p.cacheKey(listName, cidrAddr)
p.cache[key] = time.Now()
// 清理过期的缓存项
p.cleanupExpiredCache()
}
// 清理过期的缓存项
func (p *mikrotikAddressListPlugin) cleanupExpiredCache() {
now := time.Now()
for key, addTime := range p.cache {
if now.Sub(addTime) >= p.cacheTTL {
delete(p.cache, key)
}
}
}
// 获取缓存统计信息
func (p *mikrotikAddressListPlugin) getCacheStats() (int, int) {
p.cacheMu.RLock()
defer p.cacheMu.RUnlock()
total := len(p.cache)
valid := 0
now := time.Now()
for _, addTime := range p.cache {
if now.Sub(addTime) < p.cacheTTL {
valid++
}
}
return total, valid
}
func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error {
// 检查连接是否正常
if p.conn == nil {
p.log.Error("MikroTik connection is nil")
return fmt.Errorf("mikrotik_addresslist: connection is nil")
}
r := qCtx.R()
if r != nil {
// 获取查询的域名
var domain string
if len(qCtx.Q().Question) > 0 {
domain = strings.TrimSuffix(qCtx.Q().Question[0].Name, ".")
}
p.log.Debug("processing DNS response",
zap.String("qname", domain),
zap.Int("answer_count", len(r.Answer)))
if err := p.addToAddressList(r, domain); err != nil {
p.log.Error("failed to add addresses to MikroTik", zap.Error(err))
return fmt.Errorf("mikrotik_addresslist: %w", err)
}
}
return nil
}
func (p *mikrotikAddressListPlugin) Close() error {
// 停止验证工作器
if p.args.VerifyAdd {
close(p.stopVerify)
}
// 等待所有工作完成
p.wg.Wait()
// 清理缓存
p.cacheMu.Lock()
cacheSize := len(p.cache)
p.cache = nil
p.cacheMu.Unlock()
p.log.Info("plugin closed", zap.Int("cache_cleared", cacheSize))
// 关闭连接
p.mu.Lock()
defer p.mu.Unlock()
if p.conn != nil {
return p.conn.Close()
}
return nil
}
func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg, domain string) error {
p.log.Debug("starting to process DNS response",
zap.String("configured_address_list4", p.args.AddressList4),
zap.Int("answer_count", len(r.Answer)))
// 收集所有需要处理的 IPv4 地址
var addresses []netip.Addr
for i := range r.Answer {
switch rr := r.Answer[i].(type) {
case *dns.A:
if len(p.args.AddressList4) == 0 {
p.log.Debug("skipping A record, no IPv4 address list configured")
continue
}
addr, ok := netip.AddrFromSlice(rr.A.To4())
if !ok {
p.log.Error("invalid A record", zap.String("ip", rr.A.String()))
continue // 跳过无效记录,不中断处理
}
addresses = append(addresses, addr)
p.log.Debug("queued A record for processing",
zap.String("ip", addr.String()),
zap.String("address_list4", p.args.AddressList4))
case *dns.AAAA:
// 跳过 IPv6 记录
p.log.Debug("skipping AAAA record (IPv6 not supported)")
continue
default:
p.log.Debug("skipping non-A record", zap.String("type", fmt.Sprintf("%T", rr)))
continue
}
}
if len(addresses) == 0 {
p.log.Debug("no IPv4 addresses to process")
return nil
}
// 动态调整工作池大小
p.adjustWorkerPoolSize(len(addresses))
// 使用优化的批量处理
if err := p.batchAddAddresses(addresses, p.args.AddressList4, p.args.Mask4, domain); err != nil {
return err
}
// 记录缓存统计信息
total, valid := p.getCacheStats()
p.log.Info("IPv4 addresses processed",
zap.Int("processed_count", len(addresses)),
zap.Int("cache_total", total),
zap.Int("cache_valid", valid))
return nil
}
func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listName string, mask int, domain string) error {
p.log.Debug("addAddressToMikrotik called",
zap.String("addr", addr.String()),
zap.String("listName", listName),
zap.Int("mask", mask))
// 构建 CIDR 格式的地址,将 IP 转换为网段地址
var cidrAddr string
if addr.Is4() {
// 将 IPv4 地址转换为网段地址(主机位清零)
networkAddr := netip.PrefixFrom(addr, p.args.Mask4).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask4)
} else {
// 将 IPv6 地址转换为网段地址(主机位清零)
networkAddr := netip.PrefixFrom(addr, p.args.Mask6).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask6)
}
p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName))
// 首先检查内存缓存
if p.isInCache(listName, cidrAddr) {
p.log.Debug("address found in cache, skipping", zap.String("cidr", cidrAddr), zap.String("list", listName))
return nil
}
// 缓存中没有,检查 MikroTik 中是否已存在
exists, err := p.addressExists(listName, cidrAddr)
if err != nil {
// 如果检查失败,可能是地址列表不存在,继续尝试添加
p.log.Debug("failed to check if address exists in MikroTik, will try to add anyway", zap.Error(err))
} else if exists {
// 地址已存在于 MikroTik 中,添加到缓存并跳过
p.log.Debug("address already exists in MikroTik, adding to cache", zap.String("cidr", cidrAddr), zap.String("list", listName))
p.addToCache(listName, cidrAddr)
return nil
}
// 构造 RouterOS 参数,注意必须以 = 开头!
params := []string{
"=list=" + listName,
"=address=" + cidrAddr,
}
// 使用域名作为注释优先级高于配置文件中的comment
comment := domain
if comment == "" && p.args.Comment != "" {
comment = p.args.Comment
}
if comment != "" {
params = append(params, "=comment="+comment)
}
// 添加超时时间(如果配置了)
if p.args.TimeoutAddr > 0 {
params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr))
}
p.log.Info("adding address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.String("domain", domain),
zap.String("comment", comment),
zap.Int("timeout", p.args.TimeoutAddr))
p.log.Debug("Add to list: ", zap.Strings("params", params))
// 发送到 RouterOS带智能重试机制
maxRetries := 3
backoffDuration := 100 * time.Millisecond
for i := 0; i < maxRetries; i++ {
// 使用读锁保护连接访问
p.mu.RLock()
conn := p.conn
isConnected := p.isConnected
p.mu.RUnlock()
if conn == nil || !isConnected {
p.log.Debug("connection not available, attempting to reconnect")
p.mu.Lock()
p.isConnected = false
if err := p.reconnect(); err != nil {
p.mu.Unlock()
p.log.Error("failed to reconnect", zap.Error(err))
time.Sleep(backoffDuration)
backoffDuration *= 2 // 指数退避
continue
}
conn = p.conn
p.mu.Unlock()
}
args := append([]string{"/ip/firewall/address-list/add"}, params...)
_, err = conn.Run(args...)
if err != nil {
if strings.Contains(err.Error(), "already have such entry") {
p.log.Debug("Address already exists", zap.String("cidr", cidrAddr))
p.addToCache(listName, cidrAddr) // 添加到缓存
return nil
}
// 检查是否为连接错误
if p.isConnectionError(err) {
p.log.Warn("connection error, will retry",
zap.String("cidr", cidrAddr),
zap.Int("retry", i+1),
zap.Error(err))
p.mu.Lock()
p.isConnected = false
p.mu.Unlock()
// 指数退避
time.Sleep(backoffDuration)
backoffDuration *= 2
continue
}
// 其他错误,直接返回
p.log.Error("failed to add address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.Error(err))
return fmt.Errorf("failed to add address %s to list %s: %v", cidrAddr, listName, err)
}
// 成功,跳出重试循环
break
}
p.log.Info("successfully added address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName))
// 添加到缓存
p.addToCache(listName, cidrAddr)
// 如果启用了验证,提交验证任务
if p.args.VerifyAdd {
select {
case p.verifyQueue <- verifyTask{
listName: listName,
cidrAddr: cidrAddr,
retries: 0,
}:
p.log.Debug("verification task queued", zap.String("cidr", cidrAddr))
default:
p.log.Warn("verification queue full, skipping verification", zap.String("cidr", cidrAddr))
}
}
return nil
}
func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) {
// 查询地址列表中是否已存在该地址
params := []string{
"=list=" + listName,
"=address=" + address,
}
p.log.Debug("checking address existence",
zap.String("list", listName),
zap.String("address", address))
args := append([]string{"/ip/firewall/address-list/print"}, params...)
resp, err := p.conn.Run(args...)
if err != nil {
p.log.Error("failed to check address existence",
zap.String("list", listName),
zap.String("address", address),
zap.Error(err))
return false, err
}
// 如果返回结果不为空,说明地址已存在
exists := len(resp.Re) > 0
p.log.Debug("address existence check",
zap.String("list", listName),
zap.String("address", address),
zap.Bool("exists", exists))
return exists, nil
}
// batchAddAddresses 批量添加地址到MikroTik批量操作优化
func (p *mikrotikAddressListPlugin) batchAddAddresses(addresses []netip.Addr, listName string, mask int, domain string) error {
if len(addresses) == 0 {
return nil
}
// 分批处理每批10个地址
batchSize := 10
var wg sync.WaitGroup
var mu sync.Mutex
var errors []error
successCount := 0
for i := 0; i < len(addresses); i += batchSize {
end := i + batchSize
if end > len(addresses) {
end = len(addresses)
}
batch := addresses[i:end]
wg.Add(1)
go func(batch []netip.Addr) {
defer wg.Done()
// 获取工作池槽位
select {
case p.workerPool <- struct{}{}:
defer func() { <-p.workerPool }()
default:
p.log.Debug("worker pool full, processing batch directly")
}
for _, addr := range batch {
if err := p.addAddressToMikrotik(addr, listName, mask, domain); err != nil {
mu.Lock()
errors = append(errors, err)
mu.Unlock()
} else {
mu.Lock()
successCount++
mu.Unlock()
}
}
}(batch)
}
wg.Wait()
if len(errors) > 0 {
p.log.Error("batch processing completed with errors",
zap.Int("success_count", successCount),
zap.Int("error_count", len(errors)),
zap.Error(errors[0]))
return errors[0]
}
p.log.Info("batch processing completed successfully",
zap.Int("success_count", successCount),
zap.Int("total_count", len(addresses)))
return nil
}
// adjustWorkerPoolSize 动态调整工作池大小
func (p *mikrotikAddressListPlugin) adjustWorkerPoolSize(addressCount int) {
var targetSize int
switch {
case addressCount <= 5:
targetSize = 3
case addressCount <= 20:
targetSize = 5
case addressCount <= 50:
targetSize = 10
default:
targetSize = 15
}
// 如果当前容量不够,创建新的工作池
if cap(p.workerPool) < targetSize {
p.log.Debug("adjusting worker pool size",
zap.Int("old_size", cap(p.workerPool)),
zap.Int("new_size", targetSize),
zap.Int("address_count", addressCount))
// 创建新的工作池
p.workerPool = make(chan struct{}, targetSize)
}
}
// isConnectionError 检查是否为连接错误
func (p *mikrotikAddressListPlugin) isConnectionError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
return strings.Contains(errStr, "EOF") ||
strings.Contains(errStr, "connection") ||
strings.Contains(errStr, "closed") ||
strings.Contains(errStr, "timeout")
}
// verifyWorker 验证工作器,独立处理验证任务,避免阻塞写入操作
func (p *mikrotikAddressListPlugin) verifyWorker() {
p.wg.Add(1)
defer p.wg.Done()
ticker := time.NewTicker(2 * time.Second) // 每2秒处理一批验证任务
defer ticker.Stop()
for {
select {
case <-p.stopVerify:
p.log.Info("verification worker stopping")
return
case <-ticker.C:
// 批量处理验证任务
p.processBatchVerification()
case task := <-p.verifyQueue:
// 获取验证工作池槽位
select {
case p.verifyPool <- struct{}{}:
go func(task verifyTask) {
defer func() { <-p.verifyPool }()
p.processVerificationTask(task)
}(task)
default:
// 验证池满,延迟处理
go func() {
time.Sleep(time.Second)
select {
case p.verifyQueue <- task:
default:
p.log.Warn("failed to requeue verification task",
zap.String("cidr", task.cidrAddr))
}
}()
}
}
}
}
// processBatchVerification 批量处理验证队列中的任务
func (p *mikrotikAddressListPlugin) processBatchVerification() {
var tasks []verifyTask
// 收集最多10个任务进行批量处理
for i := 0; i < 10; i++ {
select {
case task := <-p.verifyQueue:
tasks = append(tasks, task)
default:
break
}
}
if len(tasks) == 0 {
return
}
p.log.Debug("processing batch verification", zap.Int("task_count", len(tasks)))
for _, task := range tasks {
select {
case p.verifyPool <- struct{}{}:
go func(task verifyTask) {
defer func() { <-p.verifyPool }()
p.processVerificationTask(task)
}(task)
default:
// 如果池满,重新排队
select {
case p.verifyQueue <- task:
default:
p.log.Warn("verification queue full, dropping task",
zap.String("cidr", task.cidrAddr))
}
}
}
}
// processVerificationTask 处理单个验证任务
func (p *mikrotikAddressListPlugin) processVerificationTask(task verifyTask) {
// 等待一段时间再验证让MikroTik有时间处理
time.Sleep(time.Duration(500+task.retries*200) * time.Millisecond)
exists, err := p.addressExists(task.listName, task.cidrAddr)
if err != nil {
if task.retries < 3 {
// 重试
task.retries++
p.log.Debug("verification failed, retrying",
zap.String("cidr", task.cidrAddr),
zap.Int("retries", task.retries),
zap.Error(err))
select {
case p.verifyQueue <- task:
default:
p.log.Warn("failed to requeue verification task for retry",
zap.String("cidr", task.cidrAddr))
}
} else {
p.log.Error("verification failed after max retries",
zap.String("cidr", task.cidrAddr),
zap.String("list", task.listName),
zap.Error(err))
}
return
}
if !exists {
p.log.Warn("address not found in MikroTik after add operation",
zap.String("cidr", task.cidrAddr),
zap.String("list", task.listName))
// 从缓存中移除,下次会重新尝试添加
p.cacheMu.Lock()
key := p.cacheKey(task.listName, task.cidrAddr)
delete(p.cache, key)
p.cacheMu.Unlock()
// 可以选择重新添加地址
if task.retries < 2 {
p.log.Info("attempting to re-add address",
zap.String("cidr", task.cidrAddr),
zap.String("list", task.listName))
// 这里可以重新调用添加逻辑,但要避免无限循环
// 暂时只记录警告由下次DNS查询触发重新添加
}
} else {
p.log.Debug("address verification successful",
zap.String("cidr", task.cidrAddr),
zap.String("list", task.listName))
}
}