482 lines
12 KiB
Go
482 lines
12 KiB
Go
/*
|
|
* Copyright (C) 2020-2022, IrineSistiana
|
|
*
|
|
* This file is part of mosdns.
|
|
*
|
|
* mosdns is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* mosdns is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package cache
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/IrineSistiana/mosdns/v5/coremain"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/cache"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
|
"github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/klauspost/compress/gzip"
|
|
"github.com/miekg/dns"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/sync/singleflight"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
const (
|
|
PluginType = "cache"
|
|
)
|
|
|
|
func init() {
|
|
coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) })
|
|
sequence.MustRegExecQuickSetup(PluginType, quickSetupCache)
|
|
}
|
|
|
|
const (
|
|
defaultLazyUpdateTimeout = time.Second * 5
|
|
expiredMsgTtl = 5
|
|
|
|
minimumChangesToDump = 1024
|
|
dumpHeader = "mosdns_cache_v2"
|
|
dumpBlockSize = 128
|
|
dumpMaximumBlockLength = 1 << 20 // 1M block. 8kb pre entry. Should be enough.
|
|
)
|
|
|
|
var _ sequence.RecursiveExecutable = (*Cache)(nil)
|
|
|
|
type Args struct {
|
|
Size int `yaml:"size"`
|
|
LazyCacheTTL int `yaml:"lazy_cache_ttl"`
|
|
DumpFile string `yaml:"dump_file"`
|
|
DumpInterval int `yaml:"dump_interval"`
|
|
}
|
|
|
|
func (a *Args) init() {
|
|
utils.SetDefaultUnsignNum(&a.Size, 1024)
|
|
utils.SetDefaultUnsignNum(&a.DumpInterval, 600)
|
|
}
|
|
|
|
type Cache struct {
|
|
args *Args
|
|
|
|
logger *zap.Logger
|
|
backend *cache.Cache[key, *item]
|
|
lazyUpdateSF singleflight.Group
|
|
closeOnce sync.Once
|
|
closeNotify chan struct{}
|
|
updatedKey atomic.Uint64
|
|
|
|
queryTotal prometheus.Counter
|
|
hitTotal prometheus.Counter
|
|
lazyHitTotal prometheus.Counter
|
|
size prometheus.GaugeFunc
|
|
}
|
|
|
|
func Init(bp *coremain.BP, args any) (any, error) {
|
|
c := NewCache(args.(*Args), Opts{
|
|
Logger: bp.L(),
|
|
MetricsTag: bp.Tag(),
|
|
})
|
|
|
|
if err := c.RegMetricsTo(prometheus.WrapRegistererWithPrefix(PluginType+"_", bp.M().GetMetricsReg())); err != nil {
|
|
return nil, fmt.Errorf("failed to register metrics, %w", err)
|
|
}
|
|
bp.RegAPI(c.Api())
|
|
return c, nil
|
|
}
|
|
|
|
// QuickSetup format: [size]
|
|
// default is 1024. If size is < 1024, 1024 will be used.
|
|
func quickSetupCache(bq sequence.BQ, s string) (any, error) {
|
|
size := 0
|
|
if len(s) > 0 {
|
|
i, err := strconv.Atoi(s)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid size, %w", err)
|
|
}
|
|
size = i
|
|
}
|
|
// Don't register metrics in quick setup.
|
|
return NewCache(&Args{Size: size}, Opts{Logger: bq.L()}), nil
|
|
}
|
|
|
|
type Opts struct {
|
|
Logger *zap.Logger
|
|
MetricsTag string
|
|
}
|
|
|
|
func NewCache(args *Args, opts Opts) *Cache {
|
|
args.init()
|
|
|
|
logger := opts.Logger
|
|
if logger == nil {
|
|
logger = zap.NewNop()
|
|
}
|
|
|
|
backend := cache.New[key, *item](cache.Opts{Size: args.Size})
|
|
lb := map[string]string{"tag": opts.MetricsTag}
|
|
p := &Cache{
|
|
args: args,
|
|
logger: logger,
|
|
backend: backend,
|
|
closeNotify: make(chan struct{}),
|
|
|
|
queryTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "query_total",
|
|
Help: "The total number of processed queries",
|
|
ConstLabels: lb,
|
|
}),
|
|
hitTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "hit_total",
|
|
Help: "The total number of queries that hit the cache",
|
|
ConstLabels: lb,
|
|
}),
|
|
lazyHitTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
|
Name: "lazy_hit_total",
|
|
Help: "The total number of queries that hit the expired cache",
|
|
ConstLabels: lb,
|
|
}),
|
|
size: prometheus.NewGaugeFunc(prometheus.GaugeOpts{
|
|
Name: "size_current",
|
|
Help: "Current cache size in records",
|
|
ConstLabels: lb,
|
|
}, func() float64 {
|
|
return float64(backend.Len())
|
|
}),
|
|
}
|
|
|
|
if err := p.loadDump(); err != nil {
|
|
p.logger.Error("failed to load cache dump", zap.Error(err))
|
|
}
|
|
p.startDumpLoop()
|
|
|
|
return p
|
|
}
|
|
|
|
func (c *Cache) RegMetricsTo(r prometheus.Registerer) error {
|
|
for _, collector := range [...]prometheus.Collector{c.queryTotal, c.hitTotal, c.lazyHitTotal, c.size} {
|
|
if err := r.Register(collector); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Cache) Exec(ctx context.Context, qCtx *query_context.Context, next sequence.ChainWalker) error {
|
|
c.queryTotal.Inc()
|
|
q := qCtx.Q()
|
|
|
|
msgKey := getMsgKey(q)
|
|
if len(msgKey) == 0 { // skip cache
|
|
return next.ExecNext(ctx, qCtx)
|
|
}
|
|
|
|
cachedResp, lazyHit := getRespFromCache(msgKey, c.backend, c.args.LazyCacheTTL > 0, expiredMsgTtl)
|
|
if lazyHit {
|
|
c.lazyHitTotal.Inc()
|
|
c.doLazyUpdate(msgKey, qCtx, next)
|
|
}
|
|
if cachedResp != nil { // cache hit
|
|
c.hitTotal.Inc()
|
|
cachedResp.Id = q.Id // change msg id
|
|
qCtx.SetResponse(cachedResp)
|
|
}
|
|
|
|
err := next.ExecNext(ctx, qCtx)
|
|
|
|
if r := qCtx.R(); r != nil && cachedResp != r { // pointer compare. r is not cachedResp
|
|
saveRespToCache(msgKey, r, c.backend, c.args.LazyCacheTTL)
|
|
c.updatedKey.Add(1)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// doLazyUpdate starts a new goroutine to execute next node and update the cache in the background.
|
|
// It has an inner singleflight.Group to de-duplicate same msgKey.
|
|
func (c *Cache) doLazyUpdate(msgKey string, qCtx *query_context.Context, next sequence.ChainWalker) {
|
|
qCtxCopy := qCtx.Copy()
|
|
lazyUpdateFunc := func() (any, error) {
|
|
defer c.lazyUpdateSF.Forget(msgKey)
|
|
qCtx := qCtxCopy
|
|
|
|
c.logger.Debug("start lazy cache update", qCtx.InfoField())
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultLazyUpdateTimeout)
|
|
defer cancel()
|
|
|
|
err := next.ExecNext(ctx, qCtx)
|
|
if err != nil {
|
|
c.logger.Warn("failed to update lazy cache", qCtx.InfoField(), zap.Error(err))
|
|
}
|
|
|
|
r := qCtx.R()
|
|
if r != nil {
|
|
saveRespToCache(msgKey, r, c.backend, c.args.LazyCacheTTL)
|
|
c.updatedKey.Add(1)
|
|
}
|
|
c.logger.Debug("lazy cache updated", qCtx.InfoField())
|
|
return nil, nil
|
|
}
|
|
c.lazyUpdateSF.DoChan(msgKey, lazyUpdateFunc) // DoChan won't block this goroutine
|
|
}
|
|
|
|
func (c *Cache) Close() error {
|
|
if err := c.dumpCache(); err != nil {
|
|
c.logger.Error("failed to dump cache", zap.Error(err))
|
|
}
|
|
c.closeOnce.Do(func() {
|
|
close(c.closeNotify)
|
|
})
|
|
return c.backend.Close()
|
|
}
|
|
|
|
func (c *Cache) loadDump() error {
|
|
if len(c.args.DumpFile) == 0 {
|
|
return nil
|
|
}
|
|
f, err := os.Open(c.args.DumpFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
en, err := c.readDump(f)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.logger.Info("cache dump loaded", zap.Int("entries", en))
|
|
return nil
|
|
}
|
|
|
|
// startDumpLoop starts a dump loop in another goroutine. It does not block.
|
|
func (c *Cache) startDumpLoop() {
|
|
if len(c.args.DumpFile) == 0 {
|
|
return
|
|
}
|
|
go func() {
|
|
ticker := time.NewTicker(time.Duration(c.args.DumpInterval) * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
// Check if we have enough changes to dump.
|
|
keyUpdated := c.updatedKey.Swap(0)
|
|
if keyUpdated < minimumChangesToDump { // Nop.
|
|
c.updatedKey.Add(keyUpdated)
|
|
continue
|
|
}
|
|
|
|
if err := c.dumpCache(); err != nil {
|
|
c.logger.Error("dump cache", zap.Error(err))
|
|
}
|
|
case <-c.closeNotify:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (c *Cache) dumpCache() error {
|
|
if len(c.args.DumpFile) == 0 {
|
|
return nil
|
|
}
|
|
|
|
f, err := os.Create(c.args.DumpFile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
en, err := c.writeDump(f)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write dump, %w", err)
|
|
}
|
|
c.logger.Info("cache dumped", zap.Int("entries", en))
|
|
return nil
|
|
}
|
|
|
|
func (c *Cache) Api() *chi.Mux {
|
|
r := chi.NewRouter()
|
|
r.Get("/flush", func(w http.ResponseWriter, req *http.Request) {
|
|
c.backend.Flush()
|
|
})
|
|
r.Get("/dump", func(w http.ResponseWriter, req *http.Request) {
|
|
w.Header().Set("content-type", "application/octet-stream")
|
|
_, err := c.writeDump(w)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
})
|
|
r.Post("/load_dump", func(w http.ResponseWriter, req *http.Request) {
|
|
if _, err := c.readDump(req.Body); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
return r
|
|
}
|
|
|
|
func (c *Cache) writeDump(w io.Writer) (int, error) {
|
|
en := 0
|
|
|
|
gw, _ := gzip.NewWriterLevel(w, gzip.BestSpeed)
|
|
gw.Name = dumpHeader
|
|
|
|
block := new(CacheDumpBlock)
|
|
writeBlock := func() error {
|
|
b, err := proto.Marshal(block)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal protobuf, %w", err)
|
|
}
|
|
|
|
l := make([]byte, 8)
|
|
binary.BigEndian.PutUint64(l, uint64(len(b)))
|
|
_, err = gw.Write(l)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write header, %w", err)
|
|
}
|
|
_, err = gw.Write(b)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write data, %w", err)
|
|
}
|
|
|
|
en += len(block.GetEntries())
|
|
block.Reset()
|
|
return nil
|
|
}
|
|
|
|
now := time.Now()
|
|
rangeFunc := func(k key, v *item, cacheExpirationTime time.Time) error {
|
|
if cacheExpirationTime.Before(now) {
|
|
return nil
|
|
}
|
|
msg, err := v.resp.Pack()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to pack msg, %w", err)
|
|
}
|
|
e := &CachedEntry{
|
|
Key: []byte(k),
|
|
CacheExpirationTime: cacheExpirationTime.Unix(),
|
|
MsgExpirationTime: v.expirationTime.Unix(),
|
|
Msg: msg,
|
|
}
|
|
block.Entries = append(block.Entries, e)
|
|
|
|
// Block is big enough for a write operation.
|
|
if len(block.Entries) >= dumpBlockSize {
|
|
return writeBlock()
|
|
}
|
|
return nil
|
|
}
|
|
if err := c.backend.Range(rangeFunc); err != nil {
|
|
return en, err
|
|
}
|
|
|
|
if len(block.GetEntries()) > 0 {
|
|
if err := writeBlock(); err != nil {
|
|
return en, err
|
|
}
|
|
}
|
|
return en, gw.Close()
|
|
}
|
|
|
|
// readDump reads dumped data from r. It returns the number of bytes read,
|
|
// number of entries read and any error encountered.
|
|
func (c *Cache) readDump(r io.Reader) (int, error) {
|
|
en := 0
|
|
gr, err := gzip.NewReader(r)
|
|
if err != nil {
|
|
return en, fmt.Errorf("failed to read gzip header, %w", err)
|
|
}
|
|
if gr.Name != dumpHeader {
|
|
return en, fmt.Errorf("invalid or old cache dump, header is %s, want %s", gr.Name, dumpHeader)
|
|
}
|
|
|
|
var errReadHeaderEOF = errors.New("")
|
|
readBlock := func() error {
|
|
h := pool.GetBuf(8)
|
|
defer pool.ReleaseBuf(h)
|
|
_, err := io.ReadFull(gr, *h)
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
return errReadHeaderEOF
|
|
}
|
|
return fmt.Errorf("failed to read block header, %w", err)
|
|
}
|
|
u := binary.BigEndian.Uint64(*h)
|
|
if u > dumpMaximumBlockLength {
|
|
return fmt.Errorf("invalid header, block length is big, %d", u)
|
|
}
|
|
|
|
b := pool.GetBuf(int(u))
|
|
defer pool.ReleaseBuf(b)
|
|
_, err = io.ReadFull(gr, *b)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read block data, %w", err)
|
|
}
|
|
|
|
block := new(CacheDumpBlock)
|
|
if err := proto.Unmarshal(*b, block); err != nil {
|
|
return fmt.Errorf("failed to decode block data, %w", err)
|
|
}
|
|
|
|
en += len(block.GetEntries())
|
|
for _, entry := range block.GetEntries() {
|
|
cacheExpTime := time.Unix(entry.GetCacheExpirationTime(), 0)
|
|
msgExpTime := time.Unix(entry.GetMsgExpirationTime(), 0)
|
|
storedTime := time.Unix(entry.GetMsgStoredTime(), 0)
|
|
resp := new(dns.Msg)
|
|
if err := resp.Unpack(entry.GetMsg()); err != nil {
|
|
return fmt.Errorf("failed to decode dns msg, %w", err)
|
|
}
|
|
|
|
i := &item{
|
|
resp: resp,
|
|
storedTime: storedTime,
|
|
expirationTime: msgExpTime,
|
|
}
|
|
c.backend.Store(key(entry.GetKey()), i, cacheExpTime)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
for {
|
|
err = readBlock()
|
|
if err != nil {
|
|
if err == errReadHeaderEOF {
|
|
err = nil // This is expected if there is no block to read.
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
return en, err
|
|
}
|
|
return en, gr.Close()
|
|
}
|