mosdns/plugin/executable/smart_fallback/smart_fallback.go
dengxiongjian 0413ee5d44
Some checks failed
Test mosdns / build (push) Has been cancelled
二次开发
2025-10-16 21:07:48 +08:00

269 lines
6.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package smart_fallback
import (
"bytes"
"context"
"fmt"
"net/netip"
"os"
"time"
"github.com/IrineSistiana/mosdns/v5/coremain"
"github.com/IrineSistiana/mosdns/v5/pkg/matcher/netlist"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
"github.com/miekg/dns"
"go.uber.org/zap"
)
const PluginType = "smart_fallback"
func init() {
coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) })
}
// Args 配置参数
type Args struct {
Primary string `yaml:"primary"` // 主上游国内DNS
Secondary string `yaml:"secondary"` // 备用上游国际DNS
ChinaIP []string `yaml:"china_ip"` // CN IP地址表文件路径
Timeout int `yaml:"timeout"` // 超时时间(毫秒)
AlwaysStandby bool `yaml:"always_standby"` // 是否总是同时查询备用
Verbose bool `yaml:"verbose"` // 是否启用详细日志
}
type SmartFallback struct {
primary sequence.Executable // 主上游执行器
secondary sequence.Executable // 备用上游执行器
chinaIPList *netlist.List // CN IP地址匹配器
timeout time.Duration
alwaysStandby bool
verbose bool
logger *zap.Logger
}
// Init 初始化插件
func Init(bp *coremain.BP, args any) (any, error) {
cfg := args.(*Args)
// 1. 加载主上游
primary := bp.M().GetPlugin(cfg.Primary)
if primary == nil {
return nil, fmt.Errorf("无法加载主上游 %s", cfg.Primary)
}
primaryExec := sequence.ToExecutable(primary)
if primaryExec == nil {
return nil, fmt.Errorf("主上游 %s 不是可执行插件", cfg.Primary)
}
// 2. 加载备用上游
secondary := bp.M().GetPlugin(cfg.Secondary)
if secondary == nil {
return nil, fmt.Errorf("无法加载备用上游 %s", cfg.Secondary)
}
secondaryExec := sequence.ToExecutable(secondary)
if secondaryExec == nil {
return nil, fmt.Errorf("备用上游 %s 不是可执行插件", cfg.Secondary)
}
// 3. 加载CN IP地址表
chinaIPList := netlist.NewList()
for _, file := range cfg.ChinaIP {
b, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("无法读取CN IP地址表文件 %s: %w", file, err)
}
if err := netlist.LoadFromReader(chinaIPList, bytes.NewReader(b)); err != nil {
return nil, fmt.Errorf("无法加载CN IP地址表文件 %s: %w", file, err)
}
}
chinaIPList.Sort()
// 4. 设置超时
timeout := time.Duration(cfg.Timeout) * time.Millisecond
if timeout == 0 {
timeout = 2000 * time.Millisecond // 默认2秒
}
return &SmartFallback{
primary: primaryExec,
secondary: secondaryExec,
chinaIPList: chinaIPList,
timeout: timeout,
alwaysStandby: cfg.AlwaysStandby,
verbose: cfg.Verbose,
logger: bp.L(),
}, nil
}
// Exec 执行查询逻辑
func (s *SmartFallback) Exec(ctx context.Context, qCtx *query_context.Context) error {
// 设置超时
ctx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
if s.verbose {
s.logger.Info("smart_fallback start",
zap.String("domain", qCtx.Q().Question[0].Name))
}
// 根据配置选择查询策略
if s.alwaysStandby {
return s.execParallel(ctx, qCtx)
}
return s.execSequential(ctx, qCtx)
}
// execSequential 顺序查询(推荐:节省资源)
func (s *SmartFallback) execSequential(ctx context.Context, qCtx *query_context.Context) error {
// 1. 先查询主上游国内DNS
qCtxCopy := qCtx.Copy()
err := s.primary.Exec(ctx, qCtxCopy)
if err != nil {
// 主上游失败,直接用备用上游
if s.verbose {
s.logger.Warn("primary upstream failed, using secondary",
zap.Error(err))
}
return s.secondary.Exec(ctx, qCtx)
}
resp := qCtxCopy.R()
if resp == nil || len(resp.Answer) == 0 {
// 无结果,用备用上游
if s.verbose {
s.logger.Info("primary upstream returned no answer, using secondary")
}
return s.secondary.Exec(ctx, qCtx)
}
// 2. 检查返回的IP是否在CN地址表
if s.isResponseFromChina(resp) {
// ✅ 是CN IP直接返回
qCtx.SetResponse(resp)
if s.verbose {
s.logger.Info("response from China, using primary result")
}
return nil
}
// 3. 非CN IP使用备用上游重新查询
if s.verbose {
s.logger.Info("response not from China, using secondary upstream")
}
return s.secondary.Exec(ctx, qCtx)
}
// execParallel 并行查询(可选:更快但消耗资源)
func (s *SmartFallback) execParallel(ctx context.Context, qCtx *query_context.Context) error {
type result struct {
resp *dns.Msg
err error
from string
}
resChan := make(chan result, 2)
// 同时查询主备上游
go func() {
qCtxCopy := qCtx.Copy()
err := s.primary.Exec(ctx, qCtxCopy)
resChan <- result{resp: qCtxCopy.R(), err: err, from: "primary"}
}()
go func() {
qCtxCopy := qCtx.Copy()
err := s.secondary.Exec(ctx, qCtxCopy)
resChan <- result{resp: qCtxCopy.R(), err: err, from: "secondary"}
}()
// 优先采用主上游的CN结果
var primaryRes, secondaryRes *result
for i := 0; i < 2; i++ {
res := <-resChan
if res.from == "primary" {
primaryRes = &res
} else {
secondaryRes = &res
}
// 如果主上游返回CN IP立即采用
if primaryRes != nil && primaryRes.err == nil &&
s.isResponseFromChina(primaryRes.resp) {
qCtx.SetResponse(primaryRes.resp)
if s.verbose {
s.logger.Info("parallel mode: primary returned China IP, using it")
}
return nil
}
// 如果备用上游先返回且主上游失败,采用备用
if secondaryRes != nil && secondaryRes.err == nil &&
(primaryRes == nil || primaryRes.err != nil) {
qCtx.SetResponse(secondaryRes.resp)
if s.verbose {
s.logger.Info("parallel mode: secondary returned result first, using it")
}
return nil
}
}
// 优先返回备用上游结果
if secondaryRes != nil && secondaryRes.err == nil {
qCtx.SetResponse(secondaryRes.resp)
return nil
}
if primaryRes != nil {
return primaryRes.err
}
return fmt.Errorf("所有上游查询失败")
}
// isResponseFromChina 检查响应IP是否来自中国
func (s *SmartFallback) isResponseFromChina(resp *dns.Msg) bool {
if resp == nil {
return false
}
// 遍历所有应答记录
for _, ans := range resp.Answer {
var ip netip.Addr
switch rr := ans.(type) {
case *dns.A:
// IPv4 地址
ip = netip.AddrFrom4([4]byte(rr.A))
case *dns.AAAA:
// IPv6 地址
ip = netip.AddrFrom16([16]byte(rr.AAAA))
default:
continue
}
// 检查是否在CN地址表
matched := s.chinaIPList.Match(ip)
if !matched {
// 只要有一个IP不在CN表就认为是国外IP
if s.verbose {
s.logger.Info("detected foreign IP",
zap.String("ip", ip.String()),
zap.String("domain", resp.Question[0].Name))
}
return false
}
}
// 所有IP都在CN表中
if s.verbose {
s.logger.Info("all IPs are from China")
}
return true
}
var _ sequence.Executable = (*SmartFallback)(nil)