You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

231 lines
4.8 KiB

package ratelimit
import (
"net"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// RateLimiter 速率限制器
type RateLimiter struct {
logger *logrus.Logger
globalLimit *TokenBucket
perIPLimiters map[string]*TokenBucket
cleanupInterval time.Duration
mu sync.RWMutex
stopCh chan struct{}
}
// TokenBucket Token桶算法实现
type TokenBucket struct {
capacity int64
tokens int64
refillRate int64 // tokens per second
lastRefill time.Time
mu sync.Mutex
}
// Config 速率限制配置
type Config struct {
Enabled bool `json:"enabled"`
RequestsPerSecond int `json:"requestsPerSecond"`
BurstSize int `json:"burstSize"`
PerIPRequestsPerSec int `json:"perIPRequestsPerSec"`
PerIPBurstSize int `json:"perIPBurstSize"`
CleanupInterval time.Duration `json:"cleanupInterval"`
}
// NewRateLimiter 创建新的速率限制器
func NewRateLimiter(config Config, logger *logrus.Logger) *RateLimiter {
if !config.Enabled {
return &RateLimiter{
logger: logger,
}
}
// 默认值
if config.BurstSize == 0 {
config.BurstSize = config.RequestsPerSecond * 2
}
if config.PerIPBurstSize == 0 {
config.PerIPBurstSize = config.PerIPRequestsPerSec * 2
}
if config.CleanupInterval == 0 {
config.CleanupInterval = 5 * time.Minute
}
rl := &RateLimiter{
logger: logger,
globalLimit: NewTokenBucket(int64(config.BurstSize), int64(config.RequestsPerSecond)),
perIPLimiters: make(map[string]*TokenBucket),
cleanupInterval: config.CleanupInterval,
stopCh: make(chan struct{}),
}
// 启动清理goroutine
go rl.cleanup()
return rl
}
// Allow 检查是否允许请求
func (rl *RateLimiter) Allow(remoteAddr string) bool {
if rl.globalLimit == nil {
return true // 未启用速率限制
}
// 全局速率限制
if !rl.globalLimit.Allow() {
rl.logger.WithField("type", "global").Debug("Rate limit exceeded")
return false
}
// 获取客户端IP
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
host = remoteAddr
}
// 单IP速率限制
rl.mu.RLock()
bucket, exists := rl.perIPLimiters[host]
rl.mu.RUnlock()
if !exists {
rl.mu.Lock()
bucket, exists = rl.perIPLimiters[host]
if !exists {
bucket = NewTokenBucket(20, 10) // 默认每IP 10 rps, burst 20
rl.perIPLimiters[host] = bucket
}
rl.mu.Unlock()
}
if !bucket.Allow() {
rl.logger.WithFields(logrus.Fields{
"type": "per_ip",
"ip": host,
}).Debug("Per-IP rate limit exceeded")
return false
}
return true
}
// Stop 停止速率限制器
func (rl *RateLimiter) Stop() {
if rl.stopCh != nil {
close(rl.stopCh)
}
}
// GetStats 获取统计信息
func (rl *RateLimiter) GetStats() map[string]interface{} {
if rl.globalLimit == nil {
return map[string]interface{}{
"enabled": false,
}
}
rl.mu.RLock()
perIPCount := len(rl.perIPLimiters)
rl.mu.RUnlock()
return map[string]interface{}{
"enabled": true,
"global_tokens": rl.globalLimit.GetTokens(),
"per_ip_buckets": perIPCount,
}
}
// cleanup 清理过期的IP限制器
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(rl.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-rl.stopCh:
return
case <-ticker.C:
rl.cleanupExpiredBuckets()
}
}
}
// cleanupExpiredBuckets 清理过期的IP桶
func (rl *RateLimiter) cleanupExpiredBuckets() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
expiredIPs := make([]string, 0)
for ip, bucket := range rl.perIPLimiters {
// 如果桶超过10分钟没有活动,清理掉
if now.Sub(bucket.lastRefill) > 10*time.Minute {
expiredIPs = append(expiredIPs, ip)
}
}
for _, ip := range expiredIPs {
delete(rl.perIPLimiters, ip)
}
if len(expiredIPs) > 0 {
rl.logger.WithField("count", len(expiredIPs)).Debug("Cleaned up expired rate limit buckets")
}
}
// NewTokenBucket 创建新的Token桶
func NewTokenBucket(capacity, refillRate int64) *TokenBucket {
return &TokenBucket{
capacity: capacity,
tokens: capacity,
refillRate: refillRate,
lastRefill: time.Now(),
}
}
// Allow 检查是否允许请求(消耗一个token)
func (tb *TokenBucket) Allow() bool {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
if tb.tokens > 0 {
tb.tokens--
return true
}
return false
}
// refill 补充tokens
func (tb *TokenBucket) refill() {
now := time.Now()
elapsed := now.Sub(tb.lastRefill)
// 计算应该补充的tokens
tokensToAdd := int64(elapsed.Seconds()) * tb.refillRate
if tokensToAdd > 0 {
tb.tokens += tokensToAdd
if tb.tokens > tb.capacity {
tb.tokens = tb.capacity
}
tb.lastRefill = now
}
}
// GetTokens 获取当前token数量
func (tb *TokenBucket) GetTokens() int64 {
tb.mu.Lock()
defer tb.mu.Unlock()
tb.refill()
return tb.tokens
}