mosdns/pkg/nftset_utils/handler.go
dengxiongjian cd761e8145
Some checks are pending
Test mosdns / build (push) Waiting to run
新增Mikrotik API 插入解析ip
2025-07-31 11:28:55 +08:00

155 lines
3.6 KiB
Go

//go:build linux
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package nftset_utils
import (
"errors"
"fmt"
"net/netip"
"sync"
"time"
"github.com/google/nftables"
"go4.org/netipx"
)
var (
ErrClosed = errors.New("closed handler")
)
// NftSetHandler can add netip.Prefix to the corresponding set.
// The table that contains this set must be an inet family table.
// If the set has a 'interval' flag, the prefix from netip.Prefix will be
// applied.
type NftSetHandler struct {
opts HandlerOpts
m sync.Mutex
closed bool
lastUpdate time.Time
set *nftables.Set
lastingConn *nftables.Conn // Note: lasting conn is not concurrent safe so m is required.
disableSetCache bool // for test only
}
type HandlerOpts struct {
TableFamily nftables.TableFamily
TableName string
SetName string
}
// NewNtSetHandler inits NftSetHandler.
func NewNtSetHandler(opts HandlerOpts) *NftSetHandler {
return &NftSetHandler{
opts: opts,
}
}
// getSetLocked get set info from kernel. It has an internal cache and won't
// invoke a syscall every time.
func (h *NftSetHandler) getSetLocked() (*nftables.Set, error) {
const refreshInterval = time.Second
now := time.Now()
if !h.disableSetCache && h.set != nil && now.Sub(h.lastUpdate) < refreshInterval {
return h.set, nil
}
// Note: GetSetByName is not concurrent safe.
set, err := h.lastingConn.GetSetByName(&nftables.Table{Name: h.opts.TableName, Family: h.opts.TableFamily}, h.opts.SetName)
if err != nil {
return nil, err
}
h.set = set
h.lastUpdate = now
return set, nil
}
// AddElems adds netip.Prefix(s) to set in a single batch.
func (h *NftSetHandler) AddElems(es ...netip.Prefix) error {
h.m.Lock()
defer h.m.Unlock()
if h.closed {
return ErrClosed
}
if h.lastingConn == nil {
c, err := nftables.New(nftables.AsLasting())
if err != nil {
return fmt.Errorf("failed to open netlink, %w", err)
}
h.lastingConn = c
}
set, err := h.getSetLocked()
if err != nil {
return fmt.Errorf("failed to get set, %w", err)
}
var elems []nftables.SetElement
if set.Interval {
elems = make([]nftables.SetElement, 0, 2*len(es))
} else {
elems = make([]nftables.SetElement, 0, len(es))
}
for i, e := range es {
if !e.IsValid() {
return fmt.Errorf("invalid prefix at index %d", i)
}
if set.Interval {
start := e.Masked().Addr()
elems = append(elems, nftables.SetElement{Key: start.AsSlice(), IntervalEnd: false})
end := netipx.PrefixLastIP(e).Next() // may be invalid if end is overflowed
if end.IsValid() {
elems = append(elems, nftables.SetElement{Key: end.AsSlice(), IntervalEnd: true})
}
} else {
elems = append(elems, nftables.SetElement{Key: e.Addr().AsSlice()})
}
}
err = h.lastingConn.SetAddElements(set, elems)
if err != nil {
return err
}
return h.lastingConn.Flush()
}
func (h *NftSetHandler) Close() error {
h.m.Lock()
defer h.m.Unlock()
if h.closed {
return nil
}
h.closed = true
if h.lastingConn != nil {
return h.lastingConn.CloseLasting()
}
return nil
}