mosdns/pkg/upstream/transport/conn_traditional.go
dengxiongjian cd761e8145
Some checks are pending
Test mosdns / build (push) Waiting to run
新增Mikrotik API 插入解析ip
2025-07-31 11:28:55 +08:00

284 lines
7.6 KiB
Go

/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package transport
import (
"context"
"encoding/binary"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
)
var (
ErrTDCTooManyQueries = errors.New("too many queries") // Connection has too many ongoing queries.
ErrTDCClosed = errors.New("dns connection closed")
)
var _ DnsConn = (*TraditionalDnsConn)(nil)
// TraditionalDnsConn is a low-level connection for traditional dns protocol, where
// dns frames transport in a single and simple connection. (e.g. udp, tcp, tls)
type TraditionalDnsConn struct {
c NetConn
isTcp bool
idleTimeout time.Duration
maxCq int
closeOnce sync.Once
closeNotify chan struct{}
closed atomic.Bool // atomic, for fast check
closeErr error // closeErr is ready (not nil) when closeNotify is closed.
queueMu sync.RWMutex
reservedQuery int
nextQid uint16
queue map[uint32]chan *[]byte // uint32 has fast path
// waitingResp indicates connection is waiting a reply from the peer.
// It can identify c is dead or buggy in some circumstances. e.g. Network is dropped
// and the sockets were still open because no fin or rst was received.
waitingResp atomic.Bool
}
type TraditionalDnsConnOpts struct {
// Set to true if underlayer connection require a length header.
// e.g. TCP and DoT.
WithLengthHeader bool
// IdleTimeout controls the maximum idle time for each connection.
// Default is defaultIdleTimeout.
IdleTimeout time.Duration
// MaxConcurrentQuery limits the number of maximum concurrent queries
// in the connection. Default is defaultTdcMaxConcurrentQuery.
MaxConcurrentQuery int
}
func NewDnsConn(opt TraditionalDnsConnOpts, conn NetConn) *TraditionalDnsConn {
dc := &TraditionalDnsConn{
c: conn,
isTcp: opt.WithLengthHeader,
closeNotify: make(chan struct{}),
queue: make(map[uint32]chan *[]byte),
}
setDefaultGZ(&dc.idleTimeout, opt.IdleTimeout, defaultIdleTimeout)
setDefaultGZ(&dc.maxCq, opt.MaxConcurrentQuery, defaultTdcMaxConcurrentQuery)
go dc.readLoop()
return dc
}
// exchange sends q out and waits for its reply.
func (dc *TraditionalDnsConn) exchange(ctx context.Context, q []byte) (*[]byte, error) {
select {
case <-dc.closeNotify:
return nil, ErrTDCClosed
default:
}
assignedQid, respChan := dc.addQueueC()
if respChan == nil {
return nil, ErrTDCTooManyQueries
}
defer dc.deleteQueueC(assignedQid)
// Reminder: Set write deadline here is not very useful to avoid dead connections.
// Typically, a write operation will time out only if its socket buffer is full.
// Ser read deadline is enough.
err := dc.writeQuery(q, assignedQid)
if err != nil {
// Write error usually is fatal. Abort and close this connection.
dc.CloseWithErr(fmt.Errorf("write err, %w", err))
return nil, err
}
// If a query was sent, server should have a reply (even not for this query) in a short time.
// This indicates the connection is healthy. Otherwise, this connection might be dead.
// The Read deadline will be refreshed in DnsConn.readLoop() after every successful read.
// Note: There has a race condition in this SetReadDeadline() call and the one in
// readLoop(). It's not a big problem.
if dc.waitingResp.CompareAndSwap(false, true) {
dc.c.SetReadDeadline(time.Now().Add(waitingReplyTimeout))
}
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case r := <-respChan:
orgId := binary.BigEndian.Uint16(q)
binary.BigEndian.PutUint16(*r, orgId)
return r, nil
case <-dc.closeNotify:
return nil, dc.closeErr
}
}
func (dc *TraditionalDnsConn) writeQuery(q []byte, assignedQid uint16) error {
var payload *[]byte
if dc.isTcp {
var err error
payload, err = copyMsgWithLenHdr(q)
if err != nil {
return err
}
binary.BigEndian.PutUint16((*payload)[2:], assignedQid)
} else {
payload = copyMsg(q)
binary.BigEndian.PutUint16(*payload, assignedQid)
}
_, err := dc.c.Write(*payload)
pool.ReleaseBuf(payload)
return err
}
func (dc *TraditionalDnsConn) readResp() (payload *[]byte, err error) {
if dc.isTcp {
return dnsutils.ReadRawMsgFromTCP(dc.c)
}
return readMsgUdp(dc.c)
}
// readLoop reads DnsConn until there was a read error.
func (dc *TraditionalDnsConn) readLoop() {
for {
dc.c.SetReadDeadline(time.Now().Add(dc.idleTimeout))
r, err := dc.readResp()
if err != nil {
dc.CloseWithErr(fmt.Errorf("read err, %w", err)) // abort this connection.
return
}
dc.waitingResp.Store(false)
rid := binary.BigEndian.Uint16(*r)
resChan := dc.getQueueC(rid)
if resChan != nil {
select {
case resChan <- r: // resChan has buffer
default:
pool.ReleaseBuf(r)
}
} else {
pool.ReleaseBuf(r)
}
}
}
func (dc *TraditionalDnsConn) IsClosed() bool {
return dc.closed.Load()
}
func (dc *TraditionalDnsConn) Close() error {
dc.CloseWithErr(ErrTDCClosed)
return nil
}
// CloseWithErr closes DnsConn with an error. The error will be sent
// to the waiting Exchange calls.
// Subsequent calls are noop.
// Default err is ErrTDCClosed.
func (dc *TraditionalDnsConn) CloseWithErr(err error) {
if err == nil {
err = ErrTDCClosed
}
dc.closeOnce.Do(func() {
dc.closed.Store(true)
dc.closeErr = err
close(dc.closeNotify)
dc.c.Close()
})
}
func (dc *TraditionalDnsConn) getQueueC(qid uint16) chan<- *[]byte {
dc.queueMu.RLock()
defer dc.queueMu.RUnlock()
return dc.queue[uint32(qid)]
}
func (dc *TraditionalDnsConn) queueLen() int {
dc.queueMu.RLock()
defer dc.queueMu.RUnlock()
return len(dc.queue) + dc.reservedQuery
}
// addQueueC assigns a qid and add it to the queue.
// It returns a nil c if queue has too many queries.
// Caller must call deleteQueueC to release the qid in queue.
func (dc *TraditionalDnsConn) addQueueC() (qid uint16, c chan *[]byte) {
c = make(chan *[]byte)
dc.queueMu.Lock()
for i := 0; i < 100; i++ {
qid = dc.nextQid
dc.nextQid++
if _, dup := dc.queue[uint32(qid)]; dup {
continue
}
dc.queue[uint32(qid)] = c
dc.queueMu.Unlock()
return qid, c
}
dc.queueMu.Unlock()
// Too many queries in queue. Can't assign qid.
return 0, nil
}
func (dc *TraditionalDnsConn) deleteQueueC(qid uint16) {
dc.queueMu.Lock()
delete(dc.queue, uint32(qid))
dc.queueMu.Unlock()
}
func (dc *TraditionalDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) {
if dc.closed.Load() {
return nil, true
}
dc.queueMu.Lock()
defer dc.queueMu.Unlock()
if len(dc.queue)+dc.reservedQuery >= dc.maxCq {
return nil, false
}
dc.reservedQuery++
return (*tdcOneTimeExchanger)(dc), false
}
type tdcOneTimeExchanger TraditionalDnsConn
var _ ReservedExchanger = (*tdcOneTimeExchanger)(nil)
func (ote *tdcOneTimeExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) {
defer ote.WithdrawReserved()
return (*TraditionalDnsConn)(ote).exchange(ctx, q)
}
func (ote *tdcOneTimeExchanger) WithdrawReserved() {
ote.queueMu.Lock()
ote.reservedQuery--
ote.queueMu.Unlock()
}