You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
231 lines
4.8 KiB
231 lines
4.8 KiB
package ratelimit
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// RateLimiter 速率限制器
|
|
type RateLimiter struct {
|
|
logger *logrus.Logger
|
|
globalLimit *TokenBucket
|
|
perIPLimiters map[string]*TokenBucket
|
|
cleanupInterval time.Duration
|
|
mu sync.RWMutex
|
|
stopCh chan struct{}
|
|
}
|
|
|
|
// TokenBucket Token桶算法实现
|
|
type TokenBucket struct {
|
|
capacity int64
|
|
tokens int64
|
|
refillRate int64 // tokens per second
|
|
lastRefill time.Time
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// Config 速率限制配置
|
|
type Config struct {
|
|
Enabled bool `json:"enabled"`
|
|
RequestsPerSecond int `json:"requestsPerSecond"`
|
|
BurstSize int `json:"burstSize"`
|
|
PerIPRequestsPerSec int `json:"perIPRequestsPerSec"`
|
|
PerIPBurstSize int `json:"perIPBurstSize"`
|
|
CleanupInterval time.Duration `json:"cleanupInterval"`
|
|
}
|
|
|
|
// NewRateLimiter 创建新的速率限制器
|
|
func NewRateLimiter(config Config, logger *logrus.Logger) *RateLimiter {
|
|
if !config.Enabled {
|
|
return &RateLimiter{
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// 默认值
|
|
if config.BurstSize == 0 {
|
|
config.BurstSize = config.RequestsPerSecond * 2
|
|
}
|
|
if config.PerIPBurstSize == 0 {
|
|
config.PerIPBurstSize = config.PerIPRequestsPerSec * 2
|
|
}
|
|
if config.CleanupInterval == 0 {
|
|
config.CleanupInterval = 5 * time.Minute
|
|
}
|
|
|
|
rl := &RateLimiter{
|
|
logger: logger,
|
|
globalLimit: NewTokenBucket(int64(config.BurstSize), int64(config.RequestsPerSecond)),
|
|
perIPLimiters: make(map[string]*TokenBucket),
|
|
cleanupInterval: config.CleanupInterval,
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
|
|
// 启动清理goroutine
|
|
go rl.cleanup()
|
|
|
|
return rl
|
|
}
|
|
|
|
// Allow 检查是否允许请求
|
|
func (rl *RateLimiter) Allow(remoteAddr string) bool {
|
|
if rl.globalLimit == nil {
|
|
return true // 未启用速率限制
|
|
}
|
|
|
|
// 全局速率限制
|
|
if !rl.globalLimit.Allow() {
|
|
rl.logger.WithField("type", "global").Debug("Rate limit exceeded")
|
|
return false
|
|
}
|
|
|
|
// 获取客户端IP
|
|
host, _, err := net.SplitHostPort(remoteAddr)
|
|
if err != nil {
|
|
host = remoteAddr
|
|
}
|
|
|
|
// 单IP速率限制
|
|
rl.mu.RLock()
|
|
bucket, exists := rl.perIPLimiters[host]
|
|
rl.mu.RUnlock()
|
|
|
|
if !exists {
|
|
rl.mu.Lock()
|
|
bucket, exists = rl.perIPLimiters[host]
|
|
if !exists {
|
|
bucket = NewTokenBucket(20, 10) // 默认每IP 10 rps, burst 20
|
|
rl.perIPLimiters[host] = bucket
|
|
}
|
|
rl.mu.Unlock()
|
|
}
|
|
|
|
if !bucket.Allow() {
|
|
rl.logger.WithFields(logrus.Fields{
|
|
"type": "per_ip",
|
|
"ip": host,
|
|
}).Debug("Per-IP rate limit exceeded")
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// Stop 停止速率限制器
|
|
func (rl *RateLimiter) Stop() {
|
|
if rl.stopCh != nil {
|
|
close(rl.stopCh)
|
|
}
|
|
}
|
|
|
|
// GetStats 获取统计信息
|
|
func (rl *RateLimiter) GetStats() map[string]interface{} {
|
|
if rl.globalLimit == nil {
|
|
return map[string]interface{}{
|
|
"enabled": false,
|
|
}
|
|
}
|
|
|
|
rl.mu.RLock()
|
|
perIPCount := len(rl.perIPLimiters)
|
|
rl.mu.RUnlock()
|
|
|
|
return map[string]interface{}{
|
|
"enabled": true,
|
|
"global_tokens": rl.globalLimit.GetTokens(),
|
|
"per_ip_buckets": perIPCount,
|
|
}
|
|
}
|
|
|
|
// cleanup 清理过期的IP限制器
|
|
func (rl *RateLimiter) cleanup() {
|
|
ticker := time.NewTicker(rl.cleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-rl.stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
rl.cleanupExpiredBuckets()
|
|
}
|
|
}
|
|
}
|
|
|
|
// cleanupExpiredBuckets 清理过期的IP桶
|
|
func (rl *RateLimiter) cleanupExpiredBuckets() {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
expiredIPs := make([]string, 0)
|
|
|
|
for ip, bucket := range rl.perIPLimiters {
|
|
// 如果桶超过10分钟没有活动,清理掉
|
|
if now.Sub(bucket.lastRefill) > 10*time.Minute {
|
|
expiredIPs = append(expiredIPs, ip)
|
|
}
|
|
}
|
|
|
|
for _, ip := range expiredIPs {
|
|
delete(rl.perIPLimiters, ip)
|
|
}
|
|
|
|
if len(expiredIPs) > 0 {
|
|
rl.logger.WithField("count", len(expiredIPs)).Debug("Cleaned up expired rate limit buckets")
|
|
}
|
|
}
|
|
|
|
// NewTokenBucket 创建新的Token桶
|
|
func NewTokenBucket(capacity, refillRate int64) *TokenBucket {
|
|
return &TokenBucket{
|
|
capacity: capacity,
|
|
tokens: capacity,
|
|
refillRate: refillRate,
|
|
lastRefill: time.Now(),
|
|
}
|
|
}
|
|
|
|
// Allow 检查是否允许请求(消耗一个token)
|
|
func (tb *TokenBucket) Allow() bool {
|
|
tb.mu.Lock()
|
|
defer tb.mu.Unlock()
|
|
|
|
tb.refill()
|
|
|
|
if tb.tokens > 0 {
|
|
tb.tokens--
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// refill 补充tokens
|
|
func (tb *TokenBucket) refill() {
|
|
now := time.Now()
|
|
elapsed := now.Sub(tb.lastRefill)
|
|
|
|
// 计算应该补充的tokens
|
|
tokensToAdd := int64(elapsed.Seconds()) * tb.refillRate
|
|
|
|
if tokensToAdd > 0 {
|
|
tb.tokens += tokensToAdd
|
|
if tb.tokens > tb.capacity {
|
|
tb.tokens = tb.capacity
|
|
}
|
|
tb.lastRefill = now
|
|
}
|
|
}
|
|
|
|
// GetTokens 获取当前token数量
|
|
func (tb *TokenBucket) GetTokens() int64 {
|
|
tb.mu.Lock()
|
|
defer tb.mu.Unlock()
|
|
|
|
tb.refill()
|
|
return tb.tokens
|
|
}
|
|
|