mosdns/pkg/rate_limiter/rate_limiter.go
dengxiongjian cd761e8145
Some checks are pending
Test mosdns / build (push) Waiting to run
新增Mikrotik API 插入解析ip
2025-07-31 11:28:55 +08:00

154 lines
3.1 KiB
Go

package rate_limiter
import (
"net/netip"
"sync"
"time"
"golang.org/x/time/rate"
)
const (
tableShards = 32
gcInterval = time.Minute
)
type Limiter struct {
// Limit and Burst are read-only.
Limit rate.Limit
Burst int
closeOnce sync.Once
closeNotify chan struct{}
tables [tableShards]*tableShard
}
type tableShard struct {
m sync.Mutex
table map[netip.Addr]*limiterEntry
}
type limiterEntry struct {
l *rate.Limiter
lastSeen time.Time
sync.Once
}
// NewRateLimiter creates a new client rate limiter.
// limit and burst should be greater than zero. See rate.Limiter for more
// details.
// Limiter has a internal gc which will run and remove old client entries every 1m.
// If the token refill time (burst/limit) is greater than 1m,
// the actual average qps limit may be higher than expected because the client status
// may be deleted and re-initialized.
func NewRateLimiter(limit rate.Limit, burst int) *Limiter {
l := &Limiter{
Limit: limit,
Burst: burst,
closeNotify: make(chan struct{}),
}
for i := range l.tables {
l.tables[i] = &tableShard{table: make(map[netip.Addr]*limiterEntry)}
}
go l.gcLoop(gcInterval)
return l
}
// maskedUnmappedP must be a masked prefix and contain a unmapped addr.
func (l *Limiter) Allow(unmappedAddr netip.Addr) bool {
now := time.Now()
shard := l.getTableShard(unmappedAddr)
shard.m.Lock()
e, ok := shard.table[unmappedAddr]
if !ok {
e = &limiterEntry{
l: rate.NewLimiter(l.Limit, l.Burst),
lastSeen: now,
}
shard.table[unmappedAddr] = e
}
e.lastSeen = now
shard.m.Unlock()
clientLimiter := e.l
return clientLimiter.AllowN(now, 1)
}
func (l *Limiter) Close() error {
l.closeOnce.Do(func() {
close(l.closeNotify)
})
return nil
}
func (l *Limiter) gcLoop(gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-l.closeNotify:
return
case now := <-ticker.C:
l.doGc(now, gcInterval)
}
}
}
func (l *Limiter) doGc(now time.Time, gcInterval time.Duration) {
for _, shard := range l.tables {
shard.m.Lock()
for a, e := range shard.table {
if now.Sub(e.lastSeen) > gcInterval {
delete(shard.table, a)
}
}
shard.m.Unlock()
}
}
func (l *Limiter) getTableShard(unmappedAddr netip.Addr) *tableShard {
return l.tables[getTableShardIdx(unmappedAddr)]
}
func (l *Limiter) ForEach(doFunc func(unmappedAddr netip.Addr, r *rate.Limiter) (doBreak bool)) (doBreak bool) {
for _, shard := range l.tables {
shard.m.Lock()
for a, e := range shard.table {
doBreak = doFunc(a, e.l)
if doBreak {
shard.m.Unlock()
return
}
}
shard.m.Unlock()
}
return false
}
// Len returns current number of entries in the Limiter.
func (l *Limiter) Len() int {
n := 0
for _, shard := range l.tables {
shard.m.Lock()
n += len(shard.table)
shard.m.Unlock()
}
return n
}
func getTableShardIdx(unmappedAddr netip.Addr) int {
var i byte
if unmappedAddr.Is4() {
for _, b := range unmappedAddr.As4() {
i ^= b
}
} else {
for _, b := range unmappedAddr.As16() {
i ^= b
}
}
return int(i % tableShards)
}