485 lines
12 KiB
Go
485 lines
12 KiB
Go
/*
|
||
* MikroTik Address List 插件 - 性能优化版
|
||
*
|
||
* 主要优化:
|
||
* 1. 完全移除验证功能
|
||
* 2. 启动时从MikroTik加载现有IP到内存
|
||
* 3. 内存中判断IP是否存在,避免重复写入
|
||
* 4. 支持/24网段掩码
|
||
*/
|
||
|
||
package mikrotik_addresslist
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/netip"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
|
||
"github.com/miekg/dns"
|
||
"go.uber.org/zap"
|
||
|
||
routeros "github.com/go-routeros/routeros/v3"
|
||
)
|
||
|
||
type optimizedMikrotikAddressListPlugin struct {
|
||
args *Args
|
||
conn *routeros.Client
|
||
log *zap.Logger
|
||
|
||
// 并发控制
|
||
workerPool chan struct{}
|
||
wg sync.WaitGroup
|
||
mu sync.RWMutex
|
||
isConnected bool
|
||
|
||
// 内存IP缓存 - 核心优化
|
||
ipCache map[string]map[string]bool // map[listName]map[cidrAddr]exists
|
||
cacheMu sync.RWMutex // 保护IP缓存访问
|
||
cacheTTL time.Duration
|
||
|
||
// 网段缓存,用于/24掩码优化
|
||
subnetCache map[string]map[string]time.Time // map[listName]map[subnet]addTime
|
||
subnetMu sync.RWMutex
|
||
}
|
||
|
||
func newOptimizedMikrotikAddressListPlugin(args *Args) (*optimizedMikrotikAddressListPlugin, error) {
|
||
// 设置默认值
|
||
if args.Mask4 == 0 {
|
||
args.Mask4 = 24 // 默认使用/24网段掩码
|
||
}
|
||
if args.Mask6 == 0 {
|
||
args.Mask6 = 64 // IPv6使用/64
|
||
}
|
||
if args.Port == 0 {
|
||
args.Port = 9728
|
||
}
|
||
if args.Timeout == 0 {
|
||
args.Timeout = 3 // 优化:减少超时时间
|
||
}
|
||
if !args.AddAllIPs {
|
||
args.AddAllIPs = true
|
||
}
|
||
|
||
// 构建连接地址
|
||
addr := fmt.Sprintf("%s:%d", args.Host, args.Port)
|
||
|
||
// 创建 MikroTik 连接
|
||
conn, err := routeros.Dial(addr, args.Username, args.Password)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect to MikroTik: %w", err)
|
||
}
|
||
|
||
// 测试连接
|
||
if err := testMikrotikConnection(conn); err != nil {
|
||
conn.Close()
|
||
return nil, fmt.Errorf("failed to test MikroTik connection: %w", err)
|
||
}
|
||
|
||
// 设置工作池大小
|
||
workerCount := 20 // 增加工作线程数
|
||
|
||
// 设置缓存 TTL
|
||
cacheTTL := time.Hour * 2 // 2小时缓存
|
||
if args.CacheTTL > 0 {
|
||
cacheTTL = time.Duration(args.CacheTTL) * time.Second
|
||
}
|
||
|
||
plugin := &optimizedMikrotikAddressListPlugin{
|
||
args: args,
|
||
conn: conn,
|
||
log: zap.L().Named("mikrotik_optimized"),
|
||
workerPool: make(chan struct{}, workerCount),
|
||
ipCache: make(map[string]map[string]bool),
|
||
subnetCache: make(map[string]map[string]time.Time),
|
||
cacheTTL: cacheTTL,
|
||
isConnected: true,
|
||
}
|
||
|
||
// 🚀 核心优化:启动时加载现有IP到内存
|
||
if err := plugin.loadExistingIPs(); err != nil {
|
||
plugin.log.Warn("failed to load existing IPs, continuing anyway", zap.Error(err))
|
||
}
|
||
|
||
plugin.log.Info("optimized MikroTik plugin initialized",
|
||
zap.String("host", args.Host),
|
||
zap.Int("port", args.Port),
|
||
zap.String("address_list4", args.AddressList4),
|
||
zap.String("address_list6", args.AddressList6),
|
||
zap.Int("worker_count", workerCount),
|
||
zap.Duration("cache_ttl", cacheTTL),
|
||
zap.Int("mask4", args.Mask4),
|
||
zap.Int("mask6", args.Mask6))
|
||
|
||
return plugin, nil
|
||
}
|
||
|
||
// 🚀 核心功能:启动时从MikroTik加载现有IP
|
||
func (p *optimizedMikrotikAddressListPlugin) loadExistingIPs() error {
|
||
p.log.Info("loading existing IPs from MikroTik...")
|
||
|
||
// 加载IPv4地址列表
|
||
if p.args.AddressList4 != "" {
|
||
if err := p.loadAddressListIPs(p.args.AddressList4); err != nil {
|
||
p.log.Error("failed to load IPv4 addresses",
|
||
zap.String("list", p.args.AddressList4),
|
||
zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// 加载IPv6地址列表
|
||
if p.args.AddressList6 != "" {
|
||
if err := p.loadAddressListIPs(p.args.AddressList6); err != nil {
|
||
p.log.Error("failed to load IPv6 addresses",
|
||
zap.String("list", p.args.AddressList6),
|
||
zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// 打印加载统计
|
||
p.cacheMu.RLock()
|
||
totalIPs := 0
|
||
for listName, ips := range p.ipCache {
|
||
count := len(ips)
|
||
totalIPs += count
|
||
p.log.Info("loaded address list",
|
||
zap.String("list", listName),
|
||
zap.Int("ip_count", count))
|
||
}
|
||
p.cacheMu.RUnlock()
|
||
|
||
p.log.Info("finished loading existing IPs", zap.Int("total_ips", totalIPs))
|
||
return nil
|
||
}
|
||
|
||
// 从指定的address list加载所有IP
|
||
func (p *optimizedMikrotikAddressListPlugin) loadAddressListIPs(listName string) error {
|
||
// 查询指定列表的所有地址
|
||
resp, err := p.conn.Run("/ip/firewall/address-list/print", "=list="+listName)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to query address list %s: %w", listName, err)
|
||
}
|
||
|
||
p.cacheMu.Lock()
|
||
defer p.cacheMu.Unlock()
|
||
|
||
// 初始化缓存
|
||
if p.ipCache[listName] == nil {
|
||
p.ipCache[listName] = make(map[string]bool)
|
||
}
|
||
|
||
// 解析响应并添加到缓存
|
||
for _, re := range resp.Re {
|
||
if address, ok := re.Map["address"]; ok {
|
||
p.ipCache[listName][address] = true
|
||
|
||
// 如果是网段地址,也缓存网段信息
|
||
if strings.Contains(address, "/") {
|
||
p.subnetMu.Lock()
|
||
if p.subnetCache[listName] == nil {
|
||
p.subnetCache[listName] = make(map[string]time.Time)
|
||
}
|
||
p.subnetCache[listName][address] = time.Now()
|
||
p.subnetMu.Unlock()
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 🚀 优化的IP存在性检查(纯内存操作)
|
||
func (p *optimizedMikrotikAddressListPlugin) isIPInMemoryCache(listName, cidrAddr string) bool {
|
||
p.cacheMu.RLock()
|
||
defer p.cacheMu.RUnlock()
|
||
|
||
if listCache, exists := p.ipCache[listName]; exists {
|
||
return listCache[cidrAddr]
|
||
}
|
||
return false
|
||
}
|
||
|
||
// 🚀 优化的网段存在性检查
|
||
func (p *optimizedMikrotikAddressListPlugin) isSubnetInCache(listName, subnet string) bool {
|
||
p.subnetMu.RLock()
|
||
defer p.subnetMu.RUnlock()
|
||
|
||
if subnetMap, exists := p.subnetCache[listName]; exists {
|
||
if addTime, exists := subnetMap[subnet]; exists {
|
||
// 检查是否过期
|
||
return time.Since(addTime) < p.cacheTTL
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// 添加IP到内存缓存
|
||
func (p *optimizedMikrotikAddressListPlugin) addToMemoryCache(listName, cidrAddr string) {
|
||
p.cacheMu.Lock()
|
||
defer p.cacheMu.Unlock()
|
||
|
||
if p.ipCache[listName] == nil {
|
||
p.ipCache[listName] = make(map[string]bool)
|
||
}
|
||
p.ipCache[listName][cidrAddr] = true
|
||
|
||
// 如果是网段,也更新网段缓存
|
||
if strings.Contains(cidrAddr, "/") {
|
||
p.subnetMu.Lock()
|
||
if p.subnetCache[listName] == nil {
|
||
p.subnetCache[listName] = make(map[string]time.Time)
|
||
}
|
||
p.subnetCache[listName][cidrAddr] = time.Now()
|
||
p.subnetMu.Unlock()
|
||
}
|
||
}
|
||
|
||
// 主执行函数
|
||
func (p *optimizedMikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error {
|
||
if p.conn == nil {
|
||
p.log.Error("MikroTik connection is nil")
|
||
return nil
|
||
}
|
||
|
||
r := qCtx.R()
|
||
if r != nil {
|
||
var domain string
|
||
if len(qCtx.Q().Question) > 0 {
|
||
domain = strings.TrimSuffix(qCtx.Q().Question[0].Name, ".")
|
||
}
|
||
|
||
p.log.Debug("processing DNS response",
|
||
zap.String("qname", domain),
|
||
zap.Int("answer_count", len(r.Answer)))
|
||
|
||
// 异步处理,不阻塞DNS响应
|
||
go func(response *dns.Msg, domainName string) {
|
||
if err := p.addToAddressList(response, domainName); err != nil {
|
||
p.log.Error("failed to add addresses to MikroTik", zap.Error(err))
|
||
}
|
||
}(r, domain)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 🚀 优化的地址添加逻辑
|
||
func (p *optimizedMikrotikAddressListPlugin) addToAddressList(r *dns.Msg, domain string) error {
|
||
var ipv4Addresses []netip.Addr
|
||
var ipv6Addresses []netip.Addr
|
||
|
||
// 收集所有IP地址
|
||
for i := range r.Answer {
|
||
switch rr := r.Answer[i].(type) {
|
||
case *dns.A:
|
||
if len(p.args.AddressList4) == 0 {
|
||
continue
|
||
}
|
||
addr, ok := netip.AddrFromSlice(rr.A.To4())
|
||
if !ok {
|
||
continue
|
||
}
|
||
ipv4Addresses = append(ipv4Addresses, addr)
|
||
|
||
case *dns.AAAA:
|
||
if len(p.args.AddressList6) == 0 {
|
||
continue
|
||
}
|
||
addr, ok := netip.AddrFromSlice(rr.AAAA.To16())
|
||
if !ok {
|
||
continue
|
||
}
|
||
ipv6Addresses = append(ipv6Addresses, addr)
|
||
}
|
||
}
|
||
|
||
// 应用IP数量限制
|
||
if p.args.MaxIPs > 0 {
|
||
if len(ipv4Addresses) > p.args.MaxIPs {
|
||
ipv4Addresses = ipv4Addresses[:p.args.MaxIPs]
|
||
}
|
||
if len(ipv6Addresses) > p.args.MaxIPs {
|
||
ipv6Addresses = ipv6Addresses[:p.args.MaxIPs]
|
||
}
|
||
}
|
||
|
||
// 异步处理IPv4地址
|
||
if len(ipv4Addresses) > 0 && len(p.args.AddressList4) > 0 {
|
||
go p.processIPAddresses(ipv4Addresses, p.args.AddressList4, p.args.Mask4, domain, "IPv4")
|
||
}
|
||
|
||
// 异步处理IPv6地址
|
||
if len(ipv6Addresses) > 0 && len(p.args.AddressList6) > 0 {
|
||
go p.processIPAddresses(ipv6Addresses, p.args.AddressList6, p.args.Mask6, domain, "IPv6")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 🚀 优化的IP处理逻辑
|
||
func (p *optimizedMikrotikAddressListPlugin) processIPAddresses(addresses []netip.Addr, listName string, mask int, domain, ipType string) {
|
||
var needToAdd []string
|
||
skippedCount := 0
|
||
|
||
// 🚀 关键优化:先在内存中过滤已存在的IP
|
||
for _, addr := range addresses {
|
||
cidrAddr := p.buildCIDRAddress(addr, mask)
|
||
|
||
// 纯内存检查,极快速度
|
||
if p.isIPInMemoryCache(listName, cidrAddr) {
|
||
skippedCount++
|
||
p.log.Debug("IP already exists in memory cache, skipping",
|
||
zap.String("ip", addr.String()),
|
||
zap.String("cidr", cidrAddr),
|
||
zap.String("list", listName))
|
||
continue
|
||
}
|
||
|
||
// 对于/24网段,检查网段是否已存在
|
||
if mask == 24 && p.isSubnetInCache(listName, cidrAddr) {
|
||
skippedCount++
|
||
p.log.Debug("subnet already cached, skipping",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.String("list", listName))
|
||
continue
|
||
}
|
||
|
||
needToAdd = append(needToAdd, cidrAddr)
|
||
}
|
||
|
||
p.log.Info("IP processing summary",
|
||
zap.String("type", ipType),
|
||
zap.String("domain", domain),
|
||
zap.Int("total_ips", len(addresses)),
|
||
zap.Int("need_to_add", len(needToAdd)),
|
||
zap.Int("skipped_cached", skippedCount))
|
||
|
||
// 只处理需要添加的IP
|
||
if len(needToAdd) > 0 {
|
||
p.batchAddOptimized(needToAdd, listName, domain)
|
||
}
|
||
}
|
||
|
||
// 构建CIDR地址
|
||
func (p *optimizedMikrotikAddressListPlugin) buildCIDRAddress(addr netip.Addr, mask int) string {
|
||
if addr.Is4() {
|
||
networkAddr := netip.PrefixFrom(addr, mask).Addr()
|
||
return networkAddr.String() + "/" + strconv.Itoa(mask)
|
||
} else {
|
||
networkAddr := netip.PrefixFrom(addr, mask).Addr()
|
||
return networkAddr.String() + "/" + strconv.Itoa(mask)
|
||
}
|
||
}
|
||
|
||
// 🚀 优化的批量添加
|
||
func (p *optimizedMikrotikAddressListPlugin) batchAddOptimized(addresses []string, listName, domain string) {
|
||
for _, cidrAddr := range addresses {
|
||
// 获取工作池槽位
|
||
select {
|
||
case p.workerPool <- struct{}{}:
|
||
go func(addr string) {
|
||
defer func() { <-p.workerPool }()
|
||
|
||
if err := p.addSingleAddress(addr, listName, domain); err != nil {
|
||
p.log.Error("failed to add address",
|
||
zap.String("cidr", addr),
|
||
zap.String("list", listName),
|
||
zap.Error(err))
|
||
} else {
|
||
// 🚀 成功后立即更新内存缓存
|
||
p.addToMemoryCache(listName, addr)
|
||
p.log.Debug("successfully added and cached address",
|
||
zap.String("cidr", addr),
|
||
zap.String("list", listName))
|
||
}
|
||
}(cidrAddr)
|
||
default:
|
||
// 工作池满,直接执行
|
||
if err := p.addSingleAddress(cidrAddr, listName, domain); err != nil {
|
||
p.log.Error("failed to add address (direct)",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.Error(err))
|
||
} else {
|
||
p.addToMemoryCache(listName, cidrAddr)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加单个地址到MikroTik
|
||
func (p *optimizedMikrotikAddressListPlugin) addSingleAddress(cidrAddr, listName, domain string) error {
|
||
// 构造参数
|
||
params := []string{
|
||
"=list=" + listName,
|
||
"=address=" + cidrAddr,
|
||
}
|
||
|
||
// 添加注释
|
||
comment := domain
|
||
if comment == "" && p.args.Comment != "" {
|
||
comment = p.args.Comment
|
||
}
|
||
if comment != "" {
|
||
params = append(params, "=comment="+comment)
|
||
}
|
||
|
||
// 添加超时时间
|
||
if p.args.TimeoutAddr > 0 {
|
||
params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr))
|
||
}
|
||
|
||
p.log.Debug("adding address to MikroTik",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.String("list", listName),
|
||
zap.String("domain", domain))
|
||
|
||
// 发送到MikroTik
|
||
args := append([]string{"/ip/firewall/address-list/add"}, params...)
|
||
_, err := p.conn.Run(args...)
|
||
|
||
if err != nil {
|
||
if strings.Contains(err.Error(), "already have such entry") {
|
||
// 如果MikroTik说已存在,更新内存缓存
|
||
p.addToMemoryCache(listName, cidrAddr)
|
||
p.log.Debug("address already exists in MikroTik, updated cache",
|
||
zap.String("cidr", cidrAddr))
|
||
return nil
|
||
}
|
||
return fmt.Errorf("failed to add address %s: %v", cidrAddr, err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 获取缓存统计
|
||
func (p *optimizedMikrotikAddressListPlugin) getCacheStats() map[string]int {
|
||
p.cacheMu.RLock()
|
||
defer p.cacheMu.RUnlock()
|
||
|
||
stats := make(map[string]int)
|
||
for listName, ips := range p.ipCache {
|
||
stats[listName] = len(ips)
|
||
}
|
||
return stats
|
||
}
|
||
|
||
// 关闭插件
|
||
func (p *optimizedMikrotikAddressListPlugin) Close() error {
|
||
p.wg.Wait()
|
||
|
||
// 打印最终统计
|
||
stats := p.getCacheStats()
|
||
p.log.Info("optimized plugin closing", zap.Any("final_cache_stats", stats))
|
||
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
if p.conn != nil {
|
||
return p.conn.Close()
|
||
}
|
||
return nil
|
||
}
|