mosdns/pkg/server_handler/entry_handler.go

205 lines
5.2 KiB
Go

/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package server_handler
import (
"context"
"sync"
"time"
"github.com/IrineSistiana/mosdns/v5/mlog"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/IrineSistiana/mosdns/v5/pkg/server"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
"github.com/miekg/dns"
"go.uber.org/zap"
)
const (
defaultQueryTimeout = time.Second * 5
)
var (
nopLogger = mlog.Nop()
// options that can forward to upstream
queryForwardEDNS0Option = map[uint16]struct{}{
dns.EDNS0SUBNET: {},
}
// options that useless for downstream
respRemoveEDNS0Option = map[uint16]struct{}{
dns.EDNS0PADDING: {},
}
)
type EntryHandlerOpts struct {
// Logger is used for logging. Default is a noop logger.
Logger *zap.Logger
// Required.
Entry sequence.Executable
// QueryTimeout limits the timeout value of each query.
// Default is defaultQueryTimeout.
QueryTimeout time.Duration
}
func (opts *EntryHandlerOpts) init() {
if opts.Logger == nil {
opts.Logger = nopLogger
}
utils.SetDefaultNum(&opts.QueryTimeout, defaultQueryTimeout)
}
type EntryHandler struct {
opts EntryHandlerOpts
}
var _ server.Handler = (*EntryHandler)(nil)
func NewEntryHandler(opts EntryHandlerOpts) *EntryHandler {
opts.init()
return &EntryHandler{opts: opts}
}
// ServeDNS implements server.Handler.
// If entry returns an error, a SERVFAIL response will be returned.
// If entry returns without a response, a REFUSED response will be returned.
func (h *EntryHandler) Handle(ctx context.Context, q *dns.Msg, serverMeta server.QueryMeta, packMsgPayload func(m *dns.Msg) (*[]byte, error)) *[]byte {
// 记录查询开始时间(用于统计)
startTime := time.Now()
// basic query check.
if q.Response || len(q.Question) != 1 || len(q.Answer)+len(q.Ns) > 0 || len(q.Extra) > 1 {
return nil
}
ddl := time.Now().Add(h.opts.QueryTimeout)
ctx, cancel := context.WithDeadline(ctx, ddl)
defer cancel()
qCtx := query_context.NewContext(q)
qCtx.ServerMeta = serverMeta
// exec entry
err := h.opts.Entry.Exec(ctx, qCtx)
var resp *dns.Msg
success := true
if err != nil {
h.opts.Logger.Warn("entry err", qCtx.InfoField(), zap.Error(err))
resp = new(dns.Msg)
resp.SetReply(q)
resp.Rcode = dns.RcodeServerFailure
success = false
} else {
resp = qCtx.R()
}
if resp == nil {
resp = new(dns.Msg)
resp.SetReply(q)
resp.Rcode = dns.RcodeRefused
success = false
}
// We assume that our server is a forwarder.
resp.RecursionAvailable = true
// add respOpt back to resp
if respOpt := qCtx.RespOpt(); respOpt != nil {
resp.Extra = append(resp.Extra, respOpt)
}
if serverMeta.FromUDP {
udpSize := getValidUDPSize(qCtx.ClientOpt())
resp.Truncate(udpSize)
}
payload, err := packMsgPayload(resp)
if err != nil {
h.opts.Logger.Error("internal err: failed to pack resp msg", qCtx.InfoField(), zap.Error(err))
return nil
}
// 记录查询统计
responseTime := time.Since(startTime).Milliseconds()
cached := checkIfCachedResponse(qCtx) // 检查是否命中缓存
recordQueryStats(success, cached, responseTime)
return payload
}
// opt can be nil.
func getValidUDPSize(opt *dns.OPT) int {
var s uint16
if opt != nil {
s = opt.UDPSize()
}
if s < dns.MinMsgSize {
s = dns.MinMsgSize
}
return int(s)
}
func newOpt() *dns.OPT {
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
return opt
}
// checkIfCachedResponse 检查响应是否来自缓存
// 简化实现:当前版本暂时假设所有查询都未命中缓存
// TODO: 未来可以通过缓存插件在 qCtx 中设置标记来识别缓存命中
func checkIfCachedResponse(qCtx *query_context.Context) bool {
// 简化实现,始终返回 false
// 缓存统计需要在缓存插件中单独实现
return false
}
// 全局查询统计变量(内部实现,避免循环依赖)
var (
totalQueries int64
successfulQueries int64
failedQueries int64
statsMutex sync.RWMutex
)
// recordQueryStats 记录查询统计(直接实现,避免循环依赖)
func recordQueryStats(success bool, cached bool, responseTimeMs int64) {
statsMutex.Lock()
defer statsMutex.Unlock()
totalQueries++
if success {
successfulQueries++
}
// 简化版本:暂时不记录缓存命中、响应时间等
// 这些可以通过 Prometheus metrics 或者单独的统计插件实现
}
// GetQueryStats 返回查询统计数据(供 coremain 调用)
func GetQueryStats() (total, successful, failed int64) {
statsMutex.RLock()
defer statsMutex.RUnlock()
return totalQueries, successfulQueries, failedQueries
}