/* * Copyright (C) 2020-2022, IrineSistiana * * This file is part of mosdns. * * mosdns is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * mosdns is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ package transport import ( "context" "encoding/binary" "errors" "fmt" "sync" "sync/atomic" "time" "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" "github.com/IrineSistiana/mosdns/v5/pkg/pool" ) var ( ErrTDCTooManyQueries = errors.New("too many queries") // Connection has too many ongoing queries. ErrTDCClosed = errors.New("dns connection closed") ) var _ DnsConn = (*TraditionalDnsConn)(nil) // TraditionalDnsConn is a low-level connection for traditional dns protocol, where // dns frames transport in a single and simple connection. (e.g. udp, tcp, tls) type TraditionalDnsConn struct { c NetConn isTcp bool idleTimeout time.Duration maxCq int closeOnce sync.Once closeNotify chan struct{} closed atomic.Bool // atomic, for fast check closeErr error // closeErr is ready (not nil) when closeNotify is closed. queueMu sync.RWMutex reservedQuery int nextQid uint16 queue map[uint32]chan *[]byte // uint32 has fast path // waitingResp indicates connection is waiting a reply from the peer. // It can identify c is dead or buggy in some circumstances. e.g. Network is dropped // and the sockets were still open because no fin or rst was received. waitingResp atomic.Bool } type TraditionalDnsConnOpts struct { // Set to true if underlayer connection require a length header. // e.g. TCP and DoT. WithLengthHeader bool // IdleTimeout controls the maximum idle time for each connection. // Default is defaultIdleTimeout. IdleTimeout time.Duration // MaxConcurrentQuery limits the number of maximum concurrent queries // in the connection. Default is defaultTdcMaxConcurrentQuery. MaxConcurrentQuery int } func NewDnsConn(opt TraditionalDnsConnOpts, conn NetConn) *TraditionalDnsConn { dc := &TraditionalDnsConn{ c: conn, isTcp: opt.WithLengthHeader, closeNotify: make(chan struct{}), queue: make(map[uint32]chan *[]byte), } setDefaultGZ(&dc.idleTimeout, opt.IdleTimeout, defaultIdleTimeout) setDefaultGZ(&dc.maxCq, opt.MaxConcurrentQuery, defaultTdcMaxConcurrentQuery) go dc.readLoop() return dc } // exchange sends q out and waits for its reply. func (dc *TraditionalDnsConn) exchange(ctx context.Context, q []byte) (*[]byte, error) { select { case <-dc.closeNotify: return nil, ErrTDCClosed default: } assignedQid, respChan := dc.addQueueC() if respChan == nil { return nil, ErrTDCTooManyQueries } defer dc.deleteQueueC(assignedQid) // Reminder: Set write deadline here is not very useful to avoid dead connections. // Typically, a write operation will time out only if its socket buffer is full. // Ser read deadline is enough. err := dc.writeQuery(q, assignedQid) if err != nil { // Write error usually is fatal. Abort and close this connection. dc.CloseWithErr(fmt.Errorf("write err, %w", err)) return nil, err } // If a query was sent, server should have a reply (even not for this query) in a short time. // This indicates the connection is healthy. Otherwise, this connection might be dead. // The Read deadline will be refreshed in DnsConn.readLoop() after every successful read. // Note: There has a race condition in this SetReadDeadline() call and the one in // readLoop(). It's not a big problem. if dc.waitingResp.CompareAndSwap(false, true) { dc.c.SetReadDeadline(time.Now().Add(waitingReplyTimeout)) } select { case <-ctx.Done(): return nil, context.Cause(ctx) case r := <-respChan: orgId := binary.BigEndian.Uint16(q) binary.BigEndian.PutUint16(*r, orgId) return r, nil case <-dc.closeNotify: return nil, dc.closeErr } } func (dc *TraditionalDnsConn) writeQuery(q []byte, assignedQid uint16) error { var payload *[]byte if dc.isTcp { var err error payload, err = copyMsgWithLenHdr(q) if err != nil { return err } binary.BigEndian.PutUint16((*payload)[2:], assignedQid) } else { payload = copyMsg(q) binary.BigEndian.PutUint16(*payload, assignedQid) } _, err := dc.c.Write(*payload) pool.ReleaseBuf(payload) return err } func (dc *TraditionalDnsConn) readResp() (payload *[]byte, err error) { if dc.isTcp { return dnsutils.ReadRawMsgFromTCP(dc.c) } return readMsgUdp(dc.c) } // readLoop reads DnsConn until there was a read error. func (dc *TraditionalDnsConn) readLoop() { for { dc.c.SetReadDeadline(time.Now().Add(dc.idleTimeout)) r, err := dc.readResp() if err != nil { dc.CloseWithErr(fmt.Errorf("read err, %w", err)) // abort this connection. return } dc.waitingResp.Store(false) rid := binary.BigEndian.Uint16(*r) resChan := dc.getQueueC(rid) if resChan != nil { select { case resChan <- r: // resChan has buffer default: pool.ReleaseBuf(r) } } else { pool.ReleaseBuf(r) } } } func (dc *TraditionalDnsConn) IsClosed() bool { return dc.closed.Load() } func (dc *TraditionalDnsConn) Close() error { dc.CloseWithErr(ErrTDCClosed) return nil } // CloseWithErr closes DnsConn with an error. The error will be sent // to the waiting Exchange calls. // Subsequent calls are noop. // Default err is ErrTDCClosed. func (dc *TraditionalDnsConn) CloseWithErr(err error) { if err == nil { err = ErrTDCClosed } dc.closeOnce.Do(func() { dc.closed.Store(true) dc.closeErr = err close(dc.closeNotify) dc.c.Close() }) } func (dc *TraditionalDnsConn) getQueueC(qid uint16) chan<- *[]byte { dc.queueMu.RLock() defer dc.queueMu.RUnlock() return dc.queue[uint32(qid)] } func (dc *TraditionalDnsConn) queueLen() int { dc.queueMu.RLock() defer dc.queueMu.RUnlock() return len(dc.queue) + dc.reservedQuery } // addQueueC assigns a qid and add it to the queue. // It returns a nil c if queue has too many queries. // Caller must call deleteQueueC to release the qid in queue. func (dc *TraditionalDnsConn) addQueueC() (qid uint16, c chan *[]byte) { c = make(chan *[]byte) dc.queueMu.Lock() for i := 0; i < 100; i++ { qid = dc.nextQid dc.nextQid++ if _, dup := dc.queue[uint32(qid)]; dup { continue } dc.queue[uint32(qid)] = c dc.queueMu.Unlock() return qid, c } dc.queueMu.Unlock() // Too many queries in queue. Can't assign qid. return 0, nil } func (dc *TraditionalDnsConn) deleteQueueC(qid uint16) { dc.queueMu.Lock() delete(dc.queue, uint32(qid)) dc.queueMu.Unlock() } func (dc *TraditionalDnsConn) ReserveNewQuery() (_ ReservedExchanger, closed bool) { if dc.closed.Load() { return nil, true } dc.queueMu.Lock() defer dc.queueMu.Unlock() if len(dc.queue)+dc.reservedQuery >= dc.maxCq { return nil, false } dc.reservedQuery++ return (*tdcOneTimeExchanger)(dc), false } type tdcOneTimeExchanger TraditionalDnsConn var _ ReservedExchanger = (*tdcOneTimeExchanger)(nil) func (ote *tdcOneTimeExchanger) ExchangeReserved(ctx context.Context, q []byte) (resp *[]byte, err error) { defer ote.WithdrawReserved() return (*TraditionalDnsConn)(ote).exchange(ctx, q) } func (ote *tdcOneTimeExchanger) WithdrawReserved() { ote.queueMu.Lock() ote.reservedQuery-- ote.queueMu.Unlock() }