package smart_fallback import ( "bytes" "context" "fmt" "net/netip" "os" "time" "github.com/IrineSistiana/mosdns/v5/coremain" "github.com/IrineSistiana/mosdns/v5/pkg/matcher/netlist" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" "github.com/miekg/dns" "go.uber.org/zap" ) const PluginType = "smart_fallback" func init() { coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) } // Args 配置参数 type Args struct { Primary string `yaml:"primary"` // 主上游(国内DNS) Secondary string `yaml:"secondary"` // 备用上游(国际DNS) ChinaIP []string `yaml:"china_ip"` // CN IP地址表文件路径 Timeout int `yaml:"timeout"` // 超时时间(毫秒) AlwaysStandby bool `yaml:"always_standby"` // 是否总是同时查询备用 Verbose bool `yaml:"verbose"` // 是否启用详细日志 } type SmartFallback struct { primary sequence.Executable // 主上游执行器 secondary sequence.Executable // 备用上游执行器 chinaIPList *netlist.List // CN IP地址匹配器 timeout time.Duration alwaysStandby bool verbose bool logger *zap.Logger } // Init 初始化插件 func Init(bp *coremain.BP, args any) (any, error) { cfg := args.(*Args) // 1. 加载主上游 primary := bp.M().GetPlugin(cfg.Primary) if primary == nil { return nil, fmt.Errorf("无法加载主上游 %s", cfg.Primary) } primaryExec := sequence.ToExecutable(primary) if primaryExec == nil { return nil, fmt.Errorf("主上游 %s 不是可执行插件", cfg.Primary) } // 2. 加载备用上游 secondary := bp.M().GetPlugin(cfg.Secondary) if secondary == nil { return nil, fmt.Errorf("无法加载备用上游 %s", cfg.Secondary) } secondaryExec := sequence.ToExecutable(secondary) if secondaryExec == nil { return nil, fmt.Errorf("备用上游 %s 不是可执行插件", cfg.Secondary) } // 3. 加载CN IP地址表 chinaIPList := netlist.NewList() for _, file := range cfg.ChinaIP { b, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("无法读取CN IP地址表文件 %s: %w", file, err) } if err := netlist.LoadFromReader(chinaIPList, bytes.NewReader(b)); err != nil { return nil, fmt.Errorf("无法加载CN IP地址表文件 %s: %w", file, err) } } chinaIPList.Sort() // 4. 设置超时 timeout := time.Duration(cfg.Timeout) * time.Millisecond if timeout == 0 { timeout = 2000 * time.Millisecond // 默认2秒 } return &SmartFallback{ primary: primaryExec, secondary: secondaryExec, chinaIPList: chinaIPList, timeout: timeout, alwaysStandby: cfg.AlwaysStandby, verbose: cfg.Verbose, logger: bp.L(), }, nil } // Exec 执行查询逻辑 func (s *SmartFallback) Exec(ctx context.Context, qCtx *query_context.Context) error { // 设置超时 ctx, cancel := context.WithTimeout(ctx, s.timeout) defer cancel() if s.verbose { s.logger.Info("smart_fallback start", zap.String("domain", qCtx.Q().Question[0].Name)) } // 根据配置选择查询策略 if s.alwaysStandby { return s.execParallel(ctx, qCtx) } return s.execSequential(ctx, qCtx) } // execSequential 顺序查询(推荐:节省资源) func (s *SmartFallback) execSequential(ctx context.Context, qCtx *query_context.Context) error { // 1. 先查询主上游(国内DNS) qCtxCopy := qCtx.Copy() err := s.primary.Exec(ctx, qCtxCopy) if err != nil { // 主上游失败,直接用备用上游 if s.verbose { s.logger.Warn("primary upstream failed, using secondary", zap.Error(err)) } return s.secondary.Exec(ctx, qCtx) } resp := qCtxCopy.R() if resp == nil || len(resp.Answer) == 0 { // 无结果,用备用上游 if s.verbose { s.logger.Info("primary upstream returned no answer, using secondary") } return s.secondary.Exec(ctx, qCtx) } // 2. 检查返回的IP是否在CN地址表 if s.isResponseFromChina(resp) { // ✅ 是CN IP,直接返回 qCtx.SetResponse(resp) if s.verbose { s.logger.Info("response from China, using primary result") } return nil } // 3. 非CN IP,使用备用上游重新查询 if s.verbose { s.logger.Info("response not from China, using secondary upstream") } return s.secondary.Exec(ctx, qCtx) } // execParallel 并行查询(可选:更快但消耗资源) func (s *SmartFallback) execParallel(ctx context.Context, qCtx *query_context.Context) error { type result struct { resp *dns.Msg err error from string } resChan := make(chan result, 2) // 同时查询主备上游 go func() { qCtxCopy := qCtx.Copy() err := s.primary.Exec(ctx, qCtxCopy) resChan <- result{resp: qCtxCopy.R(), err: err, from: "primary"} }() go func() { qCtxCopy := qCtx.Copy() err := s.secondary.Exec(ctx, qCtxCopy) resChan <- result{resp: qCtxCopy.R(), err: err, from: "secondary"} }() // 优先采用主上游的CN结果 var primaryRes, secondaryRes *result for i := 0; i < 2; i++ { res := <-resChan if res.from == "primary" { primaryRes = &res } else { secondaryRes = &res } // 如果主上游返回CN IP,立即采用 if primaryRes != nil && primaryRes.err == nil && s.isResponseFromChina(primaryRes.resp) { qCtx.SetResponse(primaryRes.resp) if s.verbose { s.logger.Info("parallel mode: primary returned China IP, using it") } return nil } // 如果备用上游先返回且主上游失败,采用备用 if secondaryRes != nil && secondaryRes.err == nil && (primaryRes == nil || primaryRes.err != nil) { qCtx.SetResponse(secondaryRes.resp) if s.verbose { s.logger.Info("parallel mode: secondary returned result first, using it") } return nil } } // 优先返回备用上游结果 if secondaryRes != nil && secondaryRes.err == nil { qCtx.SetResponse(secondaryRes.resp) return nil } if primaryRes != nil { return primaryRes.err } return fmt.Errorf("所有上游查询失败") } // isResponseFromChina 检查响应IP是否来自中国 func (s *SmartFallback) isResponseFromChina(resp *dns.Msg) bool { if resp == nil { return false } // 遍历所有应答记录 for _, ans := range resp.Answer { var ip netip.Addr switch rr := ans.(type) { case *dns.A: // IPv4 地址 ip = netip.AddrFrom4([4]byte(rr.A)) case *dns.AAAA: // IPv6 地址 ip = netip.AddrFrom16([16]byte(rr.AAAA)) default: continue } // 检查是否在CN地址表 matched := s.chinaIPList.Match(ip) if !matched { // 只要有一个IP不在CN表,就认为是国外IP if s.verbose { s.logger.Info("detected foreign IP", zap.String("ip", ip.String()), zap.String("domain", resp.Question[0].Name)) } return false } } // 所有IP都在CN表中 if s.verbose { s.logger.Info("all IPs are from China") } return true } var _ sequence.Executable = (*SmartFallback)(nil)