mosdns/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_optimized.go

485 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* MikroTik Address List 插件 - 性能优化版
*
* 主要优化:
* 1. 完全移除验证功能
* 2. 启动时从MikroTik加载现有IP到内存
* 3. 内存中判断IP是否存在避免重复写入
* 4. 支持/24网段掩码
*/
package mikrotik_addresslist
import (
"context"
"fmt"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/miekg/dns"
"go.uber.org/zap"
routeros "github.com/go-routeros/routeros/v3"
)
type optimizedMikrotikAddressListPlugin struct {
args *Args
conn *routeros.Client
log *zap.Logger
// 并发控制
workerPool chan struct{}
wg sync.WaitGroup
mu sync.RWMutex
isConnected bool
// 内存IP缓存 - 核心优化
ipCache map[string]map[string]bool // map[listName]map[cidrAddr]exists
cacheMu sync.RWMutex // 保护IP缓存访问
cacheTTL time.Duration
// 网段缓存,用于/24掩码优化
subnetCache map[string]map[string]time.Time // map[listName]map[subnet]addTime
subnetMu sync.RWMutex
}
func newOptimizedMikrotikAddressListPlugin(args *Args) (*optimizedMikrotikAddressListPlugin, error) {
// 设置默认值
if args.Mask4 == 0 {
args.Mask4 = 24 // 默认使用/24网段掩码
}
if args.Mask6 == 0 {
args.Mask6 = 64 // IPv6使用/64
}
if args.Port == 0 {
args.Port = 9728
}
if args.Timeout == 0 {
args.Timeout = 3 // 优化:减少超时时间
}
if !args.AddAllIPs {
args.AddAllIPs = true
}
// 构建连接地址
addr := fmt.Sprintf("%s:%d", args.Host, args.Port)
// 创建 MikroTik 连接
conn, err := routeros.Dial(addr, args.Username, args.Password)
if err != nil {
return nil, fmt.Errorf("failed to connect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to test MikroTik connection: %w", err)
}
// 设置工作池大小
workerCount := 20 // 增加工作线程数
// 设置缓存 TTL
cacheTTL := time.Hour * 2 // 2小时缓存
if args.CacheTTL > 0 {
cacheTTL = time.Duration(args.CacheTTL) * time.Second
}
plugin := &optimizedMikrotikAddressListPlugin{
args: args,
conn: conn,
log: zap.L().Named("mikrotik_optimized"),
workerPool: make(chan struct{}, workerCount),
ipCache: make(map[string]map[string]bool),
subnetCache: make(map[string]map[string]time.Time),
cacheTTL: cacheTTL,
isConnected: true,
}
// 🚀 核心优化启动时加载现有IP到内存
if err := plugin.loadExistingIPs(); err != nil {
plugin.log.Warn("failed to load existing IPs, continuing anyway", zap.Error(err))
}
plugin.log.Info("optimized MikroTik plugin initialized",
zap.String("host", args.Host),
zap.Int("port", args.Port),
zap.String("address_list4", args.AddressList4),
zap.String("address_list6", args.AddressList6),
zap.Int("worker_count", workerCount),
zap.Duration("cache_ttl", cacheTTL),
zap.Int("mask4", args.Mask4),
zap.Int("mask6", args.Mask6))
return plugin, nil
}
// 🚀 核心功能启动时从MikroTik加载现有IP
func (p *optimizedMikrotikAddressListPlugin) loadExistingIPs() error {
p.log.Info("loading existing IPs from MikroTik...")
// 加载IPv4地址列表
if p.args.AddressList4 != "" {
if err := p.loadAddressListIPs(p.args.AddressList4); err != nil {
p.log.Error("failed to load IPv4 addresses",
zap.String("list", p.args.AddressList4),
zap.Error(err))
}
}
// 加载IPv6地址列表
if p.args.AddressList6 != "" {
if err := p.loadAddressListIPs(p.args.AddressList6); err != nil {
p.log.Error("failed to load IPv6 addresses",
zap.String("list", p.args.AddressList6),
zap.Error(err))
}
}
// 打印加载统计
p.cacheMu.RLock()
totalIPs := 0
for listName, ips := range p.ipCache {
count := len(ips)
totalIPs += count
p.log.Info("loaded address list",
zap.String("list", listName),
zap.Int("ip_count", count))
}
p.cacheMu.RUnlock()
p.log.Info("finished loading existing IPs", zap.Int("total_ips", totalIPs))
return nil
}
// 从指定的address list加载所有IP
func (p *optimizedMikrotikAddressListPlugin) loadAddressListIPs(listName string) error {
// 查询指定列表的所有地址
resp, err := p.conn.Run("/ip/firewall/address-list/print", "=list="+listName)
if err != nil {
return fmt.Errorf("failed to query address list %s: %w", listName, err)
}
p.cacheMu.Lock()
defer p.cacheMu.Unlock()
// 初始化缓存
if p.ipCache[listName] == nil {
p.ipCache[listName] = make(map[string]bool)
}
// 解析响应并添加到缓存
for _, re := range resp.Re {
if address, ok := re.Map["address"]; ok {
p.ipCache[listName][address] = true
// 如果是网段地址,也缓存网段信息
if strings.Contains(address, "/") {
p.subnetMu.Lock()
if p.subnetCache[listName] == nil {
p.subnetCache[listName] = make(map[string]time.Time)
}
p.subnetCache[listName][address] = time.Now()
p.subnetMu.Unlock()
}
}
}
return nil
}
// 🚀 优化的IP存在性检查纯内存操作
func (p *optimizedMikrotikAddressListPlugin) isIPInMemoryCache(listName, cidrAddr string) bool {
p.cacheMu.RLock()
defer p.cacheMu.RUnlock()
if listCache, exists := p.ipCache[listName]; exists {
return listCache[cidrAddr]
}
return false
}
// 🚀 优化的网段存在性检查
func (p *optimizedMikrotikAddressListPlugin) isSubnetInCache(listName, subnet string) bool {
p.subnetMu.RLock()
defer p.subnetMu.RUnlock()
if subnetMap, exists := p.subnetCache[listName]; exists {
if addTime, exists := subnetMap[subnet]; exists {
// 检查是否过期
return time.Since(addTime) < p.cacheTTL
}
}
return false
}
// 添加IP到内存缓存
func (p *optimizedMikrotikAddressListPlugin) addToMemoryCache(listName, cidrAddr string) {
p.cacheMu.Lock()
defer p.cacheMu.Unlock()
if p.ipCache[listName] == nil {
p.ipCache[listName] = make(map[string]bool)
}
p.ipCache[listName][cidrAddr] = true
// 如果是网段,也更新网段缓存
if strings.Contains(cidrAddr, "/") {
p.subnetMu.Lock()
if p.subnetCache[listName] == nil {
p.subnetCache[listName] = make(map[string]time.Time)
}
p.subnetCache[listName][cidrAddr] = time.Now()
p.subnetMu.Unlock()
}
}
// 主执行函数
func (p *optimizedMikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error {
if p.conn == nil {
p.log.Error("MikroTik connection is nil")
return nil
}
r := qCtx.R()
if r != nil {
var domain string
if len(qCtx.Q().Question) > 0 {
domain = strings.TrimSuffix(qCtx.Q().Question[0].Name, ".")
}
p.log.Debug("processing DNS response",
zap.String("qname", domain),
zap.Int("answer_count", len(r.Answer)))
// 异步处理不阻塞DNS响应
go func(response *dns.Msg, domainName string) {
if err := p.addToAddressList(response, domainName); err != nil {
p.log.Error("failed to add addresses to MikroTik", zap.Error(err))
}
}(r, domain)
}
return nil
}
// 🚀 优化的地址添加逻辑
func (p *optimizedMikrotikAddressListPlugin) addToAddressList(r *dns.Msg, domain string) error {
var ipv4Addresses []netip.Addr
var ipv6Addresses []netip.Addr
// 收集所有IP地址
for i := range r.Answer {
switch rr := r.Answer[i].(type) {
case *dns.A:
if len(p.args.AddressList4) == 0 {
continue
}
addr, ok := netip.AddrFromSlice(rr.A.To4())
if !ok {
continue
}
ipv4Addresses = append(ipv4Addresses, addr)
case *dns.AAAA:
if len(p.args.AddressList6) == 0 {
continue
}
addr, ok := netip.AddrFromSlice(rr.AAAA.To16())
if !ok {
continue
}
ipv6Addresses = append(ipv6Addresses, addr)
}
}
// 应用IP数量限制
if p.args.MaxIPs > 0 {
if len(ipv4Addresses) > p.args.MaxIPs {
ipv4Addresses = ipv4Addresses[:p.args.MaxIPs]
}
if len(ipv6Addresses) > p.args.MaxIPs {
ipv6Addresses = ipv6Addresses[:p.args.MaxIPs]
}
}
// 异步处理IPv4地址
if len(ipv4Addresses) > 0 && len(p.args.AddressList4) > 0 {
go p.processIPAddresses(ipv4Addresses, p.args.AddressList4, p.args.Mask4, domain, "IPv4")
}
// 异步处理IPv6地址
if len(ipv6Addresses) > 0 && len(p.args.AddressList6) > 0 {
go p.processIPAddresses(ipv6Addresses, p.args.AddressList6, p.args.Mask6, domain, "IPv6")
}
return nil
}
// 🚀 优化的IP处理逻辑
func (p *optimizedMikrotikAddressListPlugin) processIPAddresses(addresses []netip.Addr, listName string, mask int, domain, ipType string) {
var needToAdd []string
skippedCount := 0
// 🚀 关键优化先在内存中过滤已存在的IP
for _, addr := range addresses {
cidrAddr := p.buildCIDRAddress(addr, mask)
// 纯内存检查,极快速度
if p.isIPInMemoryCache(listName, cidrAddr) {
skippedCount++
p.log.Debug("IP already exists in memory cache, skipping",
zap.String("ip", addr.String()),
zap.String("cidr", cidrAddr),
zap.String("list", listName))
continue
}
// 对于/24网段检查网段是否已存在
if mask == 24 && p.isSubnetInCache(listName, cidrAddr) {
skippedCount++
p.log.Debug("subnet already cached, skipping",
zap.String("cidr", cidrAddr),
zap.String("list", listName))
continue
}
needToAdd = append(needToAdd, cidrAddr)
}
p.log.Info("IP processing summary",
zap.String("type", ipType),
zap.String("domain", domain),
zap.Int("total_ips", len(addresses)),
zap.Int("need_to_add", len(needToAdd)),
zap.Int("skipped_cached", skippedCount))
// 只处理需要添加的IP
if len(needToAdd) > 0 {
p.batchAddOptimized(needToAdd, listName, domain)
}
}
// 构建CIDR地址
func (p *optimizedMikrotikAddressListPlugin) buildCIDRAddress(addr netip.Addr, mask int) string {
if addr.Is4() {
networkAddr := netip.PrefixFrom(addr, mask).Addr()
return networkAddr.String() + "/" + strconv.Itoa(mask)
} else {
networkAddr := netip.PrefixFrom(addr, mask).Addr()
return networkAddr.String() + "/" + strconv.Itoa(mask)
}
}
// 🚀 优化的批量添加
func (p *optimizedMikrotikAddressListPlugin) batchAddOptimized(addresses []string, listName, domain string) {
for _, cidrAddr := range addresses {
// 获取工作池槽位
select {
case p.workerPool <- struct{}{}:
go func(addr string) {
defer func() { <-p.workerPool }()
if err := p.addSingleAddress(addr, listName, domain); err != nil {
p.log.Error("failed to add address",
zap.String("cidr", addr),
zap.String("list", listName),
zap.Error(err))
} else {
// 🚀 成功后立即更新内存缓存
p.addToMemoryCache(listName, addr)
p.log.Debug("successfully added and cached address",
zap.String("cidr", addr),
zap.String("list", listName))
}
}(cidrAddr)
default:
// 工作池满,直接执行
if err := p.addSingleAddress(cidrAddr, listName, domain); err != nil {
p.log.Error("failed to add address (direct)",
zap.String("cidr", cidrAddr),
zap.Error(err))
} else {
p.addToMemoryCache(listName, cidrAddr)
}
}
}
}
// 添加单个地址到MikroTik
func (p *optimizedMikrotikAddressListPlugin) addSingleAddress(cidrAddr, listName, domain string) error {
// 构造参数
params := []string{
"=list=" + listName,
"=address=" + cidrAddr,
}
// 添加注释
comment := domain
if comment == "" && p.args.Comment != "" {
comment = p.args.Comment
}
if comment != "" {
params = append(params, "=comment="+comment)
}
// 添加超时时间
if p.args.TimeoutAddr > 0 {
params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr))
}
p.log.Debug("adding address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.String("domain", domain))
// 发送到MikroTik
args := append([]string{"/ip/firewall/address-list/add"}, params...)
_, err := p.conn.Run(args...)
if err != nil {
if strings.Contains(err.Error(), "already have such entry") {
// 如果MikroTik说已存在更新内存缓存
p.addToMemoryCache(listName, cidrAddr)
p.log.Debug("address already exists in MikroTik, updated cache",
zap.String("cidr", cidrAddr))
return nil
}
return fmt.Errorf("failed to add address %s: %v", cidrAddr, err)
}
return nil
}
// 获取缓存统计
func (p *optimizedMikrotikAddressListPlugin) getCacheStats() map[string]int {
p.cacheMu.RLock()
defer p.cacheMu.RUnlock()
stats := make(map[string]int)
for listName, ips := range p.ipCache {
stats[listName] = len(ips)
}
return stats
}
// 关闭插件
func (p *optimizedMikrotikAddressListPlugin) Close() error {
p.wg.Wait()
// 打印最终统计
stats := p.getCacheStats()
p.log.Info("optimized plugin closing", zap.Any("final_cache_stats", stats))
p.mu.Lock()
defer p.mu.Unlock()
if p.conn != nil {
return p.conn.Close()
}
return nil
}