diff --git a/mosdns-linux-amd64 b/mosdns similarity index 75% rename from mosdns-linux-amd64 rename to mosdns index b5c7057..1d19d81 100644 Binary files a/mosdns-linux-amd64 and b/mosdns differ diff --git a/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go b/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go index fbcbd1e..78e0a97 100644 --- a/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go +++ b/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go @@ -25,6 +25,7 @@ import ( "net/netip" "strconv" "strings" + "sync" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/miekg/dns" @@ -37,6 +38,11 @@ type mikrotikAddressListPlugin struct { args *Args conn *routeros.Client log *zap.Logger + + // 并发控制 + workerPool chan struct{} + wg sync.WaitGroup + mu sync.RWMutex // 保护连接的重连操作 } func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) { @@ -68,10 +74,14 @@ func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error return nil, fmt.Errorf("failed to test MikroTik connection: %w", err) } + // 设置工作池大小(可以根据需要调整) + workerCount := 10 // 并发工作线程数 + plugin := &mikrotikAddressListPlugin{ - args: args, - conn: conn, - log: zap.L().Named("mikrotik_addresslist"), + args: args, + conn: conn, + log: zap.L().Named("mikrotik_addresslist"), + workerPool: make(chan struct{}, workerCount), } // 记录连接成功信息 @@ -79,7 +89,8 @@ func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error zap.String("host", args.Host), zap.Int("port", args.Port), zap.String("username", args.Username), - zap.String("address_list4", args.AddressList4)) + zap.String("address_list4", args.AddressList4), + zap.Int("worker_count", workerCount)) return plugin, nil } @@ -103,6 +114,30 @@ func testMikrotikConnection(conn *routeros.Client) error { return nil } +func (p *mikrotikAddressListPlugin) reconnect() error { + p.log.Info("attempting to reconnect to MikroTik") + + // 关闭旧连接 + if p.conn != nil { + p.conn.Close() + } + + // 重新建立连接 + var err error + p.conn, err = routeros.Dial(p.args.Host+":"+strconv.Itoa(p.args.Port), p.args.Username, p.args.Password) + if err != nil { + return fmt.Errorf("failed to reconnect to MikroTik: %w", err) + } + + // 测试连接 + if err := testMikrotikConnection(p.conn); err != nil { + return fmt.Errorf("failed to test reconnection: %w", err) + } + + p.log.Info("successfully reconnected to MikroTik") + return nil +} + func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error { // 检查连接是否正常 if p.conn == nil { @@ -125,6 +160,13 @@ func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context. } func (p *mikrotikAddressListPlugin) Close() error { + // 等待所有工作完成 + p.wg.Wait() + + // 关闭连接 + p.mu.Lock() + defer p.mu.Unlock() + if p.conn != nil { return p.conn.Close() } @@ -132,11 +174,12 @@ func (p *mikrotikAddressListPlugin) Close() error { } func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error { - addedCount := 0 p.log.Debug("starting to process DNS response", zap.String("configured_address_list4", p.args.AddressList4), zap.Int("answer_count", len(r.Answer))) + // 收集所有需要处理的 IPv4 地址 + var addresses []netip.Addr for i := range r.Answer { switch rr := r.Answer[i].(type) { case *dns.A: @@ -147,15 +190,12 @@ func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error { addr, ok := netip.AddrFromSlice(rr.A.To4()) if !ok { p.log.Error("invalid A record", zap.String("ip", rr.A.String())) - return fmt.Errorf("invalid A record with ip: %s", rr.A) + continue // 跳过无效记录,不中断处理 } - p.log.Debug("processing A record", + addresses = append(addresses, addr) + p.log.Debug("queued A record for processing", zap.String("ip", addr.String()), zap.String("address_list4", p.args.AddressList4)) - if err := p.addAddressToMikrotik(addr, p.args.AddressList4, p.args.Mask4); err != nil { - return err - } - addedCount++ case *dns.AAAA: // 跳过 IPv6 记录 @@ -167,12 +207,61 @@ func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error { } } + if len(addresses) == 0 { + p.log.Debug("no IPv4 addresses to process") + return nil + } + + // 并发处理所有地址 + var wg sync.WaitGroup + var mu sync.Mutex + var errors []error + addedCount := 0 + + for _, addr := range addresses { + wg.Add(1) + go func(addr netip.Addr) { + defer wg.Done() + + // 获取工作池槽位 + select { + case p.workerPool <- struct{}{}: + defer func() { <-p.workerPool }() + default: + // 如果工作池满了,直接处理(避免阻塞) + p.log.Debug("worker pool full, processing directly") + } + + if err := p.addAddressToMikrotik(addr, p.args.AddressList4, p.args.Mask4); err != nil { + mu.Lock() + errors = append(errors, err) + mu.Unlock() + } else { + mu.Lock() + addedCount++ + mu.Unlock() + } + }(addr) + } + + // 等待所有工作完成 + wg.Wait() + + // 记录结果 if addedCount > 0 { - p.log.Info("added IPv4 addresses to MikroTik", zap.Int("count", addedCount)) + p.log.Info("concurrently added IPv4 addresses to MikroTik", + zap.Int("success_count", addedCount), + zap.Int("total_count", len(addresses)), + zap.Int("error_count", len(errors))) } else { p.log.Debug("no IPv4 addresses added to MikroTik") } + // 如果有错误,返回第一个错误 + if len(errors) > 0 { + return fmt.Errorf("some addresses failed to add: %v", errors[0]) + } + return nil } @@ -182,12 +271,16 @@ func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listNa zap.String("listName", listName), zap.Int("mask", mask)) - // 构建 CIDR 格式的地址 + // 构建 CIDR 格式的地址,将 IP 转换为网段地址 var cidrAddr string if addr.Is4() { - cidrAddr = addr.String() + "/" + strconv.Itoa(p.args.Mask4) + // 将 IPv4 地址转换为网段地址(主机位清零) + networkAddr := netip.PrefixFrom(addr, p.args.Mask4).Addr() + cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask4) } else { - cidrAddr = addr.String() + "/" + strconv.Itoa(p.args.Mask6) + // 将 IPv6 地址转换为网段地址(主机位清零) + networkAddr := netip.PrefixFrom(addr, p.args.Mask6).Addr() + cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask6) } p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName)) @@ -227,19 +320,57 @@ func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listNa p.log.Debug("Add to list: ", zap.Strings("params", params)) - // 发送到 RouterOS - args := append([]string{"/ip/firewall/address-list/add"}, params...) - _, err = p.conn.Run(args...) - if err != nil { - if strings.Contains(err.Error(), "already have such entry") { - p.log.Debug("Already exists: ", zap.String("cidr", cidrAddr)) - return nil + // 发送到 RouterOS,带重试机制 + maxRetries := 3 + for i := 0; i < maxRetries; i++ { + // 使用读锁保护连接访问 + p.mu.RLock() + conn := p.conn + p.mu.RUnlock() + + if conn == nil { + p.log.Error("connection is nil") + return fmt.Errorf("connection is nil") } - p.log.Error("failed to add address to MikroTik", - zap.String("cidr", cidrAddr), - zap.String("list", listName), - zap.Error(err)) - return fmt.Errorf("failed to add address %s to list %s: from RouterOS device: %v", cidrAddr, listName, err) + + args := append([]string{"/ip/firewall/address-list/add"}, params...) + _, err = conn.Run(args...) + if err != nil { + if strings.Contains(err.Error(), "already have such entry") { + p.log.Debug("Already exists: ", zap.String("cidr", cidrAddr)) + return nil + } + + // 如果是连接错误,尝试重新连接 + if strings.Contains(err.Error(), "EOF") || strings.Contains(err.Error(), "connection") { + p.log.Warn("connection error, attempting to reconnect", + zap.String("cidr", cidrAddr), + zap.Int("retry", i+1), + zap.Error(err)) + + // 使用写锁保护重连操作 + p.mu.Lock() + if err := p.reconnect(); err != nil { + p.mu.Unlock() + p.log.Error("failed to reconnect", zap.Error(err)) + continue + } + p.mu.Unlock() + + // 重试 + continue + } + + // 其他错误,记录并返回 + p.log.Error("failed to add address to MikroTik", + zap.String("cidr", cidrAddr), + zap.String("list", listName), + zap.Error(err)) + return fmt.Errorf("failed to add address %s to list %s: from RouterOS device: %v", cidrAddr, listName, err) + } + + // 成功,跳出重试循环 + break } p.log.Info("successfully added address to MikroTik", @@ -251,13 +382,17 @@ func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listNa func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) { // 查询地址列表中是否已存在该地址 - query := fmt.Sprintf("?list=%s&address=%s", listName, address) + params := []string{ + "=list=" + listName, + "=address=" + address, + } + p.log.Debug("checking address existence", zap.String("list", listName), - zap.String("address", address), - zap.String("query", query)) + zap.String("address", address)) - resp, err := p.conn.Run("/ip/firewall/address-list/print", query) + args := append([]string{"/ip/firewall/address-list/print"}, params...) + resp, err := p.conn.Run(args...) if err != nil { p.log.Error("failed to check address existence", zap.String("list", listName),