269 lines
6.9 KiB
Go
269 lines
6.9 KiB
Go
package smart_fallback
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"fmt"
|
||
"net/netip"
|
||
"os"
|
||
"time"
|
||
|
||
"github.com/IrineSistiana/mosdns/v5/coremain"
|
||
"github.com/IrineSistiana/mosdns/v5/pkg/matcher/netlist"
|
||
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
|
||
"github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
|
||
"github.com/miekg/dns"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
const PluginType = "smart_fallback"
|
||
|
||
func init() {
|
||
coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) })
|
||
}
|
||
|
||
// Args 配置参数
|
||
type Args struct {
|
||
Primary string `yaml:"primary"` // 主上游(国内DNS)
|
||
Secondary string `yaml:"secondary"` // 备用上游(国际DNS)
|
||
ChinaIP []string `yaml:"china_ip"` // CN IP地址表文件路径
|
||
Timeout int `yaml:"timeout"` // 超时时间(毫秒)
|
||
AlwaysStandby bool `yaml:"always_standby"` // 是否总是同时查询备用
|
||
Verbose bool `yaml:"verbose"` // 是否启用详细日志
|
||
}
|
||
|
||
type SmartFallback struct {
|
||
primary sequence.Executable // 主上游执行器
|
||
secondary sequence.Executable // 备用上游执行器
|
||
chinaIPList *netlist.List // CN IP地址匹配器
|
||
timeout time.Duration
|
||
alwaysStandby bool
|
||
verbose bool
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// Init 初始化插件
|
||
func Init(bp *coremain.BP, args any) (any, error) {
|
||
cfg := args.(*Args)
|
||
|
||
// 1. 加载主上游
|
||
primary := bp.M().GetPlugin(cfg.Primary)
|
||
if primary == nil {
|
||
return nil, fmt.Errorf("无法加载主上游 %s", cfg.Primary)
|
||
}
|
||
primaryExec := sequence.ToExecutable(primary)
|
||
if primaryExec == nil {
|
||
return nil, fmt.Errorf("主上游 %s 不是可执行插件", cfg.Primary)
|
||
}
|
||
|
||
// 2. 加载备用上游
|
||
secondary := bp.M().GetPlugin(cfg.Secondary)
|
||
if secondary == nil {
|
||
return nil, fmt.Errorf("无法加载备用上游 %s", cfg.Secondary)
|
||
}
|
||
secondaryExec := sequence.ToExecutable(secondary)
|
||
if secondaryExec == nil {
|
||
return nil, fmt.Errorf("备用上游 %s 不是可执行插件", cfg.Secondary)
|
||
}
|
||
|
||
// 3. 加载CN IP地址表
|
||
chinaIPList := netlist.NewList()
|
||
for _, file := range cfg.ChinaIP {
|
||
b, err := os.ReadFile(file)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("无法读取CN IP地址表文件 %s: %w", file, err)
|
||
}
|
||
if err := netlist.LoadFromReader(chinaIPList, bytes.NewReader(b)); err != nil {
|
||
return nil, fmt.Errorf("无法加载CN IP地址表文件 %s: %w", file, err)
|
||
}
|
||
}
|
||
chinaIPList.Sort()
|
||
|
||
// 4. 设置超时
|
||
timeout := time.Duration(cfg.Timeout) * time.Millisecond
|
||
if timeout == 0 {
|
||
timeout = 2000 * time.Millisecond // 默认2秒
|
||
}
|
||
|
||
return &SmartFallback{
|
||
primary: primaryExec,
|
||
secondary: secondaryExec,
|
||
chinaIPList: chinaIPList,
|
||
timeout: timeout,
|
||
alwaysStandby: cfg.AlwaysStandby,
|
||
verbose: cfg.Verbose,
|
||
logger: bp.L(),
|
||
}, nil
|
||
}
|
||
|
||
// Exec 执行查询逻辑
|
||
func (s *SmartFallback) Exec(ctx context.Context, qCtx *query_context.Context) error {
|
||
// 设置超时
|
||
ctx, cancel := context.WithTimeout(ctx, s.timeout)
|
||
defer cancel()
|
||
|
||
if s.verbose {
|
||
s.logger.Info("smart_fallback start",
|
||
zap.String("domain", qCtx.Q().Question[0].Name))
|
||
}
|
||
|
||
// 根据配置选择查询策略
|
||
if s.alwaysStandby {
|
||
return s.execParallel(ctx, qCtx)
|
||
}
|
||
|
||
return s.execSequential(ctx, qCtx)
|
||
}
|
||
|
||
// execSequential 顺序查询(推荐:节省资源)
|
||
func (s *SmartFallback) execSequential(ctx context.Context, qCtx *query_context.Context) error {
|
||
// 1. 先查询主上游(国内DNS)
|
||
qCtxCopy := qCtx.Copy()
|
||
err := s.primary.Exec(ctx, qCtxCopy)
|
||
if err != nil {
|
||
// 主上游失败,直接用备用上游
|
||
if s.verbose {
|
||
s.logger.Warn("primary upstream failed, using secondary",
|
||
zap.Error(err))
|
||
}
|
||
return s.secondary.Exec(ctx, qCtx)
|
||
}
|
||
|
||
resp := qCtxCopy.R()
|
||
if resp == nil || len(resp.Answer) == 0 {
|
||
// 无结果,用备用上游
|
||
if s.verbose {
|
||
s.logger.Info("primary upstream returned no answer, using secondary")
|
||
}
|
||
return s.secondary.Exec(ctx, qCtx)
|
||
}
|
||
|
||
// 2. 检查返回的IP是否在CN地址表
|
||
if s.isResponseFromChina(resp) {
|
||
// ✅ 是CN IP,直接返回
|
||
qCtx.SetResponse(resp)
|
||
if s.verbose {
|
||
s.logger.Info("response from China, using primary result")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 3. 非CN IP,使用备用上游重新查询
|
||
if s.verbose {
|
||
s.logger.Info("response not from China, using secondary upstream")
|
||
}
|
||
return s.secondary.Exec(ctx, qCtx)
|
||
}
|
||
|
||
// execParallel 并行查询(可选:更快但消耗资源)
|
||
func (s *SmartFallback) execParallel(ctx context.Context, qCtx *query_context.Context) error {
|
||
type result struct {
|
||
resp *dns.Msg
|
||
err error
|
||
from string
|
||
}
|
||
|
||
resChan := make(chan result, 2)
|
||
|
||
// 同时查询主备上游
|
||
go func() {
|
||
qCtxCopy := qCtx.Copy()
|
||
err := s.primary.Exec(ctx, qCtxCopy)
|
||
resChan <- result{resp: qCtxCopy.R(), err: err, from: "primary"}
|
||
}()
|
||
|
||
go func() {
|
||
qCtxCopy := qCtx.Copy()
|
||
err := s.secondary.Exec(ctx, qCtxCopy)
|
||
resChan <- result{resp: qCtxCopy.R(), err: err, from: "secondary"}
|
||
}()
|
||
|
||
// 优先采用主上游的CN结果
|
||
var primaryRes, secondaryRes *result
|
||
|
||
for i := 0; i < 2; i++ {
|
||
res := <-resChan
|
||
if res.from == "primary" {
|
||
primaryRes = &res
|
||
} else {
|
||
secondaryRes = &res
|
||
}
|
||
|
||
// 如果主上游返回CN IP,立即采用
|
||
if primaryRes != nil && primaryRes.err == nil &&
|
||
s.isResponseFromChina(primaryRes.resp) {
|
||
qCtx.SetResponse(primaryRes.resp)
|
||
if s.verbose {
|
||
s.logger.Info("parallel mode: primary returned China IP, using it")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 如果备用上游先返回且主上游失败,采用备用
|
||
if secondaryRes != nil && secondaryRes.err == nil &&
|
||
(primaryRes == nil || primaryRes.err != nil) {
|
||
qCtx.SetResponse(secondaryRes.resp)
|
||
if s.verbose {
|
||
s.logger.Info("parallel mode: secondary returned result first, using it")
|
||
}
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// 优先返回备用上游结果
|
||
if secondaryRes != nil && secondaryRes.err == nil {
|
||
qCtx.SetResponse(secondaryRes.resp)
|
||
return nil
|
||
}
|
||
|
||
if primaryRes != nil {
|
||
return primaryRes.err
|
||
}
|
||
|
||
return fmt.Errorf("所有上游查询失败")
|
||
}
|
||
|
||
// isResponseFromChina 检查响应IP是否来自中国
|
||
func (s *SmartFallback) isResponseFromChina(resp *dns.Msg) bool {
|
||
if resp == nil {
|
||
return false
|
||
}
|
||
|
||
// 遍历所有应答记录
|
||
for _, ans := range resp.Answer {
|
||
var ip netip.Addr
|
||
|
||
switch rr := ans.(type) {
|
||
case *dns.A:
|
||
// IPv4 地址
|
||
ip = netip.AddrFrom4([4]byte(rr.A))
|
||
case *dns.AAAA:
|
||
// IPv6 地址
|
||
ip = netip.AddrFrom16([16]byte(rr.AAAA))
|
||
default:
|
||
continue
|
||
}
|
||
|
||
// 检查是否在CN地址表
|
||
matched := s.chinaIPList.Match(ip)
|
||
|
||
if !matched {
|
||
// 只要有一个IP不在CN表,就认为是国外IP
|
||
if s.verbose {
|
||
s.logger.Info("detected foreign IP",
|
||
zap.String("ip", ip.String()),
|
||
zap.String("domain", resp.Question[0].Name))
|
||
}
|
||
return false
|
||
}
|
||
}
|
||
|
||
// 所有IP都在CN表中
|
||
if s.verbose {
|
||
s.logger.Info("all IPs are from China")
|
||
}
|
||
return true
|
||
}
|
||
|
||
var _ sequence.Executable = (*SmartFallback)(nil)
|