154 lines
3.1 KiB
Go
154 lines
3.1 KiB
Go
package rate_limiter
|
|
|
|
import (
|
|
"net/netip"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
const (
|
|
tableShards = 32
|
|
gcInterval = time.Minute
|
|
)
|
|
|
|
type Limiter struct {
|
|
// Limit and Burst are read-only.
|
|
Limit rate.Limit
|
|
Burst int
|
|
|
|
closeOnce sync.Once
|
|
closeNotify chan struct{}
|
|
tables [tableShards]*tableShard
|
|
}
|
|
|
|
type tableShard struct {
|
|
m sync.Mutex
|
|
table map[netip.Addr]*limiterEntry
|
|
}
|
|
|
|
type limiterEntry struct {
|
|
l *rate.Limiter
|
|
lastSeen time.Time
|
|
sync.Once
|
|
}
|
|
|
|
// NewRateLimiter creates a new client rate limiter.
|
|
// limit and burst should be greater than zero. See rate.Limiter for more
|
|
// details.
|
|
// Limiter has a internal gc which will run and remove old client entries every 1m.
|
|
// If the token refill time (burst/limit) is greater than 1m,
|
|
// the actual average qps limit may be higher than expected because the client status
|
|
// may be deleted and re-initialized.
|
|
func NewRateLimiter(limit rate.Limit, burst int) *Limiter {
|
|
l := &Limiter{
|
|
Limit: limit,
|
|
Burst: burst,
|
|
closeNotify: make(chan struct{}),
|
|
}
|
|
|
|
for i := range l.tables {
|
|
l.tables[i] = &tableShard{table: make(map[netip.Addr]*limiterEntry)}
|
|
}
|
|
|
|
go l.gcLoop(gcInterval)
|
|
return l
|
|
}
|
|
|
|
// maskedUnmappedP must be a masked prefix and contain a unmapped addr.
|
|
func (l *Limiter) Allow(unmappedAddr netip.Addr) bool {
|
|
now := time.Now()
|
|
shard := l.getTableShard(unmappedAddr)
|
|
shard.m.Lock()
|
|
e, ok := shard.table[unmappedAddr]
|
|
if !ok {
|
|
e = &limiterEntry{
|
|
l: rate.NewLimiter(l.Limit, l.Burst),
|
|
lastSeen: now,
|
|
}
|
|
shard.table[unmappedAddr] = e
|
|
}
|
|
e.lastSeen = now
|
|
shard.m.Unlock()
|
|
clientLimiter := e.l
|
|
return clientLimiter.AllowN(now, 1)
|
|
}
|
|
|
|
func (l *Limiter) Close() error {
|
|
l.closeOnce.Do(func() {
|
|
close(l.closeNotify)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (l *Limiter) gcLoop(gcInterval time.Duration) {
|
|
ticker := time.NewTicker(gcInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-l.closeNotify:
|
|
return
|
|
case now := <-ticker.C:
|
|
l.doGc(now, gcInterval)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (l *Limiter) doGc(now time.Time, gcInterval time.Duration) {
|
|
for _, shard := range l.tables {
|
|
shard.m.Lock()
|
|
for a, e := range shard.table {
|
|
if now.Sub(e.lastSeen) > gcInterval {
|
|
delete(shard.table, a)
|
|
}
|
|
}
|
|
shard.m.Unlock()
|
|
}
|
|
}
|
|
|
|
func (l *Limiter) getTableShard(unmappedAddr netip.Addr) *tableShard {
|
|
return l.tables[getTableShardIdx(unmappedAddr)]
|
|
}
|
|
|
|
func (l *Limiter) ForEach(doFunc func(unmappedAddr netip.Addr, r *rate.Limiter) (doBreak bool)) (doBreak bool) {
|
|
for _, shard := range l.tables {
|
|
shard.m.Lock()
|
|
for a, e := range shard.table {
|
|
doBreak = doFunc(a, e.l)
|
|
if doBreak {
|
|
shard.m.Unlock()
|
|
return
|
|
}
|
|
}
|
|
shard.m.Unlock()
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Len returns current number of entries in the Limiter.
|
|
func (l *Limiter) Len() int {
|
|
n := 0
|
|
for _, shard := range l.tables {
|
|
shard.m.Lock()
|
|
n += len(shard.table)
|
|
shard.m.Unlock()
|
|
}
|
|
return n
|
|
}
|
|
|
|
func getTableShardIdx(unmappedAddr netip.Addr) int {
|
|
var i byte
|
|
if unmappedAddr.Is4() {
|
|
for _, b := range unmappedAddr.As4() {
|
|
i ^= b
|
|
}
|
|
} else {
|
|
for _, b := range unmappedAddr.As16() {
|
|
i ^= b
|
|
}
|
|
}
|
|
return int(i % tableShards)
|
|
}
|