mosdns/plugin/executable/mikrotik_addresslist/mikrotik_addresslist_impl.go

413 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package mikrotik_addresslist
import (
"context"
"fmt"
"net/netip"
"strconv"
"strings"
"sync"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/miekg/dns"
"go.uber.org/zap"
routeros "github.com/go-routeros/routeros/v3"
)
type mikrotikAddressListPlugin struct {
args *Args
conn *routeros.Client
log *zap.Logger
// 并发控制
workerPool chan struct{}
wg sync.WaitGroup
mu sync.RWMutex // 保护连接的重连操作
}
func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) {
if args.Mask4 == 0 {
args.Mask4 = 24
}
if args.Mask6 == 0 {
args.Mask6 = 32
}
if args.Port == 0 {
args.Port = 8728
}
if args.Timeout == 0 {
args.Timeout = 10
}
// 构建连接地址
addr := fmt.Sprintf("%s:%d", args.Host, args.Port)
// 创建 MikroTik 连接
conn, err := routeros.Dial(addr, args.Username, args.Password)
if err != nil {
return nil, fmt.Errorf("failed to connect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to test MikroTik connection: %w", err)
}
// 设置工作池大小(可以根据需要调整)
workerCount := 10 // 并发工作线程数
plugin := &mikrotikAddressListPlugin{
args: args,
conn: conn,
log: zap.L().Named("mikrotik_addresslist"),
workerPool: make(chan struct{}, workerCount),
}
// 记录连接成功信息
plugin.log.Info("successfully connected to MikroTik",
zap.String("host", args.Host),
zap.Int("port", args.Port),
zap.String("username", args.Username),
zap.String("address_list4", args.AddressList4),
zap.Int("worker_count", workerCount))
return plugin, nil
}
// testMikrotikConnection 测试 MikroTik 连接是否正常
func testMikrotikConnection(conn *routeros.Client) error {
// 尝试执行一个简单的命令来测试连接
resp, err := conn.Run("/system/resource/print")
if err != nil {
return fmt.Errorf("connection test failed: %w", err)
}
// 检查响应是否有效
if len(resp.Re) == 0 {
return fmt.Errorf("connection test failed: no response from MikroTik")
}
// 记录连接测试成功
zap.L().Named("mikrotik_addresslist").Info("MikroTik connection test successful")
return nil
}
func (p *mikrotikAddressListPlugin) reconnect() error {
p.log.Info("attempting to reconnect to MikroTik")
// 关闭旧连接
if p.conn != nil {
p.conn.Close()
}
// 重新建立连接
var err error
p.conn, err = routeros.Dial(p.args.Host+":"+strconv.Itoa(p.args.Port), p.args.Username, p.args.Password)
if err != nil {
return fmt.Errorf("failed to reconnect to MikroTik: %w", err)
}
// 测试连接
if err := testMikrotikConnection(p.conn); err != nil {
return fmt.Errorf("failed to test reconnection: %w", err)
}
p.log.Info("successfully reconnected to MikroTik")
return nil
}
func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error {
// 检查连接是否正常
if p.conn == nil {
p.log.Error("MikroTik connection is nil")
return fmt.Errorf("mikrotik_addresslist: connection is nil")
}
r := qCtx.R()
if r != nil {
p.log.Debug("processing DNS response",
zap.String("qname", qCtx.Q().Question[0].Name),
zap.Int("answer_count", len(r.Answer)))
if err := p.addToAddressList(r); err != nil {
p.log.Error("failed to add addresses to MikroTik", zap.Error(err))
return fmt.Errorf("mikrotik_addresslist: %w", err)
}
}
return nil
}
func (p *mikrotikAddressListPlugin) Close() error {
// 等待所有工作完成
p.wg.Wait()
// 关闭连接
p.mu.Lock()
defer p.mu.Unlock()
if p.conn != nil {
return p.conn.Close()
}
return nil
}
func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error {
p.log.Debug("starting to process DNS response",
zap.String("configured_address_list4", p.args.AddressList4),
zap.Int("answer_count", len(r.Answer)))
// 收集所有需要处理的 IPv4 地址
var addresses []netip.Addr
for i := range r.Answer {
switch rr := r.Answer[i].(type) {
case *dns.A:
if len(p.args.AddressList4) == 0 {
p.log.Debug("skipping A record, no IPv4 address list configured")
continue
}
addr, ok := netip.AddrFromSlice(rr.A.To4())
if !ok {
p.log.Error("invalid A record", zap.String("ip", rr.A.String()))
continue // 跳过无效记录,不中断处理
}
addresses = append(addresses, addr)
p.log.Debug("queued A record for processing",
zap.String("ip", addr.String()),
zap.String("address_list4", p.args.AddressList4))
case *dns.AAAA:
// 跳过 IPv6 记录
p.log.Debug("skipping AAAA record (IPv6 not supported)")
continue
default:
p.log.Debug("skipping non-A record", zap.String("type", fmt.Sprintf("%T", rr)))
continue
}
}
if len(addresses) == 0 {
p.log.Debug("no IPv4 addresses to process")
return nil
}
// 并发处理所有地址
var wg sync.WaitGroup
var mu sync.Mutex
var errors []error
addedCount := 0
for _, addr := range addresses {
wg.Add(1)
go func(addr netip.Addr) {
defer wg.Done()
// 获取工作池槽位
select {
case p.workerPool <- struct{}{}:
defer func() { <-p.workerPool }()
default:
// 如果工作池满了,直接处理(避免阻塞)
p.log.Debug("worker pool full, processing directly")
}
if err := p.addAddressToMikrotik(addr, p.args.AddressList4, p.args.Mask4); err != nil {
mu.Lock()
errors = append(errors, err)
mu.Unlock()
} else {
mu.Lock()
addedCount++
mu.Unlock()
}
}(addr)
}
// 等待所有工作完成
wg.Wait()
// 记录结果
if addedCount > 0 {
p.log.Info("concurrently added IPv4 addresses to MikroTik",
zap.Int("success_count", addedCount),
zap.Int("total_count", len(addresses)),
zap.Int("error_count", len(errors)))
} else {
p.log.Debug("no IPv4 addresses added to MikroTik")
}
// 如果有错误,返回第一个错误
if len(errors) > 0 {
return fmt.Errorf("some addresses failed to add: %v", errors[0])
}
return nil
}
func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listName string, mask int) error {
p.log.Debug("addAddressToMikrotik called",
zap.String("addr", addr.String()),
zap.String("listName", listName),
zap.Int("mask", mask))
// 构建 CIDR 格式的地址,将 IP 转换为网段地址
var cidrAddr string
if addr.Is4() {
// 将 IPv4 地址转换为网段地址(主机位清零)
networkAddr := netip.PrefixFrom(addr, p.args.Mask4).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask4)
} else {
// 将 IPv6 地址转换为网段地址(主机位清零)
networkAddr := netip.PrefixFrom(addr, p.args.Mask6).Addr()
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask6)
}
p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName))
// 检查地址是否已存在
exists, err := p.addressExists(listName, cidrAddr)
if err != nil {
// 如果检查失败,可能是地址列表不存在,继续尝试添加
p.log.Debug("failed to check if address exists, will try to add anyway", zap.Error(err))
} else if exists {
// 地址已存在,跳过
p.log.Debug("address already exists", zap.String("cidr", cidrAddr), zap.String("list", listName))
return nil
}
// 构造 RouterOS 参数,注意必须以 = 开头!
params := []string{
"=list=" + listName,
"=address=" + cidrAddr,
}
// 添加注释(如果配置了)
if p.args.Comment != "" {
params = append(params, "=comment="+p.args.Comment)
}
// 添加超时时间(如果配置了)
if p.args.TimeoutAddr > 0 {
params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr))
}
p.log.Info("adding address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.String("comment", p.args.Comment),
zap.Int("timeout", p.args.TimeoutAddr))
p.log.Debug("Add to list: ", zap.Strings("params", params))
// 发送到 RouterOS带重试机制
maxRetries := 3
for i := 0; i < maxRetries; i++ {
// 使用读锁保护连接访问
p.mu.RLock()
conn := p.conn
p.mu.RUnlock()
if conn == nil {
p.log.Error("connection is nil")
return fmt.Errorf("connection is nil")
}
args := append([]string{"/ip/firewall/address-list/add"}, params...)
_, err = conn.Run(args...)
if err != nil {
if strings.Contains(err.Error(), "already have such entry") {
p.log.Debug("Already exists: ", zap.String("cidr", cidrAddr))
return nil
}
// 如果是连接错误,尝试重新连接
if strings.Contains(err.Error(), "EOF") || strings.Contains(err.Error(), "connection") {
p.log.Warn("connection error, attempting to reconnect",
zap.String("cidr", cidrAddr),
zap.Int("retry", i+1),
zap.Error(err))
// 使用写锁保护重连操作
p.mu.Lock()
if err := p.reconnect(); err != nil {
p.mu.Unlock()
p.log.Error("failed to reconnect", zap.Error(err))
continue
}
p.mu.Unlock()
// 重试
continue
}
// 其他错误,记录并返回
p.log.Error("failed to add address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName),
zap.Error(err))
return fmt.Errorf("failed to add address %s to list %s: from RouterOS device: %v", cidrAddr, listName, err)
}
// 成功,跳出重试循环
break
}
p.log.Info("successfully added address to MikroTik",
zap.String("cidr", cidrAddr),
zap.String("list", listName))
return nil
}
func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) {
// 查询地址列表中是否已存在该地址
params := []string{
"=list=" + listName,
"=address=" + address,
}
p.log.Debug("checking address existence",
zap.String("list", listName),
zap.String("address", address))
args := append([]string{"/ip/firewall/address-list/print"}, params...)
resp, err := p.conn.Run(args...)
if err != nil {
p.log.Error("failed to check address existence",
zap.String("list", listName),
zap.String("address", address),
zap.Error(err))
return false, err
}
// 如果返回结果不为空,说明地址已存在
exists := len(resp.Re) > 0
p.log.Debug("address existence check",
zap.String("list", listName),
zap.String("address", address),
zap.Bool("exists", exists))
return exists, nil
}