/* * Copyright (C) 2020-2022, IrineSistiana * * This file is part of mosdns. * * mosdns is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * mosdns is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ package fastforward import ( "context" "crypto/tls" "errors" "fmt" "strings" "time" "github.com/IrineSistiana/mosdns/v5/coremain" "github.com/IrineSistiana/mosdns/v5/pkg/pool" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/IrineSistiana/mosdns/v5/pkg/upstream" "github.com/IrineSistiana/mosdns/v5/pkg/utils" "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" "github.com/miekg/dns" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" ) const PluginType = "forward" func init() { coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) }) sequence.MustRegExecQuickSetup(PluginType, quickSetup) } const ( maxConcurrentQueries = 3 queryTimeout = time.Second * 5 ) type Args struct { Upstreams []UpstreamConfig `yaml:"upstreams"` Concurrent int `yaml:"concurrent"` // Global options. Socks5 string `yaml:"socks5"` SoMark int `yaml:"so_mark"` BindToDevice string `yaml:"bind_to_device"` Bootstrap string `yaml:"bootstrap"` BootstrapVer int `yaml:"bootstrap_version"` } type UpstreamConfig struct { Tag string `yaml:"tag"` Addr string `yaml:"addr"` // Required. DialAddr string `yaml:"dial_addr"` IdleTimeout int `yaml:"idle_timeout"` // Deprecated: This option has no affect. // TODO: (v6) Remove this option. MaxConns int `yaml:"max_conns"` EnablePipeline bool `yaml:"enable_pipeline"` EnableHTTP3 bool `yaml:"enable_http3"` InsecureSkipVerify bool `yaml:"insecure_skip_verify"` Socks5 string `yaml:"socks5"` SoMark int `yaml:"so_mark"` BindToDevice string `yaml:"bind_to_device"` Bootstrap string `yaml:"bootstrap"` BootstrapVer int `yaml:"bootstrap_version"` } func Init(bp *coremain.BP, args any) (any, error) { f, err := NewForward(args.(*Args), Opts{Logger: bp.L(), MetricsTag: bp.Tag()}) if err != nil { return nil, err } if err := f.RegisterMetricsTo(prometheus.WrapRegistererWithPrefix(PluginType+"_", bp.M().GetMetricsReg())); err != nil { _ = f.Close() return nil, err } return f, nil } var _ sequence.Executable = (*Forward)(nil) var _ sequence.QuickConfigurableExec = (*Forward)(nil) type Forward struct { args *Args logger *zap.Logger us []*upstreamWrapper tag2Upstream map[string]*upstreamWrapper // for fast tag lookup only. } type Opts struct { Logger *zap.Logger MetricsTag string } // NewForward inits a Forward from given args. // args must contain at least one upstream. func NewForward(args *Args, opt Opts) (*Forward, error) { if len(args.Upstreams) == 0 { return nil, errors.New("no upstream is configured") } if opt.Logger == nil { opt.Logger = zap.NewNop() } f := &Forward{ args: args, logger: opt.Logger, tag2Upstream: make(map[string]*upstreamWrapper), } applyGlobal := func(c *UpstreamConfig) { utils.SetDefaultString(&c.Socks5, args.Socks5) utils.SetDefaultUnsignNum(&c.SoMark, args.SoMark) utils.SetDefaultString(&c.BindToDevice, args.BindToDevice) utils.SetDefaultString(&c.Bootstrap, args.Bootstrap) utils.SetDefaultUnsignNum(&c.BootstrapVer, args.BootstrapVer) } for i, c := range args.Upstreams { if len(c.Addr) == 0 { return nil, fmt.Errorf("#%d upstream invalid args, addr is required", i) } applyGlobal(&c) uw := newWrapper(i, c, opt.MetricsTag) uOpt := upstream.Opt{ DialAddr: c.DialAddr, Socks5: c.Socks5, SoMark: c.SoMark, BindToDevice: c.BindToDevice, IdleTimeout: time.Duration(c.IdleTimeout) * time.Second, EnablePipeline: c.EnablePipeline, EnableHTTP3: c.EnableHTTP3, Bootstrap: c.Bootstrap, BootstrapVer: c.BootstrapVer, TLSConfig: &tls.Config{ InsecureSkipVerify: c.InsecureSkipVerify, ClientSessionCache: tls.NewLRUClientSessionCache(4), }, Logger: opt.Logger, EventObserver: uw, } u, err := upstream.NewUpstream(c.Addr, uOpt) if err != nil { _ = f.Close() return nil, fmt.Errorf("failed to init upstream #%d: %w", i, err) } uw.u = u f.us = append(f.us, uw) if len(c.Tag) > 0 { if _, dup := f.tag2Upstream[c.Tag]; dup { _ = f.Close() return nil, fmt.Errorf("duplicated upstream tag %s", c.Tag) } f.tag2Upstream[c.Tag] = uw } } return f, nil } func (f *Forward) RegisterMetricsTo(r prometheus.Registerer) error { for _, wu := range f.us { // Only register metrics for upstream that has a tag. if len(wu.cfg.Tag) == 0 { continue } if err := wu.registerMetricsTo(r); err != nil { return err } } return nil } func (f *Forward) Exec(ctx context.Context, qCtx *query_context.Context) (err error) { r, err := f.exchange(ctx, qCtx, f.us) if err != nil { return err } qCtx.SetResponse(r) return nil } // QuickConfigureExec format: [upstream_tag]... func (f *Forward) QuickConfigureExec(args string) (any, error) { var us []*upstreamWrapper if len(args) == 0 { // No args, use all upstreams. us = f.us } else { // Pick up upstreams by tags. for _, tag := range strings.Fields(args) { u := f.tag2Upstream[tag] if u == nil { return nil, fmt.Errorf("cannot find upstream by tag %s", tag) } us = append(us, u) } } var execFunc sequence.ExecutableFunc = func(ctx context.Context, qCtx *query_context.Context) error { r, err := f.exchange(ctx, qCtx, us) if err != nil { return err } qCtx.SetResponse(r) return nil } return execFunc, nil } func (f *Forward) Close() error { for _, u := range f.us { _ = u.Close() } return nil } func (f *Forward) exchange(ctx context.Context, qCtx *query_context.Context, us []*upstreamWrapper) (*dns.Msg, error) { if len(us) == 0 { return nil, errors.New("no upstream to exchange") } queryPayload, err := pool.PackBuffer(qCtx.Q()) if err != nil { return nil, err } defer pool.ReleaseBuf(queryPayload) concurrent := f.args.Concurrent if concurrent <= 0 { concurrent = 1 } if concurrent > maxConcurrentQueries { concurrent = maxConcurrentQueries } type res struct { r *dns.Msg err error } resChan := make(chan res) done := make(chan struct{}) defer close(done) for i := 0; i < concurrent; i++ { u := randPick(us) qc := copyPayload(queryPayload) go func(uqid uint32, question dns.Question) { defer pool.ReleaseBuf(qc) // Give each upstream a fixed timeout to finish the query. upstreamCtx, cancel := context.WithTimeout(context.Background(), queryTimeout) defer cancel() var r *dns.Msg respPayload, err := u.ExchangeContext(upstreamCtx, *qc) if err != nil { f.logger.Warn( "upstream error", zap.Uint32("uqid", uqid), zap.String("qname", question.Name), zap.Uint16("qclass", question.Qclass), zap.Uint16("qtype", question.Qtype), zap.String("upstream", u.name()), zap.Error(err), ) } else { r = new(dns.Msg) err = r.Unpack(*respPayload) pool.ReleaseBuf(respPayload) if err != nil { r = nil } } select { case resChan <- res{r: r, err: err}: case <-done: } }(qCtx.Id(), qCtx.QQuestion()) } for i := 0; i < concurrent; i++ { select { case res := <-resChan: r, err := res.r, res.err if err != nil { continue } // Retry until the last if i < concurrent-1 && r.Rcode != dns.RcodeSuccess && r.Rcode != dns.RcodeNameError { continue } return r, nil case <-ctx.Done(): return nil, context.Cause(ctx) } } return nil, errors.New("all upstream servers failed") } func quickSetup(bq sequence.BQ, s string) (any, error) { args := new(Args) args.Concurrent = maxConcurrentQueries for _, u := range strings.Fields(s) { args.Upstreams = append(args.Upstreams, UpstreamConfig{Addr: u}) } return NewForward(args, Opts{Logger: bq.L()}) }