package ratelimit import ( "net" "sync" "time" "github.com/sirupsen/logrus" ) // RateLimiter 速率限制器 type RateLimiter struct { logger *logrus.Logger globalLimit *TokenBucket perIPLimiters map[string]*TokenBucket cleanupInterval time.Duration mu sync.RWMutex stopCh chan struct{} } // TokenBucket Token桶算法实现 type TokenBucket struct { capacity int64 tokens int64 refillRate int64 // tokens per second lastRefill time.Time mu sync.Mutex } // Config 速率限制配置 type Config struct { Enabled bool `json:"enabled"` RequestsPerSecond int `json:"requestsPerSecond"` BurstSize int `json:"burstSize"` PerIPRequestsPerSec int `json:"perIPRequestsPerSec"` PerIPBurstSize int `json:"perIPBurstSize"` CleanupInterval time.Duration `json:"cleanupInterval"` } // NewRateLimiter 创建新的速率限制器 func NewRateLimiter(config Config, logger *logrus.Logger) *RateLimiter { if !config.Enabled { return &RateLimiter{ logger: logger, } } // 默认值 if config.BurstSize == 0 { config.BurstSize = config.RequestsPerSecond * 2 } if config.PerIPBurstSize == 0 { config.PerIPBurstSize = config.PerIPRequestsPerSec * 2 } if config.CleanupInterval == 0 { config.CleanupInterval = 5 * time.Minute } rl := &RateLimiter{ logger: logger, globalLimit: NewTokenBucket(int64(config.BurstSize), int64(config.RequestsPerSecond)), perIPLimiters: make(map[string]*TokenBucket), cleanupInterval: config.CleanupInterval, stopCh: make(chan struct{}), } // 启动清理goroutine go rl.cleanup() return rl } // Allow 检查是否允许请求 func (rl *RateLimiter) Allow(remoteAddr string) bool { if rl.globalLimit == nil { return true // 未启用速率限制 } // 全局速率限制 if !rl.globalLimit.Allow() { rl.logger.WithField("type", "global").Debug("Rate limit exceeded") return false } // 获取客户端IP host, _, err := net.SplitHostPort(remoteAddr) if err != nil { host = remoteAddr } // 单IP速率限制 rl.mu.RLock() bucket, exists := rl.perIPLimiters[host] rl.mu.RUnlock() if !exists { rl.mu.Lock() bucket, exists = rl.perIPLimiters[host] if !exists { bucket = NewTokenBucket(20, 10) // 默认每IP 10 rps, burst 20 rl.perIPLimiters[host] = bucket } rl.mu.Unlock() } if !bucket.Allow() { rl.logger.WithFields(logrus.Fields{ "type": "per_ip", "ip": host, }).Debug("Per-IP rate limit exceeded") return false } return true } // Stop 停止速率限制器 func (rl *RateLimiter) Stop() { if rl.stopCh != nil { close(rl.stopCh) } } // GetStats 获取统计信息 func (rl *RateLimiter) GetStats() map[string]interface{} { if rl.globalLimit == nil { return map[string]interface{}{ "enabled": false, } } rl.mu.RLock() perIPCount := len(rl.perIPLimiters) rl.mu.RUnlock() return map[string]interface{}{ "enabled": true, "global_tokens": rl.globalLimit.GetTokens(), "per_ip_buckets": perIPCount, } } // cleanup 清理过期的IP限制器 func (rl *RateLimiter) cleanup() { ticker := time.NewTicker(rl.cleanupInterval) defer ticker.Stop() for { select { case <-rl.stopCh: return case <-ticker.C: rl.cleanupExpiredBuckets() } } } // cleanupExpiredBuckets 清理过期的IP桶 func (rl *RateLimiter) cleanupExpiredBuckets() { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() expiredIPs := make([]string, 0) for ip, bucket := range rl.perIPLimiters { // 如果桶超过10分钟没有活动,清理掉 if now.Sub(bucket.lastRefill) > 10*time.Minute { expiredIPs = append(expiredIPs, ip) } } for _, ip := range expiredIPs { delete(rl.perIPLimiters, ip) } if len(expiredIPs) > 0 { rl.logger.WithField("count", len(expiredIPs)).Debug("Cleaned up expired rate limit buckets") } } // NewTokenBucket 创建新的Token桶 func NewTokenBucket(capacity, refillRate int64) *TokenBucket { return &TokenBucket{ capacity: capacity, tokens: capacity, refillRate: refillRate, lastRefill: time.Now(), } } // Allow 检查是否允许请求(消耗一个token) func (tb *TokenBucket) Allow() bool { tb.mu.Lock() defer tb.mu.Unlock() tb.refill() if tb.tokens > 0 { tb.tokens-- return true } return false } // refill 补充tokens func (tb *TokenBucket) refill() { now := time.Now() elapsed := now.Sub(tb.lastRefill) // 计算应该补充的tokens tokensToAdd := int64(elapsed.Seconds()) * tb.refillRate if tokensToAdd > 0 { tb.tokens += tokensToAdd if tb.tokens > tb.capacity { tb.tokens = tb.capacity } tb.lastRefill = now } } // GetTokens 获取当前token数量 func (tb *TokenBucket) GetTokens() int64 { tb.mu.Lock() defer tb.mu.Unlock() tb.refill() return tb.tokens }