/* * Copyright (C) 2020-2022, IrineSistiana * * This file is part of mosdns. * * mosdns is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * mosdns is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ package mikrotik_addresslist import ( "context" "fmt" "net/netip" "strconv" "strings" "sync" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/miekg/dns" "go.uber.org/zap" routeros "github.com/go-routeros/routeros/v3" ) type mikrotikAddressListPlugin struct { args *Args conn *routeros.Client log *zap.Logger // 并发控制 workerPool chan struct{} wg sync.WaitGroup mu sync.RWMutex // 保护连接的重连操作 } func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) { if args.Mask4 == 0 { args.Mask4 = 24 } if args.Mask6 == 0 { args.Mask6 = 32 } if args.Port == 0 { args.Port = 8728 } if args.Timeout == 0 { args.Timeout = 10 } // 构建连接地址 addr := fmt.Sprintf("%s:%d", args.Host, args.Port) // 创建 MikroTik 连接 conn, err := routeros.Dial(addr, args.Username, args.Password) if err != nil { return nil, fmt.Errorf("failed to connect to MikroTik: %w", err) } // 测试连接 if err := testMikrotikConnection(conn); err != nil { conn.Close() return nil, fmt.Errorf("failed to test MikroTik connection: %w", err) } // 设置工作池大小(可以根据需要调整) workerCount := 10 // 并发工作线程数 plugin := &mikrotikAddressListPlugin{ args: args, conn: conn, log: zap.L().Named("mikrotik_addresslist"), workerPool: make(chan struct{}, workerCount), } // 记录连接成功信息 plugin.log.Info("successfully connected to MikroTik", zap.String("host", args.Host), zap.Int("port", args.Port), zap.String("username", args.Username), zap.String("address_list4", args.AddressList4), zap.Int("worker_count", workerCount)) return plugin, nil } // testMikrotikConnection 测试 MikroTik 连接是否正常 func testMikrotikConnection(conn *routeros.Client) error { // 尝试执行一个简单的命令来测试连接 resp, err := conn.Run("/system/resource/print") if err != nil { return fmt.Errorf("connection test failed: %w", err) } // 检查响应是否有效 if len(resp.Re) == 0 { return fmt.Errorf("connection test failed: no response from MikroTik") } // 记录连接测试成功 zap.L().Named("mikrotik_addresslist").Info("MikroTik connection test successful") return nil } func (p *mikrotikAddressListPlugin) reconnect() error { p.log.Info("attempting to reconnect to MikroTik") // 关闭旧连接 if p.conn != nil { p.conn.Close() } // 重新建立连接 var err error p.conn, err = routeros.Dial(p.args.Host+":"+strconv.Itoa(p.args.Port), p.args.Username, p.args.Password) if err != nil { return fmt.Errorf("failed to reconnect to MikroTik: %w", err) } // 测试连接 if err := testMikrotikConnection(p.conn); err != nil { return fmt.Errorf("failed to test reconnection: %w", err) } p.log.Info("successfully reconnected to MikroTik") return nil } func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error { // 检查连接是否正常 if p.conn == nil { p.log.Error("MikroTik connection is nil") return fmt.Errorf("mikrotik_addresslist: connection is nil") } r := qCtx.R() if r != nil { p.log.Debug("processing DNS response", zap.String("qname", qCtx.Q().Question[0].Name), zap.Int("answer_count", len(r.Answer))) if err := p.addToAddressList(r); err != nil { p.log.Error("failed to add addresses to MikroTik", zap.Error(err)) return fmt.Errorf("mikrotik_addresslist: %w", err) } } return nil } func (p *mikrotikAddressListPlugin) Close() error { // 等待所有工作完成 p.wg.Wait() // 关闭连接 p.mu.Lock() defer p.mu.Unlock() if p.conn != nil { return p.conn.Close() } return nil } func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error { p.log.Debug("starting to process DNS response", zap.String("configured_address_list4", p.args.AddressList4), zap.Int("answer_count", len(r.Answer))) // 收集所有需要处理的 IPv4 地址 var addresses []netip.Addr for i := range r.Answer { switch rr := r.Answer[i].(type) { case *dns.A: if len(p.args.AddressList4) == 0 { p.log.Debug("skipping A record, no IPv4 address list configured") continue } addr, ok := netip.AddrFromSlice(rr.A.To4()) if !ok { p.log.Error("invalid A record", zap.String("ip", rr.A.String())) continue // 跳过无效记录,不中断处理 } addresses = append(addresses, addr) p.log.Debug("queued A record for processing", zap.String("ip", addr.String()), zap.String("address_list4", p.args.AddressList4)) case *dns.AAAA: // 跳过 IPv6 记录 p.log.Debug("skipping AAAA record (IPv6 not supported)") continue default: p.log.Debug("skipping non-A record", zap.String("type", fmt.Sprintf("%T", rr))) continue } } if len(addresses) == 0 { p.log.Debug("no IPv4 addresses to process") return nil } // 并发处理所有地址 var wg sync.WaitGroup var mu sync.Mutex var errors []error addedCount := 0 for _, addr := range addresses { wg.Add(1) go func(addr netip.Addr) { defer wg.Done() // 获取工作池槽位 select { case p.workerPool <- struct{}{}: defer func() { <-p.workerPool }() default: // 如果工作池满了,直接处理(避免阻塞) p.log.Debug("worker pool full, processing directly") } if err := p.addAddressToMikrotik(addr, p.args.AddressList4, p.args.Mask4); err != nil { mu.Lock() errors = append(errors, err) mu.Unlock() } else { mu.Lock() addedCount++ mu.Unlock() } }(addr) } // 等待所有工作完成 wg.Wait() // 记录结果 if addedCount > 0 { p.log.Info("concurrently added IPv4 addresses to MikroTik", zap.Int("success_count", addedCount), zap.Int("total_count", len(addresses)), zap.Int("error_count", len(errors))) } else { p.log.Debug("no IPv4 addresses added to MikroTik") } // 如果有错误,返回第一个错误 if len(errors) > 0 { return fmt.Errorf("some addresses failed to add: %v", errors[0]) } return nil } func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listName string, mask int) error { p.log.Debug("addAddressToMikrotik called", zap.String("addr", addr.String()), zap.String("listName", listName), zap.Int("mask", mask)) // 构建 CIDR 格式的地址,将 IP 转换为网段地址 var cidrAddr string if addr.Is4() { // 将 IPv4 地址转换为网段地址(主机位清零) networkAddr := netip.PrefixFrom(addr, p.args.Mask4).Addr() cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask4) } else { // 将 IPv6 地址转换为网段地址(主机位清零) networkAddr := netip.PrefixFrom(addr, p.args.Mask6).Addr() cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask6) } p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName)) // 检查地址是否已存在 exists, err := p.addressExists(listName, cidrAddr) if err != nil { // 如果检查失败,可能是地址列表不存在,继续尝试添加 p.log.Debug("failed to check if address exists, will try to add anyway", zap.Error(err)) } else if exists { // 地址已存在,跳过 p.log.Debug("address already exists", zap.String("cidr", cidrAddr), zap.String("list", listName)) return nil } // 构造 RouterOS 参数,注意必须以 = 开头! params := []string{ "=list=" + listName, "=address=" + cidrAddr, } // 添加注释(如果配置了) if p.args.Comment != "" { params = append(params, "=comment="+p.args.Comment) } // 添加超时时间(如果配置了) if p.args.TimeoutAddr > 0 { params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr)) } p.log.Info("adding address to MikroTik", zap.String("cidr", cidrAddr), zap.String("list", listName), zap.String("comment", p.args.Comment), zap.Int("timeout", p.args.TimeoutAddr)) p.log.Debug("Add to list: ", zap.Strings("params", params)) // 发送到 RouterOS,带重试机制 maxRetries := 3 for i := 0; i < maxRetries; i++ { // 使用读锁保护连接访问 p.mu.RLock() conn := p.conn p.mu.RUnlock() if conn == nil { p.log.Error("connection is nil") return fmt.Errorf("connection is nil") } args := append([]string{"/ip/firewall/address-list/add"}, params...) _, err = conn.Run(args...) if err != nil { if strings.Contains(err.Error(), "already have such entry") { p.log.Debug("Already exists: ", zap.String("cidr", cidrAddr)) return nil } // 如果是连接错误,尝试重新连接 if strings.Contains(err.Error(), "EOF") || strings.Contains(err.Error(), "connection") { p.log.Warn("connection error, attempting to reconnect", zap.String("cidr", cidrAddr), zap.Int("retry", i+1), zap.Error(err)) // 使用写锁保护重连操作 p.mu.Lock() if err := p.reconnect(); err != nil { p.mu.Unlock() p.log.Error("failed to reconnect", zap.Error(err)) continue } p.mu.Unlock() // 重试 continue } // 其他错误,记录并返回 p.log.Error("failed to add address to MikroTik", zap.String("cidr", cidrAddr), zap.String("list", listName), zap.Error(err)) return fmt.Errorf("failed to add address %s to list %s: from RouterOS device: %v", cidrAddr, listName, err) } // 成功,跳出重试循环 break } p.log.Info("successfully added address to MikroTik", zap.String("cidr", cidrAddr), zap.String("list", listName)) return nil } func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) { // 查询地址列表中是否已存在该地址 params := []string{ "=list=" + listName, "=address=" + address, } p.log.Debug("checking address existence", zap.String("list", listName), zap.String("address", address)) args := append([]string{"/ip/firewall/address-list/print"}, params...) resp, err := p.conn.Run(args...) if err != nil { p.log.Error("failed to check address existence", zap.String("list", listName), zap.String("address", address), zap.Error(err)) return false, err } // 如果返回结果不为空,说明地址已存在 exists := len(resp.Re) > 0 p.log.Debug("address existence check", zap.String("list", listName), zap.String("address", address), zap.Bool("exists", exists)) return exists, nil }