/* * Copyright (C) 2020-2022, IrineSistiana * * This file is part of mosdns. * * mosdns is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * mosdns is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ package server_handler import ( "context" "sync" "time" "github.com/IrineSistiana/mosdns/v5/mlog" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/IrineSistiana/mosdns/v5/pkg/server" "github.com/IrineSistiana/mosdns/v5/pkg/utils" "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence" "github.com/miekg/dns" "go.uber.org/zap" ) const ( defaultQueryTimeout = time.Second * 5 ) var ( nopLogger = mlog.Nop() // options that can forward to upstream queryForwardEDNS0Option = map[uint16]struct{}{ dns.EDNS0SUBNET: {}, } // options that useless for downstream respRemoveEDNS0Option = map[uint16]struct{}{ dns.EDNS0PADDING: {}, } ) type EntryHandlerOpts struct { // Logger is used for logging. Default is a noop logger. Logger *zap.Logger // Required. Entry sequence.Executable // QueryTimeout limits the timeout value of each query. // Default is defaultQueryTimeout. QueryTimeout time.Duration } func (opts *EntryHandlerOpts) init() { if opts.Logger == nil { opts.Logger = nopLogger } utils.SetDefaultNum(&opts.QueryTimeout, defaultQueryTimeout) } type EntryHandler struct { opts EntryHandlerOpts } var _ server.Handler = (*EntryHandler)(nil) func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler { opts.init() return &EntryHandler{opts: opts} } // ServeDNS implements server.Handler. // If entry returns an error, a SERVFAIL response will be returned. // If entry returns without a response, a REFUSED response will be returned. func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, serverMeta server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte { // 记录查询开始时间(用于统计) startTime := time.Now() // basic query check. if q.Response || len(q.Question) != 1 || len(q.Answer)+len(q.Ns) > 0 || len(q.Extra) > 1 { return nil } ddl := time.Now().Add(h.opts.QueryTimeout) ctx, cancel := context.WithDeadline(ctx, ddl) defer cancel() qCtx := query_context.NewContext(q) qCtx.ServerMeta = serverMeta // exec entry err := h.opts.Entry.Exec(ctx, qCtx) var resp *dns.Msg success := true if err != nil { h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Error(err)) resp = new(dns.Msg) resp.SetReply(q) resp.Rcode = dns.RcodeServerFailure success = false } else { resp = qCtx.R() } if resp == nil { resp = new(dns.Msg) resp.SetReply(q) resp.Rcode = dns.RcodeRefused success = false } // We assume that our server is a forwarder. resp.RecursionAvailable = true // add respOpt back to resp if respOpt := qCtx.RespOpt(); respOpt != nil { resp.Extra = append(resp.Extra, respOpt) } if serverMeta.FromUDP { udpSize := getValidUDPSize(qCtx.ClientOpt()) resp.Truncate(udpSize) } payload, err := packMsgPayload(resp) if err != nil { h.opts.Logger.Error("internal err: failed to pack resp msg", qCtx.InfoField(), zap.Error(err)) return nil } // 记录查询统计 responseTime := time.Since(startTime).Milliseconds() cached := checkIfCachedResponse(qCtx) // 检查是否命中缓存 recordQueryStats(success, cached, responseTime) return payload } // opt can be nil. func getValidUDPSize(opt *dns.OPT) int { var s uint16 if opt != nil { s = opt.UDPSize() } if s < dns.MinMsgSize { s = dns.MinMsgSize } return int(s) } func newOpt() *dns.OPT { opt := new(dns.OPT) opt.Hdr.Name = "." opt.Hdr.Rrtype = dns.TypeOPT return opt } // checkIfCachedResponse 检查响应是否来自缓存 // 简化实现:当前版本暂时假设所有查询都未命中缓存 // TODO: 未来可以通过缓存插件在 qCtx 中设置标记来识别缓存命中 func checkIfCachedResponse(qCtx *query_context.Context) bool { // 简化实现,始终返回 false // 缓存统计需要在缓存插件中单独实现 return false } // 全局查询统计变量(内部实现,避免循环依赖) var ( totalQueries int64 successfulQueries int64 failedQueries int64 statsMutex sync.RWMutex ) // recordQueryStats 记录查询统计(直接实现,避免循环依赖) func recordQueryStats(success bool, cached bool, responseTimeMs int64) { statsMutex.Lock() defer statsMutex.Unlock() totalQueries++ if success { successfulQueries++ } // 简化版本:暂时不记录缓存命中、响应时间等 // 这些可以通过 Prometheus metrics 或者单独的统计插件实现 } // GetQueryStats 返回查询统计数据(供 coremain 调用) func GetQueryStats() (total, successful, failed int64) { statsMutex.RLock() defer statsMutex.RUnlock() return totalQueries, successfulQueries, failedQueries }