parent
77128910b3
commit
52f0fba78d
@ -0,0 +1,289 @@ |
|||||||
|
package config |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"os" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/sirupsen/logrus" |
||||||
|
"github.com/spf13/viper" |
||||||
|
) |
||||||
|
|
||||||
|
// ServerConfig 服务器配置
|
||||||
|
type ServerConfig struct { |
||||||
|
ServiceType string `mapstructure:"serviceType"` |
||||||
|
|
||||||
|
Proxy ProxyConfig `mapstructure:"proxy"` |
||||||
|
Auth AuthConfig `mapstructure:"auth"` |
||||||
|
|
||||||
|
Timeout time.Duration `mapstructure:"timeout"` |
||||||
|
MaxConns int `mapstructure:"maxConns"` |
||||||
|
LogLevel string `mapstructure:"logLevel"` |
||||||
|
|
||||||
|
HealthCheck HealthCheckConfig `mapstructure:"healthCheck"` |
||||||
|
|
||||||
|
OptimizedServer OptimizedServerConfig `mapstructure:"optimizedServer"` |
||||||
|
} |
||||||
|
|
||||||
|
// ProxyConfig 代理配置
|
||||||
|
type ProxyConfig struct { |
||||||
|
Address string `mapstructure:"address"` |
||||||
|
Port int `mapstructure:"port"` |
||||||
|
} |
||||||
|
|
||||||
|
// AuthConfig 认证配置
|
||||||
|
type AuthConfig struct { |
||||||
|
Username string `mapstructure:"username"` |
||||||
|
Password string `mapstructure:"password"` |
||||||
|
Methods []string `mapstructure:"methods"` |
||||||
|
} |
||||||
|
|
||||||
|
// HealthCheckConfig 健康检查配置
|
||||||
|
type HealthCheckConfig struct { |
||||||
|
Enabled bool `mapstructure:"enabled"` |
||||||
|
Address string `mapstructure:"address"` |
||||||
|
Port int `mapstructure:"port"` |
||||||
|
} |
||||||
|
|
||||||
|
// OptimizedServerConfig 优化服务器配置
|
||||||
|
type OptimizedServerConfig struct { |
||||||
|
Enabled bool `mapstructure:"enabled"` |
||||||
|
MaxIdleTime time.Duration `mapstructure:"maxIdleTime"` |
||||||
|
BufferSize int `mapstructure:"bufferSize"` |
||||||
|
LogConnections bool `mapstructure:"logConnections"` |
||||||
|
|
||||||
|
DNSCache DNSCacheConfig `mapstructure:"dnsCache"` |
||||||
|
RateLimit RateLimitConfig `mapstructure:"rateLimit"` |
||||||
|
AccessControl AccessControlConfig `mapstructure:"accessControl"` |
||||||
|
Metrics MetricsConfig `mapstructure:"metrics"` |
||||||
|
ConnectionPool ConnectionPoolConfig `mapstructure:"connectionPool"` |
||||||
|
Memory MemoryConfig `mapstructure:"memory"` |
||||||
|
Transparent TransparentConfig `mapstructure:"transparent"` |
||||||
|
} |
||||||
|
|
||||||
|
// DNSCacheConfig DNS缓存配置
|
||||||
|
type DNSCacheConfig struct { |
||||||
|
Enabled bool `mapstructure:"enabled"` |
||||||
|
MaxSize int `mapstructure:"maxSize"` |
||||||
|
TTL time.Duration `mapstructure:"ttl"` |
||||||
|
} |
||||||
|
|
||||||
|
// RateLimitConfig 速率限制配置
|
||||||
|
type RateLimitConfig struct { |
||||||
|
Enabled bool `mapstructure:"enabled"` |
||||||
|
RequestsPerSecond int `mapstructure:"requestsPerSecond"` |
||||||
|
BurstSize int `mapstructure:"burstSize"` |
||||||
|
PerIPRequestsPerSec int `mapstructure:"perIPRequestsPerSec"` |
||||||
|
PerIPBurstSize int `mapstructure:"perIPBurstSize"` |
||||||
|
CleanupInterval time.Duration `mapstructure:"cleanupInterval"` |
||||||
|
} |
||||||
|
|
||||||
|
// AccessControlConfig 访问控制配置
|
||||||
|
type AccessControlConfig struct { |
||||||
|
AllowedIPs []string `mapstructure:"allowedIPs"` |
||||||
|
} |
||||||
|
|
||||||
|
// MetricsConfig 指标配置
|
||||||
|
type MetricsConfig struct { |
||||||
|
Enabled bool `mapstructure:"enabled"` |
||||||
|
Interval time.Duration `mapstructure:"interval"` |
||||||
|
} |
||||||
|
|
||||||
|
// ConnectionPoolConfig 连接池配置
|
||||||
|
type ConnectionPoolConfig struct { |
||||||
|
Enabled bool `mapstructure:"enabled"` |
||||||
|
MaxSize int `mapstructure:"maxSize"` |
||||||
|
MaxLifetime time.Duration `mapstructure:"maxLifetime"` |
||||||
|
MaxIdle time.Duration `mapstructure:"maxIdle"` |
||||||
|
InitialSize int `mapstructure:"initialSize"` |
||||||
|
} |
||||||
|
|
||||||
|
// MemoryConfig 内存优化配置
|
||||||
|
type MemoryConfig struct { |
||||||
|
Enabled bool `mapstructure:"enabled"` |
||||||
|
BufferSizes []int `mapstructure:"bufferSizes"` |
||||||
|
MonitorInterval time.Duration `mapstructure:"monitorInterval"` |
||||||
|
EnableAutoGC bool `mapstructure:"enableAutoGC"` |
||||||
|
HeapAllocThresholdMB int64 `mapstructure:"heapAllocThresholdMB"` |
||||||
|
HeapSysThresholdMB int64 `mapstructure:"heapSysThresholdMB"` |
||||||
|
ForceGCThresholdMB int64 `mapstructure:"forceGCThresholdMB"` |
||||||
|
} |
||||||
|
|
||||||
|
// TransparentConfig 透明代理配置
|
||||||
|
type TransparentConfig struct { |
||||||
|
Enabled bool `mapstructure:"enabled"` |
||||||
|
TransparentPort int `mapstructure:"transparentPort"` |
||||||
|
DNSPort int `mapstructure:"dnsPort"` |
||||||
|
BypassIPs []string `mapstructure:"bypassIPs"` |
||||||
|
BypassDomains []string `mapstructure:"bypassDomains"` |
||||||
|
} |
||||||
|
|
||||||
|
// LoadConfig 加载配置文件
|
||||||
|
func LoadConfig(configPath string) (*ServerConfig, error) { |
||||||
|
// 检查配置文件是否存在
|
||||||
|
if _, err := os.Stat(configPath); os.IsNotExist(err) { |
||||||
|
return nil, fmt.Errorf("config file not found: %s", configPath) |
||||||
|
} |
||||||
|
|
||||||
|
// 初始化viper
|
||||||
|
viper.SetConfigFile(configPath) |
||||||
|
viper.SetConfigType("yaml") |
||||||
|
|
||||||
|
// 设置默认值
|
||||||
|
setDefaults() |
||||||
|
|
||||||
|
// 读取配置文件
|
||||||
|
if err := viper.ReadInConfig(); err != nil { |
||||||
|
return nil, fmt.Errorf("failed to read config file: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
// 解析配置
|
||||||
|
var config ServerConfig |
||||||
|
if err := viper.Unmarshal(&config); err != nil { |
||||||
|
return nil, fmt.Errorf("failed to unmarshal config: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
// 验证配置
|
||||||
|
if err := validateConfig(&config); err != nil { |
||||||
|
return nil, fmt.Errorf("invalid config: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
return &config, nil |
||||||
|
} |
||||||
|
|
||||||
|
// setDefaults 设置默认配置值
|
||||||
|
func setDefaults() { |
||||||
|
// 基本配置默认值
|
||||||
|
viper.SetDefault("serviceType", "server") |
||||||
|
viper.SetDefault("proxy.address", "0.0.0.0") |
||||||
|
viper.SetDefault("proxy.port", 1080) |
||||||
|
viper.SetDefault("timeout", "30s") |
||||||
|
viper.SetDefault("maxConns", 5000) |
||||||
|
viper.SetDefault("logLevel", "info") |
||||||
|
|
||||||
|
// 健康检查默认值
|
||||||
|
viper.SetDefault("healthCheck.enabled", true) |
||||||
|
viper.SetDefault("healthCheck.address", "127.0.0.1") |
||||||
|
viper.SetDefault("healthCheck.port", 8090) |
||||||
|
|
||||||
|
// 优化服务器默认值
|
||||||
|
viper.SetDefault("optimizedServer.enabled", true) |
||||||
|
viper.SetDefault("optimizedServer.maxIdleTime", "5m") |
||||||
|
viper.SetDefault("optimizedServer.bufferSize", 65536) |
||||||
|
viper.SetDefault("optimizedServer.logConnections", true) |
||||||
|
|
||||||
|
// DNS缓存默认值
|
||||||
|
viper.SetDefault("optimizedServer.dnsCache.enabled", true) |
||||||
|
viper.SetDefault("optimizedServer.dnsCache.maxSize", 10000) |
||||||
|
viper.SetDefault("optimizedServer.dnsCache.ttl", "10m") |
||||||
|
|
||||||
|
// 速率限制默认值
|
||||||
|
viper.SetDefault("optimizedServer.rateLimit.enabled", true) |
||||||
|
viper.SetDefault("optimizedServer.rateLimit.requestsPerSecond", 100) |
||||||
|
|
||||||
|
// 指标默认值
|
||||||
|
viper.SetDefault("optimizedServer.metrics.enabled", true) |
||||||
|
viper.SetDefault("optimizedServer.metrics.interval", "5m") |
||||||
|
|
||||||
|
// 连接池默认值
|
||||||
|
viper.SetDefault("optimizedServer.connectionPool.enabled", true) |
||||||
|
viper.SetDefault("optimizedServer.connectionPool.maxSize", 1000) |
||||||
|
viper.SetDefault("optimizedServer.connectionPool.maxLifetime", "30m") |
||||||
|
viper.SetDefault("optimizedServer.connectionPool.maxIdle", "5m") |
||||||
|
viper.SetDefault("optimizedServer.connectionPool.initialSize", 100) |
||||||
|
|
||||||
|
// 内存优化默认值
|
||||||
|
viper.SetDefault("optimizedServer.memory.enabled", true) |
||||||
|
viper.SetDefault("optimizedServer.memory.bufferSizes", []int{64, 128, 256, 512, 1024}) |
||||||
|
viper.SetDefault("optimizedServer.memory.monitorInterval", "5m") |
||||||
|
viper.SetDefault("optimizedServer.memory.enableAutoGC", true) |
||||||
|
viper.SetDefault("optimizedServer.memory.heapAllocThresholdMB", 1024) |
||||||
|
viper.SetDefault("optimizedServer.memory.heapSysThresholdMB", 2048) |
||||||
|
viper.SetDefault("optimizedServer.memory.forceGCThresholdMB", 512) |
||||||
|
|
||||||
|
// 透明代理默认值
|
||||||
|
viper.SetDefault("optimizedServer.transparent.enabled", false) |
||||||
|
viper.SetDefault("optimizedServer.transparent.transparentPort", 8080) |
||||||
|
viper.SetDefault("optimizedServer.transparent.dnsPort", 53) |
||||||
|
viper.SetDefault("optimizedServer.transparent.bypassIPs", []string{}) |
||||||
|
viper.SetDefault("optimizedServer.transparent.bypassDomains", []string{}) |
||||||
|
} |
||||||
|
|
||||||
|
// validateConfig 验证配置
|
||||||
|
func validateConfig(config *ServerConfig) error { |
||||||
|
// 验证端口范围
|
||||||
|
if config.Proxy.Port < 1 || config.Proxy.Port > 65535 { |
||||||
|
return fmt.Errorf("invalid proxy port: %d", config.Proxy.Port) |
||||||
|
} |
||||||
|
|
||||||
|
if config.HealthCheck.Enabled { |
||||||
|
if config.HealthCheck.Port < 1 || config.HealthCheck.Port > 65535 { |
||||||
|
return fmt.Errorf("invalid health check port: %d", config.HealthCheck.Port) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 验证认证配置
|
||||||
|
if config.Auth.Username != "" && config.Auth.Password == "" { |
||||||
|
return fmt.Errorf("password is required when username is set") |
||||||
|
} |
||||||
|
|
||||||
|
// 验证日志级别
|
||||||
|
switch config.LogLevel { |
||||||
|
case "debug", "info", "warn", "error": |
||||||
|
// 有效的日志级别
|
||||||
|
default: |
||||||
|
return fmt.Errorf("invalid log level: %s", config.LogLevel) |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetLogLevel 获取logrus日志级别
|
||||||
|
func GetLogLevel(level string) logrus.Level { |
||||||
|
switch level { |
||||||
|
case "debug": |
||||||
|
return logrus.DebugLevel |
||||||
|
case "info": |
||||||
|
return logrus.InfoLevel |
||||||
|
case "warn": |
||||||
|
return logrus.WarnLevel |
||||||
|
case "error": |
||||||
|
return logrus.ErrorLevel |
||||||
|
default: |
||||||
|
return logrus.InfoLevel |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// ToSOCKS5Config 转换为SOCKS5配置
|
||||||
|
func (c *ServerConfig) ToSOCKS5Config() SOCKS5Config { |
||||||
|
return SOCKS5Config{ |
||||||
|
Auth: SOCKS5AuthConfig{ |
||||||
|
Methods: c.Auth.Methods, |
||||||
|
Username: c.Auth.Username, |
||||||
|
Password: c.Auth.Password, |
||||||
|
}, |
||||||
|
Timeout: c.Timeout, |
||||||
|
Rules: []SOCKS5RuleConfig{}, // 从访问控制配置转换
|
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// SOCKS5Config SOCKS5特定配置
|
||||||
|
type SOCKS5Config struct { |
||||||
|
Auth SOCKS5AuthConfig `json:"auth"` |
||||||
|
Timeout time.Duration `json:"timeout"` |
||||||
|
Rules []SOCKS5RuleConfig `json:"rules"` |
||||||
|
} |
||||||
|
|
||||||
|
// SOCKS5AuthConfig SOCKS5认证配置
|
||||||
|
type SOCKS5AuthConfig struct { |
||||||
|
Methods []string `json:"methods"` |
||||||
|
Username string `json:"username"` |
||||||
|
Password string `json:"password"` |
||||||
|
} |
||||||
|
|
||||||
|
// SOCKS5RuleConfig SOCKS5规则配置
|
||||||
|
type SOCKS5RuleConfig struct { |
||||||
|
Action string `json:"action"` |
||||||
|
IPs []string `json:"ips"` |
||||||
|
Ports []int `json:"ports"` |
||||||
|
} |
@ -0,0 +1,379 @@ |
|||||||
|
package memory |
||||||
|
|
||||||
|
import ( |
||||||
|
"runtime" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/sirupsen/logrus" |
||||||
|
) |
||||||
|
|
||||||
|
// BufferPool 缓冲区池
|
||||||
|
type BufferPool struct { |
||||||
|
pools map[int]*sync.Pool |
||||||
|
sizes []int |
||||||
|
logger *logrus.Logger |
||||||
|
stats BufferStats |
||||||
|
} |
||||||
|
|
||||||
|
// BufferStats 缓冲区统计
|
||||||
|
type BufferStats struct { |
||||||
|
mu sync.RWMutex |
||||||
|
Gets int64 `json:"gets"` |
||||||
|
Puts int64 `json:"puts"` |
||||||
|
News int64 `json:"news"` |
||||||
|
Reuses int64 `json:"reuses"` |
||||||
|
TotalAlloc int64 `json:"totalAlloc"` |
||||||
|
TotalReused int64 `json:"totalReused"` |
||||||
|
} |
||||||
|
|
||||||
|
// MemoryMonitor 内存监控器
|
||||||
|
type MemoryMonitor struct { |
||||||
|
logger *logrus.Logger |
||||||
|
ticker *time.Ticker |
||||||
|
stopCh chan struct{} |
||||||
|
thresholds Thresholds |
||||||
|
lastStats runtime.MemStats |
||||||
|
callbacks []MemoryCallback |
||||||
|
mu sync.RWMutex |
||||||
|
} |
||||||
|
|
||||||
|
// Thresholds 内存阈值配置
|
||||||
|
type Thresholds struct { |
||||||
|
HeapAllocMB int64 `json:"heapAllocMB"` |
||||||
|
HeapSysMB int64 `json:"heapSysMB"` |
||||||
|
GCPercent int `json:"gcPercent"` |
||||||
|
ForceGCThreshMB int64 `json:"forceGCThreshMB"` |
||||||
|
} |
||||||
|
|
||||||
|
// MemoryCallback 内存回调函数
|
||||||
|
type MemoryCallback func(stats runtime.MemStats) |
||||||
|
|
||||||
|
// Config 内存管理配置
|
||||||
|
type Config struct { |
||||||
|
BufferSizes []int `json:"bufferSizes"` |
||||||
|
MonitorInterval time.Duration `json:"monitorInterval"` |
||||||
|
Thresholds Thresholds `json:"thresholds"` |
||||||
|
EnableAutoGC bool `json:"enableAutoGC"` |
||||||
|
EnableOptimization bool `json:"enableOptimization"` |
||||||
|
} |
||||||
|
|
||||||
|
// Manager 内存管理器
|
||||||
|
type Manager struct { |
||||||
|
bufferPool *BufferPool |
||||||
|
monitor *MemoryMonitor |
||||||
|
config Config |
||||||
|
logger *logrus.Logger |
||||||
|
} |
||||||
|
|
||||||
|
// NewManager 创建内存管理器
|
||||||
|
func NewManager(config Config, logger *logrus.Logger) *Manager { |
||||||
|
// 设置默认值
|
||||||
|
if len(config.BufferSizes) == 0 { |
||||||
|
config.BufferSizes = []int{512, 1024, 2048, 4096, 8192, 16384, 32768, 65536} |
||||||
|
} |
||||||
|
if config.MonitorInterval == 0 { |
||||||
|
config.MonitorInterval = 30 * time.Second |
||||||
|
} |
||||||
|
if config.Thresholds.HeapAllocMB == 0 { |
||||||
|
config.Thresholds.HeapAllocMB = 100 |
||||||
|
} |
||||||
|
if config.Thresholds.HeapSysMB == 0 { |
||||||
|
config.Thresholds.HeapSysMB = 200 |
||||||
|
} |
||||||
|
if config.Thresholds.GCPercent == 0 { |
||||||
|
config.Thresholds.GCPercent = 100 |
||||||
|
} |
||||||
|
if config.Thresholds.ForceGCThreshMB == 0 { |
||||||
|
config.Thresholds.ForceGCThreshMB = 500 |
||||||
|
} |
||||||
|
|
||||||
|
manager := &Manager{ |
||||||
|
config: config, |
||||||
|
logger: logger, |
||||||
|
} |
||||||
|
|
||||||
|
// 创建缓冲区池
|
||||||
|
if config.EnableOptimization { |
||||||
|
manager.bufferPool = NewBufferPool(config.BufferSizes, logger) |
||||||
|
} |
||||||
|
|
||||||
|
// 创建内存监控器
|
||||||
|
manager.monitor = NewMemoryMonitor(config.MonitorInterval, config.Thresholds, logger) |
||||||
|
|
||||||
|
// 启用自动GC
|
||||||
|
if config.EnableAutoGC { |
||||||
|
manager.monitor.AddCallback(manager.autoGCCallback) |
||||||
|
} |
||||||
|
|
||||||
|
return manager |
||||||
|
} |
||||||
|
|
||||||
|
// NewBufferPool 创建缓冲区池
|
||||||
|
func NewBufferPool(sizes []int, logger *logrus.Logger) *BufferPool { |
||||||
|
bp := &BufferPool{ |
||||||
|
pools: make(map[int]*sync.Pool), |
||||||
|
sizes: make([]int, len(sizes)), |
||||||
|
logger: logger, |
||||||
|
} |
||||||
|
|
||||||
|
copy(bp.sizes, sizes) |
||||||
|
|
||||||
|
// 为每个大小创建池
|
||||||
|
for _, size := range sizes { |
||||||
|
sz := size // 捕获循环变量
|
||||||
|
bp.pools[sz] = &sync.Pool{ |
||||||
|
New: func() interface{} { |
||||||
|
bp.stats.incNews() |
||||||
|
bp.stats.addTotalAlloc(int64(sz)) |
||||||
|
return make([]byte, sz) |
||||||
|
}, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return bp |
||||||
|
} |
||||||
|
|
||||||
|
// Get 获取缓冲区
|
||||||
|
func (bp *BufferPool) Get(size int) []byte { |
||||||
|
bp.stats.incGets() |
||||||
|
|
||||||
|
// 找到最合适的大小
|
||||||
|
for _, poolSize := range bp.sizes { |
||||||
|
if size <= poolSize { |
||||||
|
buf := bp.pools[poolSize].Get().([]byte) |
||||||
|
bp.stats.incReuses() |
||||||
|
bp.stats.addTotalReused(int64(poolSize)) |
||||||
|
return buf[:size] |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 没有合适的池,创建新缓冲区
|
||||||
|
bp.stats.incNews() |
||||||
|
bp.stats.addTotalAlloc(int64(size)) |
||||||
|
return make([]byte, size) |
||||||
|
} |
||||||
|
|
||||||
|
// Put 归还缓冲区
|
||||||
|
func (bp *BufferPool) Put(buf []byte) { |
||||||
|
if buf == nil { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
bp.stats.incPuts() |
||||||
|
capacity := cap(buf) |
||||||
|
|
||||||
|
// 找到对应的池
|
||||||
|
for _, poolSize := range bp.sizes { |
||||||
|
if capacity == poolSize { |
||||||
|
bp.pools[poolSize].Put(buf[:poolSize]) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// GetStats 获取缓冲区统计
|
||||||
|
func (bp *BufferPool) GetStats() BufferStats { |
||||||
|
bp.stats.mu.RLock() |
||||||
|
defer bp.stats.mu.RUnlock() |
||||||
|
return bp.stats |
||||||
|
} |
||||||
|
|
||||||
|
// NewMemoryMonitor 创建内存监控器
|
||||||
|
func NewMemoryMonitor(interval time.Duration, thresholds Thresholds, logger *logrus.Logger) *MemoryMonitor { |
||||||
|
monitor := &MemoryMonitor{ |
||||||
|
logger: logger, |
||||||
|
ticker: time.NewTicker(interval), |
||||||
|
stopCh: make(chan struct{}), |
||||||
|
thresholds: thresholds, |
||||||
|
callbacks: make([]MemoryCallback, 0), |
||||||
|
} |
||||||
|
|
||||||
|
// 启动监控
|
||||||
|
go monitor.run() |
||||||
|
|
||||||
|
return monitor |
||||||
|
} |
||||||
|
|
||||||
|
// AddCallback 添加内存回调
|
||||||
|
func (mm *MemoryMonitor) AddCallback(callback MemoryCallback) { |
||||||
|
mm.mu.Lock() |
||||||
|
mm.callbacks = append(mm.callbacks, callback) |
||||||
|
mm.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
// Stop 停止监控
|
||||||
|
func (mm *MemoryMonitor) Stop() { |
||||||
|
close(mm.stopCh) |
||||||
|
mm.ticker.Stop() |
||||||
|
} |
||||||
|
|
||||||
|
// run 运行监控
|
||||||
|
func (mm *MemoryMonitor) run() { |
||||||
|
for { |
||||||
|
select { |
||||||
|
case <-mm.stopCh: |
||||||
|
return |
||||||
|
case <-mm.ticker.C: |
||||||
|
mm.check() |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// check 检查内存状态
|
||||||
|
func (mm *MemoryMonitor) check() { |
||||||
|
var stats runtime.MemStats |
||||||
|
runtime.ReadMemStats(&stats) |
||||||
|
|
||||||
|
// 记录内存统计
|
||||||
|
heapAllocMB := stats.HeapAlloc / 1024 / 1024 |
||||||
|
heapSysMB := stats.HeapSys / 1024 / 1024 |
||||||
|
|
||||||
|
mm.logger.WithFields(logrus.Fields{ |
||||||
|
"heap_alloc_mb": heapAllocMB, |
||||||
|
"heap_sys_mb": heapSysMB, |
||||||
|
"gc_num": stats.NumGC, |
||||||
|
"goroutines": runtime.NumGoroutine(), |
||||||
|
}).Debug("Memory stats") |
||||||
|
|
||||||
|
// 检查阈值
|
||||||
|
if int64(heapAllocMB) > mm.thresholds.HeapAllocMB { |
||||||
|
mm.logger.WithField("heap_alloc_mb", heapAllocMB).Warn("Heap allocation threshold exceeded") |
||||||
|
} |
||||||
|
|
||||||
|
if int64(heapSysMB) > mm.thresholds.HeapSysMB { |
||||||
|
mm.logger.WithField("heap_sys_mb", heapSysMB).Warn("Heap system threshold exceeded") |
||||||
|
} |
||||||
|
|
||||||
|
// 强制GC阈值
|
||||||
|
if int64(heapAllocMB) > mm.thresholds.ForceGCThreshMB { |
||||||
|
mm.logger.WithField("heap_alloc_mb", heapAllocMB).Info("Force GC triggered") |
||||||
|
runtime.GC() |
||||||
|
} |
||||||
|
|
||||||
|
// 调用回调函数
|
||||||
|
mm.mu.RLock() |
||||||
|
for _, callback := range mm.callbacks { |
||||||
|
go callback(stats) |
||||||
|
} |
||||||
|
mm.mu.RUnlock() |
||||||
|
|
||||||
|
mm.lastStats = stats |
||||||
|
} |
||||||
|
|
||||||
|
// GetStats 获取最新内存统计
|
||||||
|
func (mm *MemoryMonitor) GetStats() runtime.MemStats { |
||||||
|
return mm.lastStats |
||||||
|
} |
||||||
|
|
||||||
|
// Manager 方法
|
||||||
|
|
||||||
|
// GetBuffer 获取缓冲区
|
||||||
|
func (m *Manager) GetBuffer(size int) []byte { |
||||||
|
if m.bufferPool != nil { |
||||||
|
return m.bufferPool.Get(size) |
||||||
|
} |
||||||
|
return make([]byte, size) |
||||||
|
} |
||||||
|
|
||||||
|
// PutBuffer 归还缓冲区
|
||||||
|
func (m *Manager) PutBuffer(buf []byte) { |
||||||
|
if m.bufferPool != nil { |
||||||
|
m.bufferPool.Put(buf) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// GetBufferStats 获取缓冲区统计
|
||||||
|
func (m *Manager) GetBufferStats() BufferStats { |
||||||
|
if m.bufferPool != nil { |
||||||
|
return m.bufferPool.GetStats() |
||||||
|
} |
||||||
|
return BufferStats{} |
||||||
|
} |
||||||
|
|
||||||
|
// GetMemoryStats 获取内存统计
|
||||||
|
func (m *Manager) GetMemoryStats() runtime.MemStats { |
||||||
|
return m.monitor.GetStats() |
||||||
|
} |
||||||
|
|
||||||
|
// Stop 停止内存管理器
|
||||||
|
func (m *Manager) Stop() { |
||||||
|
if m.monitor != nil { |
||||||
|
m.monitor.Stop() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// ForceGC 强制垃圾回收
|
||||||
|
func (m *Manager) ForceGC() { |
||||||
|
m.logger.Info("Manual GC triggered") |
||||||
|
runtime.GC() |
||||||
|
} |
||||||
|
|
||||||
|
// autoGCCallback 自动GC回调
|
||||||
|
func (m *Manager) autoGCCallback(stats runtime.MemStats) { |
||||||
|
heapAllocMB := stats.HeapAlloc / 1024 / 1024 |
||||||
|
|
||||||
|
// 当堆分配超过阈值时触发GC
|
||||||
|
if int64(heapAllocMB) > m.config.Thresholds.ForceGCThreshMB { |
||||||
|
m.ForceGC() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// GetOverallStats 获取总体统计信息
|
||||||
|
func (m *Manager) GetOverallStats() map[string]interface{} { |
||||||
|
memStats := m.GetMemoryStats() |
||||||
|
bufferStats := m.GetBufferStats() |
||||||
|
|
||||||
|
return map[string]interface{}{ |
||||||
|
"memory": map[string]interface{}{ |
||||||
|
"heap_alloc_mb": memStats.HeapAlloc / 1024 / 1024, |
||||||
|
"heap_sys_mb": memStats.HeapSys / 1024 / 1024, |
||||||
|
"gc_num": memStats.NumGC, |
||||||
|
"goroutines": runtime.NumGoroutine(), |
||||||
|
}, |
||||||
|
"buffers": bufferStats, |
||||||
|
"config": map[string]interface{}{ |
||||||
|
"optimization_enabled": m.config.EnableOptimization, |
||||||
|
"auto_gc_enabled": m.config.EnableAutoGC, |
||||||
|
"buffer_sizes": m.config.BufferSizes, |
||||||
|
}, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// BufferStats 方法
|
||||||
|
|
||||||
|
func (bs *BufferStats) incGets() { |
||||||
|
bs.mu.Lock() |
||||||
|
bs.Gets++ |
||||||
|
bs.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (bs *BufferStats) incPuts() { |
||||||
|
bs.mu.Lock() |
||||||
|
bs.Puts++ |
||||||
|
bs.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (bs *BufferStats) incNews() { |
||||||
|
bs.mu.Lock() |
||||||
|
bs.News++ |
||||||
|
bs.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (bs *BufferStats) incReuses() { |
||||||
|
bs.mu.Lock() |
||||||
|
bs.Reuses++ |
||||||
|
bs.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (bs *BufferStats) addTotalAlloc(size int64) { |
||||||
|
bs.mu.Lock() |
||||||
|
bs.TotalAlloc += size |
||||||
|
bs.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (bs *BufferStats) addTotalReused(size int64) { |
||||||
|
bs.mu.Lock() |
||||||
|
bs.TotalReused += size |
||||||
|
bs.mu.Unlock() |
||||||
|
} |
@ -0,0 +1,413 @@ |
|||||||
|
package pool |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"net" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/sirupsen/logrus" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
ErrPoolClosed = errors.New("connection pool is closed") |
||||||
|
ErrPoolFull = errors.New("connection pool is full") |
||||||
|
ErrConnExpired = errors.New("connection has expired") |
||||||
|
ErrConnInvalid = errors.New("connection is invalid") |
||||||
|
) |
||||||
|
|
||||||
|
// ConnectionPool 连接池
|
||||||
|
type ConnectionPool struct { |
||||||
|
logger *logrus.Logger |
||||||
|
factory Factory |
||||||
|
pool chan *PooledConnection |
||||||
|
mu sync.RWMutex |
||||||
|
closed bool |
||||||
|
maxSize int |
||||||
|
maxLifetime time.Duration |
||||||
|
maxIdle time.Duration |
||||||
|
|
||||||
|
// 统计信息
|
||||||
|
stats Stats |
||||||
|
} |
||||||
|
|
||||||
|
// PooledConnection 池化连接
|
||||||
|
type PooledConnection struct { |
||||||
|
conn net.Conn |
||||||
|
createdAt time.Time |
||||||
|
lastUsed time.Time |
||||||
|
pool *ConnectionPool |
||||||
|
} |
||||||
|
|
||||||
|
// Factory 连接工厂
|
||||||
|
type Factory interface { |
||||||
|
Create() (net.Conn, error) |
||||||
|
Validate(net.Conn) bool |
||||||
|
Close(net.Conn) error |
||||||
|
} |
||||||
|
|
||||||
|
// Config 连接池配置
|
||||||
|
type Config struct { |
||||||
|
MaxSize int `json:"maxSize"` |
||||||
|
MaxLifetime time.Duration `json:"maxLifetime"` |
||||||
|
MaxIdle time.Duration `json:"maxIdle"` |
||||||
|
InitialSize int `json:"initialSize"` |
||||||
|
} |
||||||
|
|
||||||
|
// Stats 统计信息
|
||||||
|
type Stats struct { |
||||||
|
mu sync.RWMutex |
||||||
|
Created int64 `json:"created"` |
||||||
|
Reused int64 `json:"reused"` |
||||||
|
Closed int64 `json:"closed"` |
||||||
|
Active int64 `json:"active"` |
||||||
|
Idle int64 `json:"idle"` |
||||||
|
Failures int64 `json:"failures"` |
||||||
|
} |
||||||
|
|
||||||
|
// NewConnectionPool 创建新的连接池
|
||||||
|
func NewConnectionPool(config Config, factory Factory, logger *logrus.Logger) (*ConnectionPool, error) { |
||||||
|
if config.MaxSize <= 0 { |
||||||
|
config.MaxSize = 100 |
||||||
|
} |
||||||
|
if config.MaxLifetime == 0 { |
||||||
|
config.MaxLifetime = 30 * time.Minute |
||||||
|
} |
||||||
|
if config.MaxIdle == 0 { |
||||||
|
config.MaxIdle = 5 * time.Minute |
||||||
|
} |
||||||
|
|
||||||
|
pool := &ConnectionPool{ |
||||||
|
logger: logger, |
||||||
|
factory: factory, |
||||||
|
pool: make(chan *PooledConnection, config.MaxSize), |
||||||
|
maxSize: config.MaxSize, |
||||||
|
maxLifetime: config.MaxLifetime, |
||||||
|
maxIdle: config.MaxIdle, |
||||||
|
} |
||||||
|
|
||||||
|
// 预创建连接
|
||||||
|
for i := 0; i < config.InitialSize && i < config.MaxSize; i++ { |
||||||
|
conn, err := pool.factory.Create() |
||||||
|
if err != nil { |
||||||
|
pool.logger.WithError(err).Warn("Failed to create initial connection") |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
pooledConn := &PooledConnection{ |
||||||
|
conn: conn, |
||||||
|
createdAt: time.Now(), |
||||||
|
lastUsed: time.Now(), |
||||||
|
pool: pool, |
||||||
|
} |
||||||
|
|
||||||
|
select { |
||||||
|
case pool.pool <- pooledConn: |
||||||
|
pool.stats.incCreated() |
||||||
|
pool.stats.incIdle() |
||||||
|
default: |
||||||
|
conn.Close() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 启动清理goroutine
|
||||||
|
go pool.cleaner() |
||||||
|
|
||||||
|
return pool, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Get 获取连接
|
||||||
|
func (p *ConnectionPool) Get() (*PooledConnection, error) { |
||||||
|
p.mu.RLock() |
||||||
|
if p.closed { |
||||||
|
p.mu.RUnlock() |
||||||
|
return nil, ErrPoolClosed |
||||||
|
} |
||||||
|
p.mu.RUnlock() |
||||||
|
|
||||||
|
// 尝试从池中获取连接
|
||||||
|
for { |
||||||
|
select { |
||||||
|
case conn := <-p.pool: |
||||||
|
p.stats.decIdle() |
||||||
|
|
||||||
|
// 检查连接是否有效
|
||||||
|
if p.isConnValid(conn) { |
||||||
|
conn.lastUsed = time.Now() |
||||||
|
p.stats.incReused() |
||||||
|
p.stats.incActive() |
||||||
|
return conn, nil |
||||||
|
} |
||||||
|
|
||||||
|
// 连接无效,关闭并继续尝试
|
||||||
|
p.closeConn(conn) |
||||||
|
continue |
||||||
|
|
||||||
|
default: |
||||||
|
// 池中没有可用连接,创建新连接
|
||||||
|
return p.createConnection() |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Put 归还连接到池
|
||||||
|
func (p *ConnectionPool) Put(conn *PooledConnection) error { |
||||||
|
if conn == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
p.mu.RLock() |
||||||
|
if p.closed { |
||||||
|
p.mu.RUnlock() |
||||||
|
p.closeConn(conn) |
||||||
|
return ErrPoolClosed |
||||||
|
} |
||||||
|
p.mu.RUnlock() |
||||||
|
|
||||||
|
p.stats.decActive() |
||||||
|
|
||||||
|
// 检查连接是否有效
|
||||||
|
if !p.isConnValid(conn) { |
||||||
|
p.closeConn(conn) |
||||||
|
return ErrConnInvalid |
||||||
|
} |
||||||
|
|
||||||
|
// 尝试归还到池
|
||||||
|
select { |
||||||
|
case p.pool <- conn: |
||||||
|
p.stats.incIdle() |
||||||
|
return nil |
||||||
|
default: |
||||||
|
// 池已满,关闭连接
|
||||||
|
p.closeConn(conn) |
||||||
|
return ErrPoolFull |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Close 关闭连接池
|
||||||
|
func (p *ConnectionPool) Close() error { |
||||||
|
p.mu.Lock() |
||||||
|
if p.closed { |
||||||
|
p.mu.Unlock() |
||||||
|
return nil |
||||||
|
} |
||||||
|
p.closed = true |
||||||
|
p.mu.Unlock() |
||||||
|
|
||||||
|
// 关闭所有池中的连接
|
||||||
|
close(p.pool) |
||||||
|
for conn := range p.pool { |
||||||
|
p.closeConn(conn) |
||||||
|
} |
||||||
|
|
||||||
|
p.logger.Info("Connection pool closed") |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetStats 获取统计信息
|
||||||
|
func (p *ConnectionPool) GetStats() Stats { |
||||||
|
p.stats.mu.RLock() |
||||||
|
defer p.stats.mu.RUnlock() |
||||||
|
|
||||||
|
stats := p.stats |
||||||
|
stats.Idle = int64(len(p.pool)) |
||||||
|
return stats |
||||||
|
} |
||||||
|
|
||||||
|
// createConnection 创建新连接
|
||||||
|
func (p *ConnectionPool) createConnection() (*PooledConnection, error) { |
||||||
|
conn, err := p.factory.Create() |
||||||
|
if err != nil { |
||||||
|
p.stats.incFailures() |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
pooledConn := &PooledConnection{ |
||||||
|
conn: conn, |
||||||
|
createdAt: time.Now(), |
||||||
|
lastUsed: time.Now(), |
||||||
|
pool: p, |
||||||
|
} |
||||||
|
|
||||||
|
p.stats.incCreated() |
||||||
|
p.stats.incActive() |
||||||
|
return pooledConn, nil |
||||||
|
} |
||||||
|
|
||||||
|
// isConnValid 检查连接是否有效
|
||||||
|
func (p *ConnectionPool) isConnValid(conn *PooledConnection) bool { |
||||||
|
now := time.Now() |
||||||
|
|
||||||
|
// 检查连接生命周期
|
||||||
|
if now.Sub(conn.createdAt) > p.maxLifetime { |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
// 检查空闲时间
|
||||||
|
if now.Sub(conn.lastUsed) > p.maxIdle { |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
// 使用工厂验证连接
|
||||||
|
return p.factory.Validate(conn.conn) |
||||||
|
} |
||||||
|
|
||||||
|
// closeConn 关闭连接
|
||||||
|
func (p *ConnectionPool) closeConn(conn *PooledConnection) { |
||||||
|
if conn != nil && conn.conn != nil { |
||||||
|
p.factory.Close(conn.conn) |
||||||
|
p.stats.incClosed() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// cleaner 清理过期连接
|
||||||
|
func (p *ConnectionPool) cleaner() { |
||||||
|
ticker := time.NewTicker(1 * time.Minute) |
||||||
|
defer ticker.Stop() |
||||||
|
|
||||||
|
for { |
||||||
|
select { |
||||||
|
case <-ticker.C: |
||||||
|
p.cleanExpiredConnections() |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// cleanExpiredConnections 清理过期连接
|
||||||
|
func (p *ConnectionPool) cleanExpiredConnections() { |
||||||
|
p.mu.RLock() |
||||||
|
if p.closed { |
||||||
|
p.mu.RUnlock() |
||||||
|
return |
||||||
|
} |
||||||
|
p.mu.RUnlock() |
||||||
|
|
||||||
|
// 检查池中的连接
|
||||||
|
poolSize := len(p.pool) |
||||||
|
cleaned := 0 |
||||||
|
|
||||||
|
for i := 0; i < poolSize; i++ { |
||||||
|
select { |
||||||
|
case conn := <-p.pool: |
||||||
|
if p.isConnValid(conn) { |
||||||
|
// 连接有效,放回池中
|
||||||
|
select { |
||||||
|
case p.pool <- conn: |
||||||
|
default: |
||||||
|
// 池满了,关闭连接
|
||||||
|
p.closeConn(conn) |
||||||
|
cleaned++ |
||||||
|
} |
||||||
|
} else { |
||||||
|
// 连接无效,关闭
|
||||||
|
p.closeConn(conn) |
||||||
|
p.stats.decIdle() |
||||||
|
cleaned++ |
||||||
|
} |
||||||
|
default: |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if cleaned > 0 { |
||||||
|
p.logger.WithField("count", cleaned).Debug("Cleaned expired connections") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// PooledConnection 方法
|
||||||
|
|
||||||
|
// Read 读取数据
|
||||||
|
func (pc *PooledConnection) Read(b []byte) (n int, err error) { |
||||||
|
return pc.conn.Read(b) |
||||||
|
} |
||||||
|
|
||||||
|
// Write 写入数据
|
||||||
|
func (pc *PooledConnection) Write(b []byte) (n int, err error) { |
||||||
|
return pc.conn.Write(b) |
||||||
|
} |
||||||
|
|
||||||
|
// Close 关闭连接(归还到池)
|
||||||
|
func (pc *PooledConnection) Close() error { |
||||||
|
return pc.pool.Put(pc) |
||||||
|
} |
||||||
|
|
||||||
|
// ForceClose 强制关闭连接
|
||||||
|
func (pc *PooledConnection) ForceClose() error { |
||||||
|
pc.pool.closeConn(pc) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// LocalAddr 获取本地地址
|
||||||
|
func (pc *PooledConnection) LocalAddr() net.Addr { |
||||||
|
return pc.conn.LocalAddr() |
||||||
|
} |
||||||
|
|
||||||
|
// RemoteAddr 获取远程地址
|
||||||
|
func (pc *PooledConnection) RemoteAddr() net.Addr { |
||||||
|
return pc.conn.RemoteAddr() |
||||||
|
} |
||||||
|
|
||||||
|
// SetDeadline 设置截止时间
|
||||||
|
func (pc *PooledConnection) SetDeadline(t time.Time) error { |
||||||
|
return pc.conn.SetDeadline(t) |
||||||
|
} |
||||||
|
|
||||||
|
// SetReadDeadline 设置读截止时间
|
||||||
|
func (pc *PooledConnection) SetReadDeadline(t time.Time) error { |
||||||
|
return pc.conn.SetReadDeadline(t) |
||||||
|
} |
||||||
|
|
||||||
|
// SetWriteDeadline 设置写截止时间
|
||||||
|
func (pc *PooledConnection) SetWriteDeadline(t time.Time) error { |
||||||
|
return pc.conn.SetWriteDeadline(t) |
||||||
|
} |
||||||
|
|
||||||
|
// Stats 方法
|
||||||
|
|
||||||
|
func (s *Stats) incCreated() { |
||||||
|
s.mu.Lock() |
||||||
|
s.Created++ |
||||||
|
s.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Stats) incReused() { |
||||||
|
s.mu.Lock() |
||||||
|
s.Reused++ |
||||||
|
s.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Stats) incClosed() { |
||||||
|
s.mu.Lock() |
||||||
|
s.Closed++ |
||||||
|
s.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Stats) incActive() { |
||||||
|
s.mu.Lock() |
||||||
|
s.Active++ |
||||||
|
s.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Stats) decActive() { |
||||||
|
s.mu.Lock() |
||||||
|
s.Active-- |
||||||
|
s.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Stats) incIdle() { |
||||||
|
s.mu.Lock() |
||||||
|
s.Idle++ |
||||||
|
s.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Stats) decIdle() { |
||||||
|
s.mu.Lock() |
||||||
|
s.Idle-- |
||||||
|
s.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
func (s *Stats) incFailures() { |
||||||
|
s.mu.Lock() |
||||||
|
s.Failures++ |
||||||
|
s.mu.Unlock() |
||||||
|
} |
@ -0,0 +1,231 @@ |
|||||||
|
package ratelimit |
||||||
|
|
||||||
|
import ( |
||||||
|
"net" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/sirupsen/logrus" |
||||||
|
) |
||||||
|
|
||||||
|
// RateLimiter 速率限制器
|
||||||
|
type RateLimiter struct { |
||||||
|
logger *logrus.Logger |
||||||
|
globalLimit *TokenBucket |
||||||
|
perIPLimiters map[string]*TokenBucket |
||||||
|
cleanupInterval time.Duration |
||||||
|
mu sync.RWMutex |
||||||
|
stopCh chan struct{} |
||||||
|
} |
||||||
|
|
||||||
|
// TokenBucket Token桶算法实现
|
||||||
|
type TokenBucket struct { |
||||||
|
capacity int64 |
||||||
|
tokens int64 |
||||||
|
refillRate int64 // tokens per second
|
||||||
|
lastRefill time.Time |
||||||
|
mu sync.Mutex |
||||||
|
} |
||||||
|
|
||||||
|
// Config 速率限制配置
|
||||||
|
type Config struct { |
||||||
|
Enabled bool `json:"enabled"` |
||||||
|
RequestsPerSecond int `json:"requestsPerSecond"` |
||||||
|
BurstSize int `json:"burstSize"` |
||||||
|
PerIPRequestsPerSec int `json:"perIPRequestsPerSec"` |
||||||
|
PerIPBurstSize int `json:"perIPBurstSize"` |
||||||
|
CleanupInterval time.Duration `json:"cleanupInterval"` |
||||||
|
} |
||||||
|
|
||||||
|
// NewRateLimiter 创建新的速率限制器
|
||||||
|
func NewRateLimiter(config Config, logger *logrus.Logger) *RateLimiter { |
||||||
|
if !config.Enabled { |
||||||
|
return &RateLimiter{ |
||||||
|
logger: logger, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 默认值
|
||||||
|
if config.BurstSize == 0 { |
||||||
|
config.BurstSize = config.RequestsPerSecond * 2 |
||||||
|
} |
||||||
|
if config.PerIPBurstSize == 0 { |
||||||
|
config.PerIPBurstSize = config.PerIPRequestsPerSec * 2 |
||||||
|
} |
||||||
|
if config.CleanupInterval == 0 { |
||||||
|
config.CleanupInterval = 5 * time.Minute |
||||||
|
} |
||||||
|
|
||||||
|
rl := &RateLimiter{ |
||||||
|
logger: logger, |
||||||
|
globalLimit: NewTokenBucket(int64(config.BurstSize), int64(config.RequestsPerSecond)), |
||||||
|
perIPLimiters: make(map[string]*TokenBucket), |
||||||
|
cleanupInterval: config.CleanupInterval, |
||||||
|
stopCh: make(chan struct{}), |
||||||
|
} |
||||||
|
|
||||||
|
// 启动清理goroutine
|
||||||
|
go rl.cleanup() |
||||||
|
|
||||||
|
return rl |
||||||
|
} |
||||||
|
|
||||||
|
// Allow 检查是否允许请求
|
||||||
|
func (rl *RateLimiter) Allow(remoteAddr string) bool { |
||||||
|
if rl.globalLimit == nil { |
||||||
|
return true // 未启用速率限制
|
||||||
|
} |
||||||
|
|
||||||
|
// 全局速率限制
|
||||||
|
if !rl.globalLimit.Allow() { |
||||||
|
rl.logger.WithField("type", "global").Debug("Rate limit exceeded") |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
// 获取客户端IP
|
||||||
|
host, _, err := net.SplitHostPort(remoteAddr) |
||||||
|
if err != nil { |
||||||
|
host = remoteAddr |
||||||
|
} |
||||||
|
|
||||||
|
// 单IP速率限制
|
||||||
|
rl.mu.RLock() |
||||||
|
bucket, exists := rl.perIPLimiters[host] |
||||||
|
rl.mu.RUnlock() |
||||||
|
|
||||||
|
if !exists { |
||||||
|
rl.mu.Lock() |
||||||
|
bucket, exists = rl.perIPLimiters[host] |
||||||
|
if !exists { |
||||||
|
bucket = NewTokenBucket(20, 10) // 默认每IP 10 rps, burst 20
|
||||||
|
rl.perIPLimiters[host] = bucket |
||||||
|
} |
||||||
|
rl.mu.Unlock() |
||||||
|
} |
||||||
|
|
||||||
|
if !bucket.Allow() { |
||||||
|
rl.logger.WithFields(logrus.Fields{ |
||||||
|
"type": "per_ip", |
||||||
|
"ip": host, |
||||||
|
}).Debug("Per-IP rate limit exceeded") |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
// Stop 停止速率限制器
|
||||||
|
func (rl *RateLimiter) Stop() { |
||||||
|
if rl.stopCh != nil { |
||||||
|
close(rl.stopCh) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// GetStats 获取统计信息
|
||||||
|
func (rl *RateLimiter) GetStats() map[string]interface{} { |
||||||
|
if rl.globalLimit == nil { |
||||||
|
return map[string]interface{}{ |
||||||
|
"enabled": false, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
rl.mu.RLock() |
||||||
|
perIPCount := len(rl.perIPLimiters) |
||||||
|
rl.mu.RUnlock() |
||||||
|
|
||||||
|
return map[string]interface{}{ |
||||||
|
"enabled": true, |
||||||
|
"global_tokens": rl.globalLimit.GetTokens(), |
||||||
|
"per_ip_buckets": perIPCount, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// cleanup 清理过期的IP限制器
|
||||||
|
func (rl *RateLimiter) cleanup() { |
||||||
|
ticker := time.NewTicker(rl.cleanupInterval) |
||||||
|
defer ticker.Stop() |
||||||
|
|
||||||
|
for { |
||||||
|
select { |
||||||
|
case <-rl.stopCh: |
||||||
|
return |
||||||
|
case <-ticker.C: |
||||||
|
rl.cleanupExpiredBuckets() |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// cleanupExpiredBuckets 清理过期的IP桶
|
||||||
|
func (rl *RateLimiter) cleanupExpiredBuckets() { |
||||||
|
rl.mu.Lock() |
||||||
|
defer rl.mu.Unlock() |
||||||
|
|
||||||
|
now := time.Now() |
||||||
|
expiredIPs := make([]string, 0) |
||||||
|
|
||||||
|
for ip, bucket := range rl.perIPLimiters { |
||||||
|
// 如果桶超过10分钟没有活动,清理掉
|
||||||
|
if now.Sub(bucket.lastRefill) > 10*time.Minute { |
||||||
|
expiredIPs = append(expiredIPs, ip) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
for _, ip := range expiredIPs { |
||||||
|
delete(rl.perIPLimiters, ip) |
||||||
|
} |
||||||
|
|
||||||
|
if len(expiredIPs) > 0 { |
||||||
|
rl.logger.WithField("count", len(expiredIPs)).Debug("Cleaned up expired rate limit buckets") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// NewTokenBucket 创建新的Token桶
|
||||||
|
func NewTokenBucket(capacity, refillRate int64) *TokenBucket { |
||||||
|
return &TokenBucket{ |
||||||
|
capacity: capacity, |
||||||
|
tokens: capacity, |
||||||
|
refillRate: refillRate, |
||||||
|
lastRefill: time.Now(), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Allow 检查是否允许请求(消耗一个token)
|
||||||
|
func (tb *TokenBucket) Allow() bool { |
||||||
|
tb.mu.Lock() |
||||||
|
defer tb.mu.Unlock() |
||||||
|
|
||||||
|
tb.refill() |
||||||
|
|
||||||
|
if tb.tokens > 0 { |
||||||
|
tb.tokens-- |
||||||
|
return true |
||||||
|
} |
||||||
|
|
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
// refill 补充tokens
|
||||||
|
func (tb *TokenBucket) refill() { |
||||||
|
now := time.Now() |
||||||
|
elapsed := now.Sub(tb.lastRefill) |
||||||
|
|
||||||
|
// 计算应该补充的tokens
|
||||||
|
tokensToAdd := int64(elapsed.Seconds()) * tb.refillRate |
||||||
|
|
||||||
|
if tokensToAdd > 0 { |
||||||
|
tb.tokens += tokensToAdd |
||||||
|
if tb.tokens > tb.capacity { |
||||||
|
tb.tokens = tb.capacity |
||||||
|
} |
||||||
|
tb.lastRefill = now |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// GetTokens 获取当前token数量
|
||||||
|
func (tb *TokenBucket) GetTokens() int64 { |
||||||
|
tb.mu.Lock() |
||||||
|
defer tb.mu.Unlock() |
||||||
|
|
||||||
|
tb.refill() |
||||||
|
return tb.tokens |
||||||
|
} |
@ -0,0 +1,58 @@ |
|||||||
|
package socks5 |
||||||
|
|
||||||
|
import "strings" |
||||||
|
|
||||||
|
// SimpleAuthHandler 简单认证处理器
|
||||||
|
type SimpleAuthHandler struct { |
||||||
|
username string |
||||||
|
password string |
||||||
|
methods []byte |
||||||
|
} |
||||||
|
|
||||||
|
// NewAuthHandler 创建认证处理器
|
||||||
|
func NewAuthHandler(config AuthConfig) AuthHandler { |
||||||
|
handler := &SimpleAuthHandler{ |
||||||
|
username: config.Username, |
||||||
|
password: config.Password, |
||||||
|
} |
||||||
|
|
||||||
|
// 设置支持的认证方法
|
||||||
|
if config.Username != "" && config.Password != "" { |
||||||
|
handler.methods = []byte{AuthPassword} |
||||||
|
} else { |
||||||
|
handler.methods = []byte{AuthNone} |
||||||
|
} |
||||||
|
|
||||||
|
// 如果配置中指定了方法,使用配置的方法
|
||||||
|
if len(config.Methods) > 0 { |
||||||
|
handler.methods = []byte{} |
||||||
|
for _, method := range config.Methods { |
||||||
|
switch strings.ToLower(method) { |
||||||
|
case "none": |
||||||
|
handler.methods = append(handler.methods, AuthNone) |
||||||
|
case "password": |
||||||
|
handler.methods = append(handler.methods, AuthPassword) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return handler |
||||||
|
} |
||||||
|
|
||||||
|
// Authenticate 验证用户名和密码
|
||||||
|
func (h *SimpleAuthHandler) Authenticate(username, password string) bool { |
||||||
|
// 如果支持无认证,直接返回true
|
||||||
|
for _, method := range h.methods { |
||||||
|
if method == AuthNone { |
||||||
|
return true |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 检查用户名密码
|
||||||
|
return h.username == username && h.password == password |
||||||
|
} |
||||||
|
|
||||||
|
// Methods 返回支持的认证方法
|
||||||
|
func (h *SimpleAuthHandler) Methods() []byte { |
||||||
|
return h.methods |
||||||
|
} |
@ -0,0 +1,34 @@ |
|||||||
|
package socks5 |
||||||
|
|
||||||
|
import ( |
||||||
|
"net" |
||||||
|
"time" |
||||||
|
) |
||||||
|
|
||||||
|
// DirectDialer 直接连接拨号器
|
||||||
|
type DirectDialer struct { |
||||||
|
timeout time.Duration |
||||||
|
} |
||||||
|
|
||||||
|
// Dial 建立连接
|
||||||
|
func (d *DirectDialer) Dial(network, address string) (net.Conn, error) { |
||||||
|
if d.timeout > 0 { |
||||||
|
return net.DialTimeout(network, address, d.timeout) |
||||||
|
} |
||||||
|
return net.Dial(network, address) |
||||||
|
} |
||||||
|
|
||||||
|
// DirectResolver 直接DNS解析器
|
||||||
|
type DirectResolver struct{} |
||||||
|
|
||||||
|
// Resolve 解析域名
|
||||||
|
func (r *DirectResolver) Resolve(domain string) (net.IP, error) { |
||||||
|
ips, err := net.LookupIP(domain) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if len(ips) == 0 { |
||||||
|
return nil, net.ErrClosed |
||||||
|
} |
||||||
|
return ips[0], nil |
||||||
|
} |
@ -0,0 +1,114 @@ |
|||||||
|
package socks5 |
||||||
|
|
||||||
|
import ( |
||||||
|
"net" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
) |
||||||
|
|
||||||
|
// SimpleRule 简单访问规则
|
||||||
|
type SimpleRule struct { |
||||||
|
action string // allow, deny
|
||||||
|
networks []*net.IPNet |
||||||
|
ports []int |
||||||
|
} |
||||||
|
|
||||||
|
// NewRule 创建访问规则
|
||||||
|
func NewRule(config RuleConfig) Rule { |
||||||
|
rule := &SimpleRule{ |
||||||
|
action: strings.ToLower(config.Action), |
||||||
|
ports: config.Ports, |
||||||
|
} |
||||||
|
|
||||||
|
// 解析IP网段
|
||||||
|
for _, ipStr := range config.IPs { |
||||||
|
if !strings.Contains(ipStr, "/") { |
||||||
|
// 单个IP,添加适当的子网掩码
|
||||||
|
if strings.Contains(ipStr, ":") { |
||||||
|
ipStr += "/128" // IPv6
|
||||||
|
} else { |
||||||
|
ipStr += "/32" // IPv4
|
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
_, network, err := net.ParseCIDR(ipStr) |
||||||
|
if err != nil { |
||||||
|
// 如果解析失败,尝试作为单个IP处理
|
||||||
|
ip := net.ParseIP(ipStr) |
||||||
|
if ip != nil { |
||||||
|
if ip.To4() != nil { |
||||||
|
_, network, _ = net.ParseCIDR(ip.String() + "/32") |
||||||
|
} else { |
||||||
|
_, network, _ = net.ParseCIDR(ip.String() + "/128") |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if network != nil { |
||||||
|
rule.networks = append(rule.networks, network) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return rule |
||||||
|
} |
||||||
|
|
||||||
|
// Allow 检查是否允许访问
|
||||||
|
func (r *SimpleRule) Allow(addr net.Addr) bool { |
||||||
|
// 解析地址
|
||||||
|
var ip net.IP |
||||||
|
var port int |
||||||
|
|
||||||
|
switch a := addr.(type) { |
||||||
|
case *net.IPAddr: |
||||||
|
ip = a.IP |
||||||
|
case *net.TCPAddr: |
||||||
|
ip = a.IP |
||||||
|
port = a.Port |
||||||
|
case *net.UDPAddr: |
||||||
|
ip = a.IP |
||||||
|
port = a.Port |
||||||
|
default: |
||||||
|
// 尝试从字符串解析
|
||||||
|
host, portStr, err := net.SplitHostPort(addr.String()) |
||||||
|
if err != nil { |
||||||
|
return r.action == "allow" // 默认策略
|
||||||
|
} |
||||||
|
ip = net.ParseIP(host) |
||||||
|
if p, err := strconv.Atoi(portStr); err == nil { |
||||||
|
port = p |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if ip == nil { |
||||||
|
return r.action == "allow" // 默认策略
|
||||||
|
} |
||||||
|
|
||||||
|
// 检查IP是否匹配
|
||||||
|
ipMatches := len(r.networks) == 0 // 如果没有指定网段,默认匹配
|
||||||
|
for _, network := range r.networks { |
||||||
|
if network.Contains(ip) { |
||||||
|
ipMatches = true |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 检查端口是否匹配
|
||||||
|
portMatches := len(r.ports) == 0 // 如果没有指定端口,默认匹配
|
||||||
|
for _, p := range r.ports { |
||||||
|
if p == port { |
||||||
|
portMatches = true |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 根据规则类型返回结果
|
||||||
|
matches := ipMatches && portMatches |
||||||
|
switch r.action { |
||||||
|
case "allow": |
||||||
|
return matches |
||||||
|
case "deny": |
||||||
|
return !matches |
||||||
|
default: |
||||||
|
return true // 默认允许
|
||||||
|
} |
||||||
|
} |
@ -0,0 +1,432 @@ |
|||||||
|
package socks5 |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/binary" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/sirupsen/logrus" |
||||||
|
) |
||||||
|
|
||||||
|
// SOCKS5 协议常量
|
||||||
|
const ( |
||||||
|
// SOCKS版本
|
||||||
|
Version5 = 0x05 |
||||||
|
|
||||||
|
// 认证方法
|
||||||
|
AuthNone = 0x00 |
||||||
|
AuthPassword = 0x02 |
||||||
|
AuthNoSupported = 0xFF |
||||||
|
|
||||||
|
// 命令类型
|
||||||
|
CmdConnect = 0x01 |
||||||
|
CmdBind = 0x02 |
||||||
|
CmdUDP = 0x03 |
||||||
|
|
||||||
|
// 地址类型
|
||||||
|
AddrIPv4 = 0x01 |
||||||
|
AddrDomain = 0x03 |
||||||
|
AddrIPv6 = 0x04 |
||||||
|
|
||||||
|
// 响应状态
|
||||||
|
StatusSuccess = 0x00 |
||||||
|
StatusServerFailure = 0x01 |
||||||
|
StatusConnectionNotAllowed = 0x02 |
||||||
|
StatusNetworkUnreachable = 0x03 |
||||||
|
StatusHostUnreachable = 0x04 |
||||||
|
StatusConnectionRefused = 0x05 |
||||||
|
StatusTTLExpired = 0x06 |
||||||
|
StatusCommandNotSupported = 0x07 |
||||||
|
StatusAddressNotSupported = 0x08 |
||||||
|
) |
||||||
|
|
||||||
|
type Server struct { |
||||||
|
logger *logrus.Logger |
||||||
|
auth AuthHandler |
||||||
|
dialer Dialer |
||||||
|
resolver Resolver |
||||||
|
rules []Rule |
||||||
|
} |
||||||
|
|
||||||
|
type Config struct { |
||||||
|
Auth AuthConfig |
||||||
|
Timeout time.Duration |
||||||
|
Rules []RuleConfig |
||||||
|
} |
||||||
|
|
||||||
|
type AuthConfig struct { |
||||||
|
Methods []string |
||||||
|
Username string |
||||||
|
Password string |
||||||
|
} |
||||||
|
|
||||||
|
type RuleConfig struct { |
||||||
|
Action string // allow, deny
|
||||||
|
IPs []string |
||||||
|
Ports []int |
||||||
|
} |
||||||
|
|
||||||
|
type AuthHandler interface { |
||||||
|
Authenticate(username, password string) bool |
||||||
|
Methods() []byte |
||||||
|
} |
||||||
|
|
||||||
|
type Dialer interface { |
||||||
|
Dial(network, address string) (net.Conn, error) |
||||||
|
} |
||||||
|
|
||||||
|
type Resolver interface { |
||||||
|
Resolve(domain string) (net.IP, error) |
||||||
|
} |
||||||
|
|
||||||
|
type Rule interface { |
||||||
|
Allow(addr net.Addr) bool |
||||||
|
} |
||||||
|
|
||||||
|
// NewServer 创建新的SOCKS5服务器
|
||||||
|
func NewServer(config Config, logger *logrus.Logger) *Server { |
||||||
|
server := &Server{ |
||||||
|
logger: logger, |
||||||
|
dialer: &DirectDialer{timeout: config.Timeout}, |
||||||
|
resolver: &DirectResolver{}, |
||||||
|
} |
||||||
|
|
||||||
|
// 设置认证处理器
|
||||||
|
server.auth = NewAuthHandler(config.Auth) |
||||||
|
|
||||||
|
// 设置访问规则
|
||||||
|
for _, ruleConfig := range config.Rules { |
||||||
|
rule := NewRule(ruleConfig) |
||||||
|
server.rules = append(server.rules, rule) |
||||||
|
} |
||||||
|
|
||||||
|
return server |
||||||
|
} |
||||||
|
|
||||||
|
// HandleConnection 处理SOCKS5连接
|
||||||
|
func (s *Server) HandleConnection(conn net.Conn) error { |
||||||
|
defer conn.Close() |
||||||
|
|
||||||
|
s.logger.WithField("remote_addr", conn.RemoteAddr()).Debug("New SOCKS5 connection") |
||||||
|
|
||||||
|
// 设置连接超时
|
||||||
|
conn.SetDeadline(time.Now().Add(30 * time.Second)) |
||||||
|
|
||||||
|
// 1. 协议版本协商
|
||||||
|
if err := s.handleVersionNegotiation(conn); err != nil { |
||||||
|
s.logger.WithError(err).Error("Version negotiation failed") |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// 2. 认证
|
||||||
|
if err := s.handleAuthentication(conn); err != nil { |
||||||
|
s.logger.WithError(err).Error("Authentication failed") |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// 3. 处理请求
|
||||||
|
return s.handleRequest(conn) |
||||||
|
} |
||||||
|
|
||||||
|
// handleVersionNegotiation 处理版本协商
|
||||||
|
func (s *Server) handleVersionNegotiation(conn net.Conn) error { |
||||||
|
// 读取客户端版本协商请求
|
||||||
|
buf := make([]byte, 2) |
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil { |
||||||
|
return fmt.Errorf("failed to read version: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
version := buf[0] |
||||||
|
nmethods := buf[1] |
||||||
|
|
||||||
|
if version != Version5 { |
||||||
|
return fmt.Errorf("unsupported SOCKS version: %d", version) |
||||||
|
} |
||||||
|
|
||||||
|
// 读取支持的认证方法
|
||||||
|
methods := make([]byte, nmethods) |
||||||
|
if _, err := io.ReadFull(conn, methods); err != nil { |
||||||
|
return fmt.Errorf("failed to read methods: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
// 选择认证方法
|
||||||
|
authMethods := s.auth.Methods() |
||||||
|
selectedMethod := byte(AuthNoSupported) |
||||||
|
|
||||||
|
for _, method := range methods { |
||||||
|
for _, supported := range authMethods { |
||||||
|
if method == supported { |
||||||
|
selectedMethod = method |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
if selectedMethod != AuthNoSupported { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// 发送选择的认证方法
|
||||||
|
response := []byte{Version5, selectedMethod} |
||||||
|
if _, err := conn.Write(response); err != nil { |
||||||
|
return fmt.Errorf("failed to write method selection: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
if selectedMethod == AuthNoSupported { |
||||||
|
return fmt.Errorf("no supported authentication method") |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// handleAuthentication 处理认证
|
||||||
|
func (s *Server) handleAuthentication(conn net.Conn) error { |
||||||
|
// 读取认证请求
|
||||||
|
buf := make([]byte, 2) |
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil { |
||||||
|
return fmt.Errorf("failed to read auth version: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
version := buf[0] |
||||||
|
usernameLen := buf[1] |
||||||
|
|
||||||
|
if version != 0x01 { |
||||||
|
return fmt.Errorf("unsupported auth version: %d", version) |
||||||
|
} |
||||||
|
|
||||||
|
// 读取用户名
|
||||||
|
username := make([]byte, usernameLen) |
||||||
|
if _, err := io.ReadFull(conn, username); err != nil { |
||||||
|
return fmt.Errorf("failed to read username: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
// 读取密码长度
|
||||||
|
if _, err := io.ReadFull(conn, buf[:1]); err != nil { |
||||||
|
return fmt.Errorf("failed to read password length: %w", err) |
||||||
|
} |
||||||
|
passwordLen := buf[0] |
||||||
|
|
||||||
|
// 读取密码
|
||||||
|
password := make([]byte, passwordLen) |
||||||
|
if _, err := io.ReadFull(conn, password); err != nil { |
||||||
|
return fmt.Errorf("failed to read password: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
// 验证认证
|
||||||
|
success := s.auth.Authenticate(string(username), string(password)) |
||||||
|
|
||||||
|
// 发送认证结果
|
||||||
|
status := byte(0x01) // 失败
|
||||||
|
if success { |
||||||
|
status = 0x00 // 成功
|
||||||
|
} |
||||||
|
|
||||||
|
response := []byte{0x01, status} |
||||||
|
if _, err := conn.Write(response); err != nil { |
||||||
|
return fmt.Errorf("failed to write auth response: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
if !success { |
||||||
|
return fmt.Errorf("authentication failed") |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
// handleRequest 处理SOCKS5请求
|
||||||
|
func (s *Server) handleRequest(conn net.Conn) error { |
||||||
|
// 读取请求头
|
||||||
|
buf := make([]byte, 4) |
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil { |
||||||
|
return fmt.Errorf("failed to read request header: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
version := buf[0] |
||||||
|
cmd := buf[1] |
||||||
|
// rsv := buf[2] // 保留字段
|
||||||
|
addrType := buf[3] |
||||||
|
|
||||||
|
if version != Version5 { |
||||||
|
return fmt.Errorf("invalid SOCKS version: %d", version) |
||||||
|
} |
||||||
|
|
||||||
|
// 读取目标地址
|
||||||
|
addr, err := s.readAddress(conn, addrType) |
||||||
|
if err != nil { |
||||||
|
s.sendResponse(conn, StatusServerFailure, "0.0.0.0", 0) |
||||||
|
return fmt.Errorf("failed to read address: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
// 检查访问规则
|
||||||
|
if !s.checkRules(addr) { |
||||||
|
s.sendResponse(conn, StatusConnectionNotAllowed, "0.0.0.0", 0) |
||||||
|
return fmt.Errorf("connection not allowed: %s", addr) |
||||||
|
} |
||||||
|
|
||||||
|
// 处理不同的命令
|
||||||
|
switch cmd { |
||||||
|
case CmdConnect: |
||||||
|
return s.handleConnect(conn, addr) |
||||||
|
case CmdBind: |
||||||
|
return s.handleBind(conn, addr) |
||||||
|
case CmdUDP: |
||||||
|
return s.handleUDP(conn, addr) |
||||||
|
default: |
||||||
|
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0) |
||||||
|
return fmt.Errorf("unsupported command: %d", cmd) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// readAddress 读取地址信息
|
||||||
|
func (s *Server) readAddress(conn net.Conn, addrType byte) (string, error) { |
||||||
|
switch addrType { |
||||||
|
case AddrIPv4: |
||||||
|
buf := make([]byte, 6) // 4字节IP + 2字节端口
|
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
ip := net.IP(buf[:4]) |
||||||
|
port := binary.BigEndian.Uint16(buf[4:6]) |
||||||
|
return fmt.Sprintf("%s:%d", ip.String(), port), nil |
||||||
|
|
||||||
|
case AddrDomain: |
||||||
|
buf := make([]byte, 1) |
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
domainLen := buf[0] |
||||||
|
|
||||||
|
domain := make([]byte, domainLen+2) // 域名 + 2字节端口
|
||||||
|
if _, err := io.ReadFull(conn, domain); err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
port := binary.BigEndian.Uint16(domain[domainLen:]) |
||||||
|
return fmt.Sprintf("%s:%d", string(domain[:domainLen]), port), nil |
||||||
|
|
||||||
|
case AddrIPv6: |
||||||
|
buf := make([]byte, 18) // 16字节IP + 2字节端口
|
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
ip := net.IP(buf[:16]) |
||||||
|
port := binary.BigEndian.Uint16(buf[16:18]) |
||||||
|
return fmt.Sprintf("[%s]:%d", ip.String(), port), nil |
||||||
|
|
||||||
|
default: |
||||||
|
return "", fmt.Errorf("unsupported address type: %d", addrType) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// handleConnect 处理CONNECT命令
|
||||||
|
func (s *Server) handleConnect(conn net.Conn, addr string) error { |
||||||
|
s.logger.WithField("target", addr).Debug("Handling CONNECT request") |
||||||
|
|
||||||
|
// 连接到目标服务器
|
||||||
|
target, err := s.dialer.Dial("tcp", addr) |
||||||
|
if err != nil { |
||||||
|
s.logger.WithError(err).WithField("target", addr).Error("Failed to connect to target") |
||||||
|
s.sendResponse(conn, StatusConnectionRefused, "0.0.0.0", 0) |
||||||
|
return err |
||||||
|
} |
||||||
|
defer target.Close() |
||||||
|
|
||||||
|
// 发送成功响应
|
||||||
|
localAddr := target.LocalAddr().(*net.TCPAddr) |
||||||
|
s.sendResponse(conn, StatusSuccess, localAddr.IP.String(), uint16(localAddr.Port)) |
||||||
|
|
||||||
|
// 开始数据转发
|
||||||
|
s.logger.WithField("target", addr).Info("Starting data relay") |
||||||
|
return s.relay(conn, target) |
||||||
|
} |
||||||
|
|
||||||
|
// handleBind 处理BIND命令
|
||||||
|
func (s *Server) handleBind(conn net.Conn, addr string) error { |
||||||
|
// BIND命令实现(简化版本)
|
||||||
|
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0) |
||||||
|
return fmt.Errorf("BIND command not implemented") |
||||||
|
} |
||||||
|
|
||||||
|
// handleUDP 处理UDP命令
|
||||||
|
func (s *Server) handleUDP(conn net.Conn, addr string) error { |
||||||
|
// UDP关联实现(简化版本)
|
||||||
|
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0) |
||||||
|
return fmt.Errorf("UDP command not implemented") |
||||||
|
} |
||||||
|
|
||||||
|
// sendResponse 发送SOCKS5响应
|
||||||
|
func (s *Server) sendResponse(conn net.Conn, status byte, ip string, port uint16) error { |
||||||
|
response := []byte{ |
||||||
|
Version5, |
||||||
|
status, |
||||||
|
0x00, // 保留字段
|
||||||
|
AddrIPv4, |
||||||
|
} |
||||||
|
|
||||||
|
// 添加IP地址
|
||||||
|
ipAddr := net.ParseIP(ip) |
||||||
|
if ipAddr == nil { |
||||||
|
ipAddr = net.ParseIP("0.0.0.0") |
||||||
|
} |
||||||
|
if ipv4 := ipAddr.To4(); ipv4 != nil { |
||||||
|
response = append(response, ipv4...) |
||||||
|
} else { |
||||||
|
response[3] = AddrIPv6 |
||||||
|
response = append(response, ipAddr.To16()...) |
||||||
|
} |
||||||
|
|
||||||
|
// 添加端口
|
||||||
|
portBytes := make([]byte, 2) |
||||||
|
binary.BigEndian.PutUint16(portBytes, port) |
||||||
|
response = append(response, portBytes...) |
||||||
|
|
||||||
|
_, err := conn.Write(response) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// checkRules 检查访问规则
|
||||||
|
func (s *Server) checkRules(addr string) bool { |
||||||
|
if len(s.rules) == 0 { |
||||||
|
return true // 没有规则时允许所有连接
|
||||||
|
} |
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(addr) |
||||||
|
if err != nil { |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
targetAddr, err := net.ResolveIPAddr("ip", host) |
||||||
|
if err != nil { |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
for _, rule := range s.rules { |
||||||
|
if rule.Allow(targetAddr) { |
||||||
|
return true |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
// relay 数据转发
|
||||||
|
func (s *Server) relay(client, target net.Conn) error { |
||||||
|
// 创建双向数据转发
|
||||||
|
errChan := make(chan error, 2) |
||||||
|
|
||||||
|
// 客户端到目标服务器
|
||||||
|
go func() { |
||||||
|
_, err := io.Copy(target, client) |
||||||
|
errChan <- err |
||||||
|
}() |
||||||
|
|
||||||
|
// 目标服务器到客户端
|
||||||
|
go func() { |
||||||
|
_, err := io.Copy(client, target) |
||||||
|
errChan <- err |
||||||
|
}() |
||||||
|
|
||||||
|
// 等待任一方向的连接关闭
|
||||||
|
err := <-errChan |
||||||
|
return err |
||||||
|
} |
@ -0,0 +1,54 @@ |
|||||||
|
package socks5 |
||||||
|
|
||||||
|
import ( |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/azoic/wormhole-server/pkg/memory" |
||||||
|
"github.com/sirupsen/logrus" |
||||||
|
) |
||||||
|
|
||||||
|
func BenchmarkBufferPool(b *testing.B) { |
||||||
|
logger := logrus.New() |
||||||
|
logger.SetLevel(logrus.ErrorLevel) |
||||||
|
|
||||||
|
memConfig := memory.Config{ |
||||||
|
BufferSizes: []int{512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}, |
||||||
|
EnableOptimization: true, |
||||||
|
EnableAutoGC: false, |
||||||
|
MonitorInterval: time.Minute, |
||||||
|
} |
||||||
|
memManager := memory.NewManager(memConfig, logger) |
||||||
|
defer memManager.Stop() |
||||||
|
|
||||||
|
b.ResetTimer() |
||||||
|
b.RunParallel(func(pb *testing.PB) { |
||||||
|
for pb.Next() { |
||||||
|
// 测试不同大小的缓冲区
|
||||||
|
sizes := []int{512, 1024, 2048, 4096, 8192} |
||||||
|
for _, size := range sizes { |
||||||
|
buf := memManager.GetBuffer(size) |
||||||
|
memManager.PutBuffer(buf) |
||||||
|
} |
||||||
|
} |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func BenchmarkMemoryManagerStats(b *testing.B) { |
||||||
|
logger := logrus.New() |
||||||
|
logger.SetLevel(logrus.ErrorLevel) |
||||||
|
|
||||||
|
memConfig := memory.Config{ |
||||||
|
BufferSizes: []int{512, 1024, 2048, 4096}, |
||||||
|
EnableOptimization: true, |
||||||
|
EnableAutoGC: false, |
||||||
|
MonitorInterval: time.Minute, |
||||||
|
} |
||||||
|
memManager := memory.NewManager(memConfig, logger) |
||||||
|
defer memManager.Stop() |
||||||
|
|
||||||
|
b.ResetTimer() |
||||||
|
for i := 0; i < b.N; i++ { |
||||||
|
_ = memManager.GetOverallStats() |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,71 @@ |
|||||||
|
package socks5 |
||||||
|
|
||||||
|
import ( |
||||||
|
"testing" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/sirupsen/logrus" |
||||||
|
) |
||||||
|
|
||||||
|
func TestNewServer(t *testing.T) { |
||||||
|
logger := logrus.New() |
||||||
|
logger.SetLevel(logrus.ErrorLevel) // 减少测试输出
|
||||||
|
|
||||||
|
config := Config{ |
||||||
|
Auth: AuthConfig{ |
||||||
|
Username: "test", |
||||||
|
Password: "pass", |
||||||
|
}, |
||||||
|
Timeout: 30 * time.Second, |
||||||
|
} |
||||||
|
|
||||||
|
server := NewServer(config, logger) |
||||||
|
if server == nil { |
||||||
|
t.Fatal("Expected server to be created, got nil") |
||||||
|
} |
||||||
|
|
||||||
|
if server.logger != logger { |
||||||
|
t.Error("Logger not set correctly") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestAuthHandler(t *testing.T) { |
||||||
|
config := AuthConfig{ |
||||||
|
Username: "testuser", |
||||||
|
Password: "testpass", |
||||||
|
Methods: []string{"password"}, |
||||||
|
} |
||||||
|
|
||||||
|
handler := NewAuthHandler(config) |
||||||
|
|
||||||
|
// 测试正确的认证
|
||||||
|
if !handler.Authenticate("testuser", "testpass") { |
||||||
|
t.Error("Valid authentication failed") |
||||||
|
} |
||||||
|
|
||||||
|
// 测试错误的认证
|
||||||
|
if handler.Authenticate("wronguser", "wrongpass") { |
||||||
|
t.Error("Invalid authentication succeeded") |
||||||
|
} |
||||||
|
|
||||||
|
// 检查支持的方法
|
||||||
|
methods := handler.Methods() |
||||||
|
if len(methods) != 1 || methods[0] != AuthPassword { |
||||||
|
t.Error("Expected password authentication method") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestRule(t *testing.T) { |
||||||
|
config := RuleConfig{ |
||||||
|
Action: "allow", |
||||||
|
IPs: []string{"192.168.1.0/24", "127.0.0.1"}, |
||||||
|
Ports: []int{80, 443}, |
||||||
|
} |
||||||
|
|
||||||
|
rule := NewRule(config) |
||||||
|
if rule == nil { |
||||||
|
t.Fatal("Expected rule to be created, got nil") |
||||||
|
} |
||||||
|
|
||||||
|
// 这里可以添加更多的规则测试
|
||||||
|
} |
@ -1,368 +0,0 @@ |
|||||||
package system |
|
||||||
|
|
||||||
import ( |
|
||||||
"fmt" |
|
||||||
"os" |
|
||||||
"os/exec" |
|
||||||
"runtime" |
|
||||||
"strings" |
|
||||||
|
|
||||||
"github.com/sirupsen/logrus" |
|
||||||
) |
|
||||||
|
|
||||||
type SystemProxy struct { |
|
||||||
logger *logrus.Logger |
|
||||||
proxyAddr string |
|
||||||
backupConfig map[string]string |
|
||||||
} |
|
||||||
|
|
||||||
type Config struct { |
|
||||||
HTTPProxy string |
|
||||||
HTTPSProxy string |
|
||||||
SOCKSProxy string |
|
||||||
NoProxy []string |
|
||||||
} |
|
||||||
|
|
||||||
func NewSystemProxy(proxyAddr string, logger *logrus.Logger) *SystemProxy { |
|
||||||
return &SystemProxy{ |
|
||||||
logger: logger, |
|
||||||
proxyAddr: proxyAddr, |
|
||||||
backupConfig: make(map[string]string), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// SetSystemProxy 设置系统代理
|
|
||||||
func (sp *SystemProxy) SetSystemProxy(config Config) error { |
|
||||||
sp.logger.Info("Setting system proxy configuration") |
|
||||||
|
|
||||||
// 备份当前配置
|
|
||||||
if err := sp.backupCurrentConfig(); err != nil { |
|
||||||
sp.logger.WithError(err).Warn("Failed to backup current proxy config") |
|
||||||
} |
|
||||||
|
|
||||||
switch runtime.GOOS { |
|
||||||
case "darwin": |
|
||||||
return sp.setMacOSProxy(config) |
|
||||||
case "linux": |
|
||||||
return sp.setLinuxProxy(config) |
|
||||||
case "windows": |
|
||||||
return sp.setWindowsProxy(config) |
|
||||||
default: |
|
||||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// RestoreSystemProxy 恢复系统代理设置
|
|
||||||
func (sp *SystemProxy) RestoreSystemProxy() error { |
|
||||||
sp.logger.Info("Restoring system proxy configuration") |
|
||||||
|
|
||||||
switch runtime.GOOS { |
|
||||||
case "darwin": |
|
||||||
return sp.restoreMacOSProxy() |
|
||||||
case "linux": |
|
||||||
return sp.restoreLinuxProxy() |
|
||||||
case "windows": |
|
||||||
return sp.restoreWindowsProxy() |
|
||||||
default: |
|
||||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// setMacOSProxy 设置macOS系统代理
|
|
||||||
func (sp *SystemProxy) setMacOSProxy(config Config) error { |
|
||||||
// 获取所有网络服务
|
|
||||||
services, err := sp.getMacOSNetworkServices() |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed to get network services: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
for _, service := range services { |
|
||||||
// 设置HTTP代理
|
|
||||||
if config.HTTPProxy != "" { |
|
||||||
parts := strings.Split(config.HTTPProxy, ":") |
|
||||||
if len(parts) == 2 { |
|
||||||
if err := sp.runCommand("networksetup", "-setwebproxy", service, parts[0], parts[1]); err != nil { |
|
||||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to set HTTP proxy") |
|
||||||
} |
|
||||||
if err := sp.runCommand("networksetup", "-setwebproxystate", service, "on"); err != nil { |
|
||||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to enable HTTP proxy") |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// 设置HTTPS代理
|
|
||||||
if config.HTTPSProxy != "" { |
|
||||||
parts := strings.Split(config.HTTPSProxy, ":") |
|
||||||
if len(parts) == 2 { |
|
||||||
if err := sp.runCommand("networksetup", "-setsecurewebproxy", service, parts[0], parts[1]); err != nil { |
|
||||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to set HTTPS proxy") |
|
||||||
} |
|
||||||
if err := sp.runCommand("networksetup", "-setsecurewebproxystate", service, "on"); err != nil { |
|
||||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to enable HTTPS proxy") |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// 设置SOCKS代理
|
|
||||||
if config.SOCKSProxy != "" { |
|
||||||
parts := strings.Split(config.SOCKSProxy, ":") |
|
||||||
if len(parts) == 2 { |
|
||||||
if err := sp.runCommand("networksetup", "-setsocksfirewallproxy", service, parts[0], parts[1]); err != nil { |
|
||||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to set SOCKS proxy") |
|
||||||
} |
|
||||||
if err := sp.runCommand("networksetup", "-setsocksfirewallproxystate", service, "on"); err != nil { |
|
||||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to enable SOCKS proxy") |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// 设置代理绕过列表
|
|
||||||
if len(config.NoProxy) > 0 { |
|
||||||
bypassList := strings.Join(config.NoProxy, " ") |
|
||||||
if err := sp.runCommand("networksetup", "-setproxybypassdomains", service, bypassList); err != nil { |
|
||||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to set proxy bypass list") |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
// setLinuxProxy 设置Linux系统代理(通过环境变量)
|
|
||||||
func (sp *SystemProxy) setLinuxProxy(config Config) error { |
|
||||||
envVars := make(map[string]string) |
|
||||||
|
|
||||||
if config.HTTPProxy != "" { |
|
||||||
envVars["http_proxy"] = "http://" + config.HTTPProxy |
|
||||||
envVars["HTTP_PROXY"] = "http://" + config.HTTPProxy |
|
||||||
} |
|
||||||
|
|
||||||
if config.HTTPSProxy != "" { |
|
||||||
envVars["https_proxy"] = "https://" + config.HTTPSProxy |
|
||||||
envVars["HTTPS_PROXY"] = "https://" + config.HTTPSProxy |
|
||||||
} |
|
||||||
|
|
||||||
if len(config.NoProxy) > 0 { |
|
||||||
noProxy := strings.Join(config.NoProxy, ",") |
|
||||||
envVars["no_proxy"] = noProxy |
|
||||||
envVars["NO_PROXY"] = noProxy |
|
||||||
} |
|
||||||
|
|
||||||
// 写入到用户的shell配置文件
|
|
||||||
return sp.writeLinuxProxyConfig(envVars) |
|
||||||
} |
|
||||||
|
|
||||||
// setWindowsProxy 设置Windows系统代理
|
|
||||||
func (sp *SystemProxy) setWindowsProxy(config Config) error { |
|
||||||
// Windows代理设置通过注册表
|
|
||||||
if config.HTTPProxy != "" { |
|
||||||
// 启用代理
|
|
||||||
if err := sp.runCommand("reg", "add", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyEnable", "/t", "REG_DWORD", "/d", "1", "/f"); err != nil { |
|
||||||
return fmt.Errorf("failed to enable proxy: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
// 设置代理服务器
|
|
||||||
if err := sp.runCommand("reg", "add", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyServer", "/t", "REG_SZ", "/d", config.HTTPProxy, "/f"); err != nil { |
|
||||||
return fmt.Errorf("failed to set proxy server: %w", err) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// 设置代理绕过列表
|
|
||||||
if len(config.NoProxy) > 0 { |
|
||||||
bypassList := strings.Join(config.NoProxy, ";") |
|
||||||
if err := sp.runCommand("reg", "add", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyOverride", "/t", "REG_SZ", "/d", bypassList, "/f"); err != nil { |
|
||||||
return fmt.Errorf("failed to set proxy bypass list: %w", err) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) getMacOSNetworkServices() ([]string, error) { |
|
||||||
output, err := exec.Command("networksetup", "-listallnetworkservices").Output() |
|
||||||
if err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
lines := strings.Split(string(output), "\n") |
|
||||||
var services []string |
|
||||||
|
|
||||||
for _, line := range lines { |
|
||||||
line = strings.TrimSpace(line) |
|
||||||
if line != "" && !strings.HasPrefix(line, "*") && line != "An asterisk (*) denotes that a network service is disabled." { |
|
||||||
services = append(services, line) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return services, nil |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) backupCurrentConfig() error { |
|
||||||
switch runtime.GOOS { |
|
||||||
case "darwin": |
|
||||||
return sp.backupMacOSConfig() |
|
||||||
case "linux": |
|
||||||
return sp.backupLinuxConfig() |
|
||||||
case "windows": |
|
||||||
return sp.backupWindowsConfig() |
|
||||||
} |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) backupMacOSConfig() error { |
|
||||||
services, err := sp.getMacOSNetworkServices() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
for _, service := range services { |
|
||||||
// 备份HTTP代理设置
|
|
||||||
output, _ := exec.Command("networksetup", "-getwebproxy", service).Output() |
|
||||||
sp.backupConfig[fmt.Sprintf("http_%s", service)] = string(output) |
|
||||||
|
|
||||||
// 备份HTTPS代理设置
|
|
||||||
output, _ = exec.Command("networksetup", "-getsecurewebproxy", service).Output() |
|
||||||
sp.backupConfig[fmt.Sprintf("https_%s", service)] = string(output) |
|
||||||
|
|
||||||
// 备份SOCKS代理设置
|
|
||||||
output, _ = exec.Command("networksetup", "-getsocksfirewallproxy", service).Output() |
|
||||||
sp.backupConfig[fmt.Sprintf("socks_%s", service)] = string(output) |
|
||||||
} |
|
||||||
|
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) backupLinuxConfig() error { |
|
||||||
// 备份环境变量
|
|
||||||
sp.backupConfig["http_proxy"] = os.Getenv("http_proxy") |
|
||||||
sp.backupConfig["https_proxy"] = os.Getenv("https_proxy") |
|
||||||
sp.backupConfig["no_proxy"] = os.Getenv("no_proxy") |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) backupWindowsConfig() error { |
|
||||||
// 备份Windows注册表设置
|
|
||||||
output, _ := exec.Command("reg", "query", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyEnable").Output() |
|
||||||
sp.backupConfig["ProxyEnable"] = string(output) |
|
||||||
|
|
||||||
output, _ = exec.Command("reg", "query", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyServer").Output() |
|
||||||
sp.backupConfig["ProxyServer"] = string(output) |
|
||||||
|
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) restoreMacOSProxy() error { |
|
||||||
services, err := sp.getMacOSNetworkServices() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
for _, service := range services { |
|
||||||
// 禁用所有代理
|
|
||||||
sp.runCommand("networksetup", "-setwebproxystate", service, "off") |
|
||||||
sp.runCommand("networksetup", "-setsecurewebproxystate", service, "off") |
|
||||||
sp.runCommand("networksetup", "-setsocksfirewallproxystate", service, "off") |
|
||||||
} |
|
||||||
|
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) restoreLinuxProxy() error { |
|
||||||
// 清除环境变量(这里简化处理)
|
|
||||||
os.Unsetenv("http_proxy") |
|
||||||
os.Unsetenv("https_proxy") |
|
||||||
os.Unsetenv("no_proxy") |
|
||||||
os.Unsetenv("HTTP_PROXY") |
|
||||||
os.Unsetenv("HTTPS_PROXY") |
|
||||||
os.Unsetenv("NO_PROXY") |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) restoreWindowsProxy() error { |
|
||||||
// 禁用代理
|
|
||||||
return sp.runCommand("reg", "add", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyEnable", "/t", "REG_DWORD", "/d", "0", "/f") |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) writeLinuxProxyConfig(envVars map[string]string) error { |
|
||||||
// 写入到 ~/.bashrc 和 ~/.profile
|
|
||||||
homeDir, err := os.UserHomeDir() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
configFiles := []string{ |
|
||||||
homeDir + "/.bashrc", |
|
||||||
homeDir + "/.profile", |
|
||||||
} |
|
||||||
|
|
||||||
proxyLines := []string{"\n# Wormhole SOCKS5 Proxy Configuration"} |
|
||||||
for key, value := range envVars { |
|
||||||
proxyLines = append(proxyLines, fmt.Sprintf("export %s=%s", key, value)) |
|
||||||
} |
|
||||||
proxyLines = append(proxyLines, "# End Wormhole Configuration\n") |
|
||||||
|
|
||||||
for _, configFile := range configFiles { |
|
||||||
file, err := os.OpenFile(configFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) |
|
||||||
if err != nil { |
|
||||||
sp.logger.WithError(err).WithField("file", configFile).Warn("Failed to open config file") |
|
||||||
continue |
|
||||||
} |
|
||||||
|
|
||||||
for _, line := range proxyLines { |
|
||||||
if _, err := file.WriteString(line + "\n"); err != nil { |
|
||||||
sp.logger.WithError(err).WithField("file", configFile).Warn("Failed to write to config file") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
file.Close() |
|
||||||
} |
|
||||||
|
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (sp *SystemProxy) runCommand(name string, args ...string) error { |
|
||||||
cmd := exec.Command(name, args...) |
|
||||||
output, err := cmd.CombinedOutput() |
|
||||||
if err != nil { |
|
||||||
sp.logger.WithFields(logrus.Fields{ |
|
||||||
"command": fmt.Sprintf("%s %s", name, strings.Join(args, " ")), |
|
||||||
"output": string(output), |
|
||||||
}).WithError(err).Debug("Command execution failed") |
|
||||||
return err |
|
||||||
} |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
// GetCurrentConfig 获取当前系统代理配置
|
|
||||||
func (sp *SystemProxy) GetCurrentConfig() (Config, error) { |
|
||||||
config := Config{ |
|
||||||
NoProxy: []string{}, |
|
||||||
} |
|
||||||
|
|
||||||
switch runtime.GOOS { |
|
||||||
case "linux": |
|
||||||
config.HTTPProxy = os.Getenv("http_proxy") |
|
||||||
config.HTTPSProxy = os.Getenv("https_proxy") |
|
||||||
noProxy := os.Getenv("no_proxy") |
|
||||||
if noProxy != "" { |
|
||||||
config.NoProxy = strings.Split(noProxy, ",") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return config, nil |
|
||||||
} |
|
||||||
|
|
||||||
// IsProxySet 检查是否已设置代理
|
|
||||||
func (sp *SystemProxy) IsProxySet() bool { |
|
||||||
switch runtime.GOOS { |
|
||||||
case "linux": |
|
||||||
return os.Getenv("http_proxy") != "" || os.Getenv("https_proxy") != "" |
|
||||||
case "darwin": |
|
||||||
// 检查macOS代理设置(简化)
|
|
||||||
return false |
|
||||||
case "windows": |
|
||||||
// 检查Windows代理设置(简化)
|
|
||||||
return false |
|
||||||
} |
|
||||||
return false |
|
||||||
} |
|
Loading…
Reference in new issue