mosdns/plugin/executable/forward/forward.go
dengxiongjian cd761e8145
Some checks are pending
Test mosdns / build (push) Waiting to run
新增Mikrotik API 插入解析ip
2025-07-31 11:28:55 +08:00

328 lines
8.3 KiB
Go

/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package fastforward
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strings"
"time"
"github.com/IrineSistiana/mosdns/v5/coremain"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/IrineSistiana/mosdns/v5/pkg/upstream"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
)
const PluginType = "forward"
func init() {
coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) })
sequence.MustRegExecQuickSetup(PluginType, quickSetup)
}
const (
maxConcurrentQueries = 3
queryTimeout = time.Second * 5
)
type Args struct {
Upstreams []UpstreamConfig `yaml:"upstreams"`
Concurrent int `yaml:"concurrent"`
// Global options.
Socks5 string `yaml:"socks5"`
SoMark int `yaml:"so_mark"`
BindToDevice string `yaml:"bind_to_device"`
Bootstrap string `yaml:"bootstrap"`
BootstrapVer int `yaml:"bootstrap_version"`
}
type UpstreamConfig struct {
Tag string `yaml:"tag"`
Addr string `yaml:"addr"` // Required.
DialAddr string `yaml:"dial_addr"`
IdleTimeout int `yaml:"idle_timeout"`
// Deprecated: This option has no affect.
// TODO: (v6) Remove this option.
MaxConns int `yaml:"max_conns"`
EnablePipeline bool `yaml:"enable_pipeline"`
EnableHTTP3 bool `yaml:"enable_http3"`
InsecureSkipVerify bool `yaml:"insecure_skip_verify"`
Socks5 string `yaml:"socks5"`
SoMark int `yaml:"so_mark"`
BindToDevice string `yaml:"bind_to_device"`
Bootstrap string `yaml:"bootstrap"`
BootstrapVer int `yaml:"bootstrap_version"`
}
func Init(bp *coremain.BP, args any) (any, error) {
f, err := NewForward(args.(*Args), Opts{Logger: bp.L(), MetricsTag: bp.Tag()})
if err != nil {
return nil, err
}
if err := f.RegisterMetricsTo(prometheus.WrapRegistererWithPrefix(PluginType+"_", bp.M().GetMetricsReg())); err != nil {
_ = f.Close()
return nil, err
}
return f, nil
}
var _ sequence.Executable = (*Forward)(nil)
var _ sequence.QuickConfigurableExec = (*Forward)(nil)
type Forward struct {
args *Args
logger *zap.Logger
us []*upstreamWrapper
tag2Upstream map[string]*upstreamWrapper // for fast tag lookup only.
}
type Opts struct {
Logger *zap.Logger
MetricsTag string
}
// NewForward inits a Forward from given args.
// args must contain at least one upstream.
func NewForward(args *Args, opt Opts) (*Forward, error) {
if len(args.Upstreams) == 0 {
return nil, errors.New("no upstream is configured")
}
if opt.Logger == nil {
opt.Logger = zap.NewNop()
}
f := &Forward{
args: args,
logger: opt.Logger,
tag2Upstream: make(map[string]*upstreamWrapper),
}
applyGlobal := func(c *UpstreamConfig) {
utils.SetDefaultString(&c.Socks5, args.Socks5)
utils.SetDefaultUnsignNum(&c.SoMark, args.SoMark)
utils.SetDefaultString(&c.BindToDevice, args.BindToDevice)
utils.SetDefaultString(&c.Bootstrap, args.Bootstrap)
utils.SetDefaultUnsignNum(&c.BootstrapVer, args.BootstrapVer)
}
for i, c := range args.Upstreams {
if len(c.Addr) == 0 {
return nil, fmt.Errorf("#%d upstream invalid args, addr is required", i)
}
applyGlobal(&c)
uw := newWrapper(i, c, opt.MetricsTag)
uOpt := upstream.Opt{
DialAddr: c.DialAddr,
Socks5: c.Socks5,
SoMark: c.SoMark,
BindToDevice: c.BindToDevice,
IdleTimeout: time.Duration(c.IdleTimeout) * time.Second,
EnablePipeline: c.EnablePipeline,
EnableHTTP3: c.EnableHTTP3,
Bootstrap: c.Bootstrap,
BootstrapVer: c.BootstrapVer,
TLSConfig: &tls.Config{
InsecureSkipVerify: c.InsecureSkipVerify,
ClientSessionCache: tls.NewLRUClientSessionCache(4),
},
Logger: opt.Logger,
EventObserver: uw,
}
u, err := upstream.NewUpstream(c.Addr, uOpt)
if err != nil {
_ = f.Close()
return nil, fmt.Errorf("failed to init upstream #%d: %w", i, err)
}
uw.u = u
f.us = append(f.us, uw)
if len(c.Tag) > 0 {
if _, dup := f.tag2Upstream[c.Tag]; dup {
_ = f.Close()
return nil, fmt.Errorf("duplicated upstream tag %s", c.Tag)
}
f.tag2Upstream[c.Tag] = uw
}
}
return f, nil
}
func (f *Forward) RegisterMetricsTo(r prometheus.Registerer) error {
for _, wu := range f.us {
// Only register metrics for upstream that has a tag.
if len(wu.cfg.Tag) == 0 {
continue
}
if err := wu.registerMetricsTo(r); err != nil {
return err
}
}
return nil
}
func (f *Forward) Exec(ctx context.Context, qCtx *query_context.Context) (err error) {
r, err := f.exchange(ctx, qCtx, f.us)
if err != nil {
return err
}
qCtx.SetResponse(r)
return nil
}
// QuickConfigureExec format: [upstream_tag]...
func (f *Forward) QuickConfigureExec(args string) (any, error) {
var us []*upstreamWrapper
if len(args) == 0 { // No args, use all upstreams.
us = f.us
} else { // Pick up upstreams by tags.
for _, tag := range strings.Fields(args) {
u := f.tag2Upstream[tag]
if u == nil {
return nil, fmt.Errorf("cannot find upstream by tag %s", tag)
}
us = append(us, u)
}
}
var execFunc sequence.ExecutableFunc = func(ctx context.Context, qCtx *query_context.Context) error {
r, err := f.exchange(ctx, qCtx, us)
if err != nil {
return err
}
qCtx.SetResponse(r)
return nil
}
return execFunc, nil
}
func (f *Forward) Close() error {
for _, u := range f.us {
_ = u.Close()
}
return nil
}
func (f *Forward) exchange(ctx context.Context, qCtx *query_context.Context, us []*upstreamWrapper) (*dns.Msg, error) {
if len(us) == 0 {
return nil, errors.New("no upstream to exchange")
}
queryPayload, err := pool.PackBuffer(qCtx.Q())
if err != nil {
return nil, err
}
defer pool.ReleaseBuf(queryPayload)
concurrent := f.args.Concurrent
if concurrent <= 0 {
concurrent = 1
}
if concurrent > maxConcurrentQueries {
concurrent = maxConcurrentQueries
}
type res struct {
r *dns.Msg
err error
}
resChan := make(chan res)
done := make(chan struct{})
defer close(done)
for i := 0; i < concurrent; i++ {
u := randPick(us)
qc := copyPayload(queryPayload)
go func(uqid uint32, question dns.Question) {
defer pool.ReleaseBuf(qc)
// Give each upstream a fixed timeout to finish the query.
upstreamCtx, cancel := context.WithTimeout(context.Background(), queryTimeout)
defer cancel()
var r *dns.Msg
respPayload, err := u.ExchangeContext(upstreamCtx, *qc)
if err != nil {
f.logger.Warn(
"upstream error",
zap.Uint32("uqid", uqid),
zap.String("qname", question.Name),
zap.Uint16("qclass", question.Qclass),
zap.Uint16("qtype", question.Qtype),
zap.String("upstream", u.name()),
zap.Error(err),
)
} else {
r = new(dns.Msg)
err = r.Unpack(*respPayload)
pool.ReleaseBuf(respPayload)
if err != nil {
r = nil
}
}
select {
case resChan <- res{r: r, err: err}:
case <-done:
}
}(qCtx.Id(), qCtx.QQuestion())
}
for i := 0; i < concurrent; i++ {
select {
case res := <-resChan:
r, err := res.r, res.err
if err != nil {
continue
}
// Retry until the last
if i < concurrent-1 && r.Rcode != dns.RcodeSuccess && r.Rcode != dns.RcodeNameError {
continue
}
return r, nil
case <-ctx.Done():
return nil, context.Cause(ctx)
}
}
return nil, errors.New("all upstream servers failed")
}
func quickSetup(bq sequence.BQ, s string) (any, error) {
args := new(Args)
args.Concurrent = maxConcurrentQueries
for _, u := range strings.Fields(s) {
args.Upstreams = append(args.Upstreams, UpstreamConfig{Addr: u})
}
return NewForward(args, Opts{Logger: bq.L()})
}