413 lines
11 KiB
Go
413 lines
11 KiB
Go
/*
|
||
* Copyright (C) 2020-2022, IrineSistiana
|
||
*
|
||
* This file is part of mosdns.
|
||
*
|
||
* mosdns is free software: you can redistribute it and/or modify
|
||
* it under the terms of the GNU General Public License as published by
|
||
* the Free Software Foundation, either version 3 of the License, or
|
||
* (at your option) any later version.
|
||
*
|
||
* mosdns is distributed in the hope that it will be useful,
|
||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||
* GNU General Public License for more details.
|
||
*
|
||
* You should have received a copy of the GNU General Public License
|
||
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||
*/
|
||
|
||
package mikrotik_addresslist
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/netip"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
|
||
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
|
||
"github.com/miekg/dns"
|
||
"go.uber.org/zap"
|
||
|
||
routeros "github.com/go-routeros/routeros/v3"
|
||
)
|
||
|
||
type mikrotikAddressListPlugin struct {
|
||
args *Args
|
||
conn *routeros.Client
|
||
log *zap.Logger
|
||
|
||
// 并发控制
|
||
workerPool chan struct{}
|
||
wg sync.WaitGroup
|
||
mu sync.RWMutex // 保护连接的重连操作
|
||
}
|
||
|
||
func newMikrotikAddressListPlugin(args *Args) (*mikrotikAddressListPlugin, error) {
|
||
if args.Mask4 == 0 {
|
||
args.Mask4 = 24
|
||
}
|
||
if args.Mask6 == 0 {
|
||
args.Mask6 = 32
|
||
}
|
||
if args.Port == 0 {
|
||
args.Port = 8728
|
||
}
|
||
if args.Timeout == 0 {
|
||
args.Timeout = 10
|
||
}
|
||
|
||
// 构建连接地址
|
||
addr := fmt.Sprintf("%s:%d", args.Host, args.Port)
|
||
|
||
// 创建 MikroTik 连接
|
||
conn, err := routeros.Dial(addr, args.Username, args.Password)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect to MikroTik: %w", err)
|
||
}
|
||
|
||
// 测试连接
|
||
if err := testMikrotikConnection(conn); err != nil {
|
||
conn.Close()
|
||
return nil, fmt.Errorf("failed to test MikroTik connection: %w", err)
|
||
}
|
||
|
||
// 设置工作池大小(可以根据需要调整)
|
||
workerCount := 10 // 并发工作线程数
|
||
|
||
plugin := &mikrotikAddressListPlugin{
|
||
args: args,
|
||
conn: conn,
|
||
log: zap.L().Named("mikrotik_addresslist"),
|
||
workerPool: make(chan struct{}, workerCount),
|
||
}
|
||
|
||
// 记录连接成功信息
|
||
plugin.log.Info("successfully connected to MikroTik",
|
||
zap.String("host", args.Host),
|
||
zap.Int("port", args.Port),
|
||
zap.String("username", args.Username),
|
||
zap.String("address_list4", args.AddressList4),
|
||
zap.Int("worker_count", workerCount))
|
||
|
||
return plugin, nil
|
||
}
|
||
|
||
// testMikrotikConnection 测试 MikroTik 连接是否正常
|
||
func testMikrotikConnection(conn *routeros.Client) error {
|
||
// 尝试执行一个简单的命令来测试连接
|
||
resp, err := conn.Run("/system/resource/print")
|
||
if err != nil {
|
||
return fmt.Errorf("connection test failed: %w", err)
|
||
}
|
||
|
||
// 检查响应是否有效
|
||
if len(resp.Re) == 0 {
|
||
return fmt.Errorf("connection test failed: no response from MikroTik")
|
||
}
|
||
|
||
// 记录连接测试成功
|
||
zap.L().Named("mikrotik_addresslist").Info("MikroTik connection test successful")
|
||
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) reconnect() error {
|
||
p.log.Info("attempting to reconnect to MikroTik")
|
||
|
||
// 关闭旧连接
|
||
if p.conn != nil {
|
||
p.conn.Close()
|
||
}
|
||
|
||
// 重新建立连接
|
||
var err error
|
||
p.conn, err = routeros.Dial(p.args.Host+":"+strconv.Itoa(p.args.Port), p.args.Username, p.args.Password)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to reconnect to MikroTik: %w", err)
|
||
}
|
||
|
||
// 测试连接
|
||
if err := testMikrotikConnection(p.conn); err != nil {
|
||
return fmt.Errorf("failed to test reconnection: %w", err)
|
||
}
|
||
|
||
p.log.Info("successfully reconnected to MikroTik")
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) Exec(_ context.Context, qCtx *query_context.Context) error {
|
||
// 检查连接是否正常
|
||
if p.conn == nil {
|
||
p.log.Error("MikroTik connection is nil")
|
||
return fmt.Errorf("mikrotik_addresslist: connection is nil")
|
||
}
|
||
|
||
r := qCtx.R()
|
||
if r != nil {
|
||
p.log.Debug("processing DNS response",
|
||
zap.String("qname", qCtx.Q().Question[0].Name),
|
||
zap.Int("answer_count", len(r.Answer)))
|
||
|
||
if err := p.addToAddressList(r); err != nil {
|
||
p.log.Error("failed to add addresses to MikroTik", zap.Error(err))
|
||
return fmt.Errorf("mikrotik_addresslist: %w", err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) Close() error {
|
||
// 等待所有工作完成
|
||
p.wg.Wait()
|
||
|
||
// 关闭连接
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
if p.conn != nil {
|
||
return p.conn.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) addToAddressList(r *dns.Msg) error {
|
||
p.log.Debug("starting to process DNS response",
|
||
zap.String("configured_address_list4", p.args.AddressList4),
|
||
zap.Int("answer_count", len(r.Answer)))
|
||
|
||
// 收集所有需要处理的 IPv4 地址
|
||
var addresses []netip.Addr
|
||
for i := range r.Answer {
|
||
switch rr := r.Answer[i].(type) {
|
||
case *dns.A:
|
||
if len(p.args.AddressList4) == 0 {
|
||
p.log.Debug("skipping A record, no IPv4 address list configured")
|
||
continue
|
||
}
|
||
addr, ok := netip.AddrFromSlice(rr.A.To4())
|
||
if !ok {
|
||
p.log.Error("invalid A record", zap.String("ip", rr.A.String()))
|
||
continue // 跳过无效记录,不中断处理
|
||
}
|
||
addresses = append(addresses, addr)
|
||
p.log.Debug("queued A record for processing",
|
||
zap.String("ip", addr.String()),
|
||
zap.String("address_list4", p.args.AddressList4))
|
||
|
||
case *dns.AAAA:
|
||
// 跳过 IPv6 记录
|
||
p.log.Debug("skipping AAAA record (IPv6 not supported)")
|
||
continue
|
||
default:
|
||
p.log.Debug("skipping non-A record", zap.String("type", fmt.Sprintf("%T", rr)))
|
||
continue
|
||
}
|
||
}
|
||
|
||
if len(addresses) == 0 {
|
||
p.log.Debug("no IPv4 addresses to process")
|
||
return nil
|
||
}
|
||
|
||
// 并发处理所有地址
|
||
var wg sync.WaitGroup
|
||
var mu sync.Mutex
|
||
var errors []error
|
||
addedCount := 0
|
||
|
||
for _, addr := range addresses {
|
||
wg.Add(1)
|
||
go func(addr netip.Addr) {
|
||
defer wg.Done()
|
||
|
||
// 获取工作池槽位
|
||
select {
|
||
case p.workerPool <- struct{}{}:
|
||
defer func() { <-p.workerPool }()
|
||
default:
|
||
// 如果工作池满了,直接处理(避免阻塞)
|
||
p.log.Debug("worker pool full, processing directly")
|
||
}
|
||
|
||
if err := p.addAddressToMikrotik(addr, p.args.AddressList4, p.args.Mask4); err != nil {
|
||
mu.Lock()
|
||
errors = append(errors, err)
|
||
mu.Unlock()
|
||
} else {
|
||
mu.Lock()
|
||
addedCount++
|
||
mu.Unlock()
|
||
}
|
||
}(addr)
|
||
}
|
||
|
||
// 等待所有工作完成
|
||
wg.Wait()
|
||
|
||
// 记录结果
|
||
if addedCount > 0 {
|
||
p.log.Info("concurrently added IPv4 addresses to MikroTik",
|
||
zap.Int("success_count", addedCount),
|
||
zap.Int("total_count", len(addresses)),
|
||
zap.Int("error_count", len(errors)))
|
||
} else {
|
||
p.log.Debug("no IPv4 addresses added to MikroTik")
|
||
}
|
||
|
||
// 如果有错误,返回第一个错误
|
||
if len(errors) > 0 {
|
||
return fmt.Errorf("some addresses failed to add: %v", errors[0])
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) addAddressToMikrotik(addr netip.Addr, listName string, mask int) error {
|
||
p.log.Debug("addAddressToMikrotik called",
|
||
zap.String("addr", addr.String()),
|
||
zap.String("listName", listName),
|
||
zap.Int("mask", mask))
|
||
|
||
// 构建 CIDR 格式的地址,将 IP 转换为网段地址
|
||
var cidrAddr string
|
||
if addr.Is4() {
|
||
// 将 IPv4 地址转换为网段地址(主机位清零)
|
||
networkAddr := netip.PrefixFrom(addr, p.args.Mask4).Addr()
|
||
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask4)
|
||
} else {
|
||
// 将 IPv6 地址转换为网段地址(主机位清零)
|
||
networkAddr := netip.PrefixFrom(addr, p.args.Mask6).Addr()
|
||
cidrAddr = networkAddr.String() + "/" + strconv.Itoa(p.args.Mask6)
|
||
}
|
||
|
||
p.log.Debug("checking address", zap.String("cidr", cidrAddr), zap.String("list", listName))
|
||
|
||
// 检查地址是否已存在
|
||
exists, err := p.addressExists(listName, cidrAddr)
|
||
if err != nil {
|
||
// 如果检查失败,可能是地址列表不存在,继续尝试添加
|
||
p.log.Debug("failed to check if address exists, will try to add anyway", zap.Error(err))
|
||
} else if exists {
|
||
// 地址已存在,跳过
|
||
p.log.Debug("address already exists", zap.String("cidr", cidrAddr), zap.String("list", listName))
|
||
return nil
|
||
}
|
||
|
||
// 构造 RouterOS 参数,注意必须以 = 开头!
|
||
params := []string{
|
||
"=list=" + listName,
|
||
"=address=" + cidrAddr,
|
||
}
|
||
|
||
// 添加注释(如果配置了)
|
||
if p.args.Comment != "" {
|
||
params = append(params, "=comment="+p.args.Comment)
|
||
}
|
||
|
||
// 添加超时时间(如果配置了)
|
||
if p.args.TimeoutAddr > 0 {
|
||
params = append(params, "=timeout="+strconv.Itoa(p.args.TimeoutAddr))
|
||
}
|
||
|
||
p.log.Info("adding address to MikroTik",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.String("list", listName),
|
||
zap.String("comment", p.args.Comment),
|
||
zap.Int("timeout", p.args.TimeoutAddr))
|
||
|
||
p.log.Debug("Add to list: ", zap.Strings("params", params))
|
||
|
||
// 发送到 RouterOS,带重试机制
|
||
maxRetries := 3
|
||
for i := 0; i < maxRetries; i++ {
|
||
// 使用读锁保护连接访问
|
||
p.mu.RLock()
|
||
conn := p.conn
|
||
p.mu.RUnlock()
|
||
|
||
if conn == nil {
|
||
p.log.Error("connection is nil")
|
||
return fmt.Errorf("connection is nil")
|
||
}
|
||
|
||
args := append([]string{"/ip/firewall/address-list/add"}, params...)
|
||
_, err = conn.Run(args...)
|
||
if err != nil {
|
||
if strings.Contains(err.Error(), "already have such entry") {
|
||
p.log.Debug("Already exists: ", zap.String("cidr", cidrAddr))
|
||
return nil
|
||
}
|
||
|
||
// 如果是连接错误,尝试重新连接
|
||
if strings.Contains(err.Error(), "EOF") || strings.Contains(err.Error(), "connection") {
|
||
p.log.Warn("connection error, attempting to reconnect",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.Int("retry", i+1),
|
||
zap.Error(err))
|
||
|
||
// 使用写锁保护重连操作
|
||
p.mu.Lock()
|
||
if err := p.reconnect(); err != nil {
|
||
p.mu.Unlock()
|
||
p.log.Error("failed to reconnect", zap.Error(err))
|
||
continue
|
||
}
|
||
p.mu.Unlock()
|
||
|
||
// 重试
|
||
continue
|
||
}
|
||
|
||
// 其他错误,记录并返回
|
||
p.log.Error("failed to add address to MikroTik",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.String("list", listName),
|
||
zap.Error(err))
|
||
return fmt.Errorf("failed to add address %s to list %s: from RouterOS device: %v", cidrAddr, listName, err)
|
||
}
|
||
|
||
// 成功,跳出重试循环
|
||
break
|
||
}
|
||
|
||
p.log.Info("successfully added address to MikroTik",
|
||
zap.String("cidr", cidrAddr),
|
||
zap.String("list", listName))
|
||
|
||
return nil
|
||
}
|
||
|
||
func (p *mikrotikAddressListPlugin) addressExists(listName, address string) (bool, error) {
|
||
// 查询地址列表中是否已存在该地址
|
||
params := []string{
|
||
"=list=" + listName,
|
||
"=address=" + address,
|
||
}
|
||
|
||
p.log.Debug("checking address existence",
|
||
zap.String("list", listName),
|
||
zap.String("address", address))
|
||
|
||
args := append([]string{"/ip/firewall/address-list/print"}, params...)
|
||
resp, err := p.conn.Run(args...)
|
||
if err != nil {
|
||
p.log.Error("failed to check address existence",
|
||
zap.String("list", listName),
|
||
zap.String("address", address),
|
||
zap.Error(err))
|
||
return false, err
|
||
}
|
||
|
||
// 如果返回结果不为空,说明地址已存在
|
||
exists := len(resp.Re) > 0
|
||
p.log.Debug("address existence check",
|
||
zap.String("list", listName),
|
||
zap.String("address", address),
|
||
zap.Bool("exists", exists))
|
||
|
||
return exists, nil
|
||
}
|