parent
77128910b3
commit
52f0fba78d
@ -0,0 +1,289 @@ |
||||
package config |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"time" |
||||
|
||||
"github.com/sirupsen/logrus" |
||||
"github.com/spf13/viper" |
||||
) |
||||
|
||||
// ServerConfig 服务器配置
|
||||
type ServerConfig struct { |
||||
ServiceType string `mapstructure:"serviceType"` |
||||
|
||||
Proxy ProxyConfig `mapstructure:"proxy"` |
||||
Auth AuthConfig `mapstructure:"auth"` |
||||
|
||||
Timeout time.Duration `mapstructure:"timeout"` |
||||
MaxConns int `mapstructure:"maxConns"` |
||||
LogLevel string `mapstructure:"logLevel"` |
||||
|
||||
HealthCheck HealthCheckConfig `mapstructure:"healthCheck"` |
||||
|
||||
OptimizedServer OptimizedServerConfig `mapstructure:"optimizedServer"` |
||||
} |
||||
|
||||
// ProxyConfig 代理配置
|
||||
type ProxyConfig struct { |
||||
Address string `mapstructure:"address"` |
||||
Port int `mapstructure:"port"` |
||||
} |
||||
|
||||
// AuthConfig 认证配置
|
||||
type AuthConfig struct { |
||||
Username string `mapstructure:"username"` |
||||
Password string `mapstructure:"password"` |
||||
Methods []string `mapstructure:"methods"` |
||||
} |
||||
|
||||
// HealthCheckConfig 健康检查配置
|
||||
type HealthCheckConfig struct { |
||||
Enabled bool `mapstructure:"enabled"` |
||||
Address string `mapstructure:"address"` |
||||
Port int `mapstructure:"port"` |
||||
} |
||||
|
||||
// OptimizedServerConfig 优化服务器配置
|
||||
type OptimizedServerConfig struct { |
||||
Enabled bool `mapstructure:"enabled"` |
||||
MaxIdleTime time.Duration `mapstructure:"maxIdleTime"` |
||||
BufferSize int `mapstructure:"bufferSize"` |
||||
LogConnections bool `mapstructure:"logConnections"` |
||||
|
||||
DNSCache DNSCacheConfig `mapstructure:"dnsCache"` |
||||
RateLimit RateLimitConfig `mapstructure:"rateLimit"` |
||||
AccessControl AccessControlConfig `mapstructure:"accessControl"` |
||||
Metrics MetricsConfig `mapstructure:"metrics"` |
||||
ConnectionPool ConnectionPoolConfig `mapstructure:"connectionPool"` |
||||
Memory MemoryConfig `mapstructure:"memory"` |
||||
Transparent TransparentConfig `mapstructure:"transparent"` |
||||
} |
||||
|
||||
// DNSCacheConfig DNS缓存配置
|
||||
type DNSCacheConfig struct { |
||||
Enabled bool `mapstructure:"enabled"` |
||||
MaxSize int `mapstructure:"maxSize"` |
||||
TTL time.Duration `mapstructure:"ttl"` |
||||
} |
||||
|
||||
// RateLimitConfig 速率限制配置
|
||||
type RateLimitConfig struct { |
||||
Enabled bool `mapstructure:"enabled"` |
||||
RequestsPerSecond int `mapstructure:"requestsPerSecond"` |
||||
BurstSize int `mapstructure:"burstSize"` |
||||
PerIPRequestsPerSec int `mapstructure:"perIPRequestsPerSec"` |
||||
PerIPBurstSize int `mapstructure:"perIPBurstSize"` |
||||
CleanupInterval time.Duration `mapstructure:"cleanupInterval"` |
||||
} |
||||
|
||||
// AccessControlConfig 访问控制配置
|
||||
type AccessControlConfig struct { |
||||
AllowedIPs []string `mapstructure:"allowedIPs"` |
||||
} |
||||
|
||||
// MetricsConfig 指标配置
|
||||
type MetricsConfig struct { |
||||
Enabled bool `mapstructure:"enabled"` |
||||
Interval time.Duration `mapstructure:"interval"` |
||||
} |
||||
|
||||
// ConnectionPoolConfig 连接池配置
|
||||
type ConnectionPoolConfig struct { |
||||
Enabled bool `mapstructure:"enabled"` |
||||
MaxSize int `mapstructure:"maxSize"` |
||||
MaxLifetime time.Duration `mapstructure:"maxLifetime"` |
||||
MaxIdle time.Duration `mapstructure:"maxIdle"` |
||||
InitialSize int `mapstructure:"initialSize"` |
||||
} |
||||
|
||||
// MemoryConfig 内存优化配置
|
||||
type MemoryConfig struct { |
||||
Enabled bool `mapstructure:"enabled"` |
||||
BufferSizes []int `mapstructure:"bufferSizes"` |
||||
MonitorInterval time.Duration `mapstructure:"monitorInterval"` |
||||
EnableAutoGC bool `mapstructure:"enableAutoGC"` |
||||
HeapAllocThresholdMB int64 `mapstructure:"heapAllocThresholdMB"` |
||||
HeapSysThresholdMB int64 `mapstructure:"heapSysThresholdMB"` |
||||
ForceGCThresholdMB int64 `mapstructure:"forceGCThresholdMB"` |
||||
} |
||||
|
||||
// TransparentConfig 透明代理配置
|
||||
type TransparentConfig struct { |
||||
Enabled bool `mapstructure:"enabled"` |
||||
TransparentPort int `mapstructure:"transparentPort"` |
||||
DNSPort int `mapstructure:"dnsPort"` |
||||
BypassIPs []string `mapstructure:"bypassIPs"` |
||||
BypassDomains []string `mapstructure:"bypassDomains"` |
||||
} |
||||
|
||||
// LoadConfig 加载配置文件
|
||||
func LoadConfig(configPath string) (*ServerConfig, error) { |
||||
// 检查配置文件是否存在
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) { |
||||
return nil, fmt.Errorf("config file not found: %s", configPath) |
||||
} |
||||
|
||||
// 初始化viper
|
||||
viper.SetConfigFile(configPath) |
||||
viper.SetConfigType("yaml") |
||||
|
||||
// 设置默认值
|
||||
setDefaults() |
||||
|
||||
// 读取配置文件
|
||||
if err := viper.ReadInConfig(); err != nil { |
||||
return nil, fmt.Errorf("failed to read config file: %w", err) |
||||
} |
||||
|
||||
// 解析配置
|
||||
var config ServerConfig |
||||
if err := viper.Unmarshal(&config); err != nil { |
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err) |
||||
} |
||||
|
||||
// 验证配置
|
||||
if err := validateConfig(&config); err != nil { |
||||
return nil, fmt.Errorf("invalid config: %w", err) |
||||
} |
||||
|
||||
return &config, nil |
||||
} |
||||
|
||||
// setDefaults 设置默认配置值
|
||||
func setDefaults() { |
||||
// 基本配置默认值
|
||||
viper.SetDefault("serviceType", "server") |
||||
viper.SetDefault("proxy.address", "0.0.0.0") |
||||
viper.SetDefault("proxy.port", 1080) |
||||
viper.SetDefault("timeout", "30s") |
||||
viper.SetDefault("maxConns", 5000) |
||||
viper.SetDefault("logLevel", "info") |
||||
|
||||
// 健康检查默认值
|
||||
viper.SetDefault("healthCheck.enabled", true) |
||||
viper.SetDefault("healthCheck.address", "127.0.0.1") |
||||
viper.SetDefault("healthCheck.port", 8090) |
||||
|
||||
// 优化服务器默认值
|
||||
viper.SetDefault("optimizedServer.enabled", true) |
||||
viper.SetDefault("optimizedServer.maxIdleTime", "5m") |
||||
viper.SetDefault("optimizedServer.bufferSize", 65536) |
||||
viper.SetDefault("optimizedServer.logConnections", true) |
||||
|
||||
// DNS缓存默认值
|
||||
viper.SetDefault("optimizedServer.dnsCache.enabled", true) |
||||
viper.SetDefault("optimizedServer.dnsCache.maxSize", 10000) |
||||
viper.SetDefault("optimizedServer.dnsCache.ttl", "10m") |
||||
|
||||
// 速率限制默认值
|
||||
viper.SetDefault("optimizedServer.rateLimit.enabled", true) |
||||
viper.SetDefault("optimizedServer.rateLimit.requestsPerSecond", 100) |
||||
|
||||
// 指标默认值
|
||||
viper.SetDefault("optimizedServer.metrics.enabled", true) |
||||
viper.SetDefault("optimizedServer.metrics.interval", "5m") |
||||
|
||||
// 连接池默认值
|
||||
viper.SetDefault("optimizedServer.connectionPool.enabled", true) |
||||
viper.SetDefault("optimizedServer.connectionPool.maxSize", 1000) |
||||
viper.SetDefault("optimizedServer.connectionPool.maxLifetime", "30m") |
||||
viper.SetDefault("optimizedServer.connectionPool.maxIdle", "5m") |
||||
viper.SetDefault("optimizedServer.connectionPool.initialSize", 100) |
||||
|
||||
// 内存优化默认值
|
||||
viper.SetDefault("optimizedServer.memory.enabled", true) |
||||
viper.SetDefault("optimizedServer.memory.bufferSizes", []int{64, 128, 256, 512, 1024}) |
||||
viper.SetDefault("optimizedServer.memory.monitorInterval", "5m") |
||||
viper.SetDefault("optimizedServer.memory.enableAutoGC", true) |
||||
viper.SetDefault("optimizedServer.memory.heapAllocThresholdMB", 1024) |
||||
viper.SetDefault("optimizedServer.memory.heapSysThresholdMB", 2048) |
||||
viper.SetDefault("optimizedServer.memory.forceGCThresholdMB", 512) |
||||
|
||||
// 透明代理默认值
|
||||
viper.SetDefault("optimizedServer.transparent.enabled", false) |
||||
viper.SetDefault("optimizedServer.transparent.transparentPort", 8080) |
||||
viper.SetDefault("optimizedServer.transparent.dnsPort", 53) |
||||
viper.SetDefault("optimizedServer.transparent.bypassIPs", []string{}) |
||||
viper.SetDefault("optimizedServer.transparent.bypassDomains", []string{}) |
||||
} |
||||
|
||||
// validateConfig 验证配置
|
||||
func validateConfig(config *ServerConfig) error { |
||||
// 验证端口范围
|
||||
if config.Proxy.Port < 1 || config.Proxy.Port > 65535 { |
||||
return fmt.Errorf("invalid proxy port: %d", config.Proxy.Port) |
||||
} |
||||
|
||||
if config.HealthCheck.Enabled { |
||||
if config.HealthCheck.Port < 1 || config.HealthCheck.Port > 65535 { |
||||
return fmt.Errorf("invalid health check port: %d", config.HealthCheck.Port) |
||||
} |
||||
} |
||||
|
||||
// 验证认证配置
|
||||
if config.Auth.Username != "" && config.Auth.Password == "" { |
||||
return fmt.Errorf("password is required when username is set") |
||||
} |
||||
|
||||
// 验证日志级别
|
||||
switch config.LogLevel { |
||||
case "debug", "info", "warn", "error": |
||||
// 有效的日志级别
|
||||
default: |
||||
return fmt.Errorf("invalid log level: %s", config.LogLevel) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// GetLogLevel 获取logrus日志级别
|
||||
func GetLogLevel(level string) logrus.Level { |
||||
switch level { |
||||
case "debug": |
||||
return logrus.DebugLevel |
||||
case "info": |
||||
return logrus.InfoLevel |
||||
case "warn": |
||||
return logrus.WarnLevel |
||||
case "error": |
||||
return logrus.ErrorLevel |
||||
default: |
||||
return logrus.InfoLevel |
||||
} |
||||
} |
||||
|
||||
// ToSOCKS5Config 转换为SOCKS5配置
|
||||
func (c *ServerConfig) ToSOCKS5Config() SOCKS5Config { |
||||
return SOCKS5Config{ |
||||
Auth: SOCKS5AuthConfig{ |
||||
Methods: c.Auth.Methods, |
||||
Username: c.Auth.Username, |
||||
Password: c.Auth.Password, |
||||
}, |
||||
Timeout: c.Timeout, |
||||
Rules: []SOCKS5RuleConfig{}, // 从访问控制配置转换
|
||||
} |
||||
} |
||||
|
||||
// SOCKS5Config SOCKS5特定配置
|
||||
type SOCKS5Config struct { |
||||
Auth SOCKS5AuthConfig `json:"auth"` |
||||
Timeout time.Duration `json:"timeout"` |
||||
Rules []SOCKS5RuleConfig `json:"rules"` |
||||
} |
||||
|
||||
// SOCKS5AuthConfig SOCKS5认证配置
|
||||
type SOCKS5AuthConfig struct { |
||||
Methods []string `json:"methods"` |
||||
Username string `json:"username"` |
||||
Password string `json:"password"` |
||||
} |
||||
|
||||
// SOCKS5RuleConfig SOCKS5规则配置
|
||||
type SOCKS5RuleConfig struct { |
||||
Action string `json:"action"` |
||||
IPs []string `json:"ips"` |
||||
Ports []int `json:"ports"` |
||||
} |
@ -0,0 +1,379 @@ |
||||
package memory |
||||
|
||||
import ( |
||||
"runtime" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
// BufferPool 缓冲区池
|
||||
type BufferPool struct { |
||||
pools map[int]*sync.Pool |
||||
sizes []int |
||||
logger *logrus.Logger |
||||
stats BufferStats |
||||
} |
||||
|
||||
// BufferStats 缓冲区统计
|
||||
type BufferStats struct { |
||||
mu sync.RWMutex |
||||
Gets int64 `json:"gets"` |
||||
Puts int64 `json:"puts"` |
||||
News int64 `json:"news"` |
||||
Reuses int64 `json:"reuses"` |
||||
TotalAlloc int64 `json:"totalAlloc"` |
||||
TotalReused int64 `json:"totalReused"` |
||||
} |
||||
|
||||
// MemoryMonitor 内存监控器
|
||||
type MemoryMonitor struct { |
||||
logger *logrus.Logger |
||||
ticker *time.Ticker |
||||
stopCh chan struct{} |
||||
thresholds Thresholds |
||||
lastStats runtime.MemStats |
||||
callbacks []MemoryCallback |
||||
mu sync.RWMutex |
||||
} |
||||
|
||||
// Thresholds 内存阈值配置
|
||||
type Thresholds struct { |
||||
HeapAllocMB int64 `json:"heapAllocMB"` |
||||
HeapSysMB int64 `json:"heapSysMB"` |
||||
GCPercent int `json:"gcPercent"` |
||||
ForceGCThreshMB int64 `json:"forceGCThreshMB"` |
||||
} |
||||
|
||||
// MemoryCallback 内存回调函数
|
||||
type MemoryCallback func(stats runtime.MemStats) |
||||
|
||||
// Config 内存管理配置
|
||||
type Config struct { |
||||
BufferSizes []int `json:"bufferSizes"` |
||||
MonitorInterval time.Duration `json:"monitorInterval"` |
||||
Thresholds Thresholds `json:"thresholds"` |
||||
EnableAutoGC bool `json:"enableAutoGC"` |
||||
EnableOptimization bool `json:"enableOptimization"` |
||||
} |
||||
|
||||
// Manager 内存管理器
|
||||
type Manager struct { |
||||
bufferPool *BufferPool |
||||
monitor *MemoryMonitor |
||||
config Config |
||||
logger *logrus.Logger |
||||
} |
||||
|
||||
// NewManager 创建内存管理器
|
||||
func NewManager(config Config, logger *logrus.Logger) *Manager { |
||||
// 设置默认值
|
||||
if len(config.BufferSizes) == 0 { |
||||
config.BufferSizes = []int{512, 1024, 2048, 4096, 8192, 16384, 32768, 65536} |
||||
} |
||||
if config.MonitorInterval == 0 { |
||||
config.MonitorInterval = 30 * time.Second |
||||
} |
||||
if config.Thresholds.HeapAllocMB == 0 { |
||||
config.Thresholds.HeapAllocMB = 100 |
||||
} |
||||
if config.Thresholds.HeapSysMB == 0 { |
||||
config.Thresholds.HeapSysMB = 200 |
||||
} |
||||
if config.Thresholds.GCPercent == 0 { |
||||
config.Thresholds.GCPercent = 100 |
||||
} |
||||
if config.Thresholds.ForceGCThreshMB == 0 { |
||||
config.Thresholds.ForceGCThreshMB = 500 |
||||
} |
||||
|
||||
manager := &Manager{ |
||||
config: config, |
||||
logger: logger, |
||||
} |
||||
|
||||
// 创建缓冲区池
|
||||
if config.EnableOptimization { |
||||
manager.bufferPool = NewBufferPool(config.BufferSizes, logger) |
||||
} |
||||
|
||||
// 创建内存监控器
|
||||
manager.monitor = NewMemoryMonitor(config.MonitorInterval, config.Thresholds, logger) |
||||
|
||||
// 启用自动GC
|
||||
if config.EnableAutoGC { |
||||
manager.monitor.AddCallback(manager.autoGCCallback) |
||||
} |
||||
|
||||
return manager |
||||
} |
||||
|
||||
// NewBufferPool 创建缓冲区池
|
||||
func NewBufferPool(sizes []int, logger *logrus.Logger) *BufferPool { |
||||
bp := &BufferPool{ |
||||
pools: make(map[int]*sync.Pool), |
||||
sizes: make([]int, len(sizes)), |
||||
logger: logger, |
||||
} |
||||
|
||||
copy(bp.sizes, sizes) |
||||
|
||||
// 为每个大小创建池
|
||||
for _, size := range sizes { |
||||
sz := size // 捕获循环变量
|
||||
bp.pools[sz] = &sync.Pool{ |
||||
New: func() interface{} { |
||||
bp.stats.incNews() |
||||
bp.stats.addTotalAlloc(int64(sz)) |
||||
return make([]byte, sz) |
||||
}, |
||||
} |
||||
} |
||||
|
||||
return bp |
||||
} |
||||
|
||||
// Get 获取缓冲区
|
||||
func (bp *BufferPool) Get(size int) []byte { |
||||
bp.stats.incGets() |
||||
|
||||
// 找到最合适的大小
|
||||
for _, poolSize := range bp.sizes { |
||||
if size <= poolSize { |
||||
buf := bp.pools[poolSize].Get().([]byte) |
||||
bp.stats.incReuses() |
||||
bp.stats.addTotalReused(int64(poolSize)) |
||||
return buf[:size] |
||||
} |
||||
} |
||||
|
||||
// 没有合适的池,创建新缓冲区
|
||||
bp.stats.incNews() |
||||
bp.stats.addTotalAlloc(int64(size)) |
||||
return make([]byte, size) |
||||
} |
||||
|
||||
// Put 归还缓冲区
|
||||
func (bp *BufferPool) Put(buf []byte) { |
||||
if buf == nil { |
||||
return |
||||
} |
||||
|
||||
bp.stats.incPuts() |
||||
capacity := cap(buf) |
||||
|
||||
// 找到对应的池
|
||||
for _, poolSize := range bp.sizes { |
||||
if capacity == poolSize { |
||||
bp.pools[poolSize].Put(buf[:poolSize]) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// GetStats 获取缓冲区统计
|
||||
func (bp *BufferPool) GetStats() BufferStats { |
||||
bp.stats.mu.RLock() |
||||
defer bp.stats.mu.RUnlock() |
||||
return bp.stats |
||||
} |
||||
|
||||
// NewMemoryMonitor 创建内存监控器
|
||||
func NewMemoryMonitor(interval time.Duration, thresholds Thresholds, logger *logrus.Logger) *MemoryMonitor { |
||||
monitor := &MemoryMonitor{ |
||||
logger: logger, |
||||
ticker: time.NewTicker(interval), |
||||
stopCh: make(chan struct{}), |
||||
thresholds: thresholds, |
||||
callbacks: make([]MemoryCallback, 0), |
||||
} |
||||
|
||||
// 启动监控
|
||||
go monitor.run() |
||||
|
||||
return monitor |
||||
} |
||||
|
||||
// AddCallback 添加内存回调
|
||||
func (mm *MemoryMonitor) AddCallback(callback MemoryCallback) { |
||||
mm.mu.Lock() |
||||
mm.callbacks = append(mm.callbacks, callback) |
||||
mm.mu.Unlock() |
||||
} |
||||
|
||||
// Stop 停止监控
|
||||
func (mm *MemoryMonitor) Stop() { |
||||
close(mm.stopCh) |
||||
mm.ticker.Stop() |
||||
} |
||||
|
||||
// run 运行监控
|
||||
func (mm *MemoryMonitor) run() { |
||||
for { |
||||
select { |
||||
case <-mm.stopCh: |
||||
return |
||||
case <-mm.ticker.C: |
||||
mm.check() |
||||
} |
||||
} |
||||
} |
||||
|
||||
// check 检查内存状态
|
||||
func (mm *MemoryMonitor) check() { |
||||
var stats runtime.MemStats |
||||
runtime.ReadMemStats(&stats) |
||||
|
||||
// 记录内存统计
|
||||
heapAllocMB := stats.HeapAlloc / 1024 / 1024 |
||||
heapSysMB := stats.HeapSys / 1024 / 1024 |
||||
|
||||
mm.logger.WithFields(logrus.Fields{ |
||||
"heap_alloc_mb": heapAllocMB, |
||||
"heap_sys_mb": heapSysMB, |
||||
"gc_num": stats.NumGC, |
||||
"goroutines": runtime.NumGoroutine(), |
||||
}).Debug("Memory stats") |
||||
|
||||
// 检查阈值
|
||||
if int64(heapAllocMB) > mm.thresholds.HeapAllocMB { |
||||
mm.logger.WithField("heap_alloc_mb", heapAllocMB).Warn("Heap allocation threshold exceeded") |
||||
} |
||||
|
||||
if int64(heapSysMB) > mm.thresholds.HeapSysMB { |
||||
mm.logger.WithField("heap_sys_mb", heapSysMB).Warn("Heap system threshold exceeded") |
||||
} |
||||
|
||||
// 强制GC阈值
|
||||
if int64(heapAllocMB) > mm.thresholds.ForceGCThreshMB { |
||||
mm.logger.WithField("heap_alloc_mb", heapAllocMB).Info("Force GC triggered") |
||||
runtime.GC() |
||||
} |
||||
|
||||
// 调用回调函数
|
||||
mm.mu.RLock() |
||||
for _, callback := range mm.callbacks { |
||||
go callback(stats) |
||||
} |
||||
mm.mu.RUnlock() |
||||
|
||||
mm.lastStats = stats |
||||
} |
||||
|
||||
// GetStats 获取最新内存统计
|
||||
func (mm *MemoryMonitor) GetStats() runtime.MemStats { |
||||
return mm.lastStats |
||||
} |
||||
|
||||
// Manager 方法
|
||||
|
||||
// GetBuffer 获取缓冲区
|
||||
func (m *Manager) GetBuffer(size int) []byte { |
||||
if m.bufferPool != nil { |
||||
return m.bufferPool.Get(size) |
||||
} |
||||
return make([]byte, size) |
||||
} |
||||
|
||||
// PutBuffer 归还缓冲区
|
||||
func (m *Manager) PutBuffer(buf []byte) { |
||||
if m.bufferPool != nil { |
||||
m.bufferPool.Put(buf) |
||||
} |
||||
} |
||||
|
||||
// GetBufferStats 获取缓冲区统计
|
||||
func (m *Manager) GetBufferStats() BufferStats { |
||||
if m.bufferPool != nil { |
||||
return m.bufferPool.GetStats() |
||||
} |
||||
return BufferStats{} |
||||
} |
||||
|
||||
// GetMemoryStats 获取内存统计
|
||||
func (m *Manager) GetMemoryStats() runtime.MemStats { |
||||
return m.monitor.GetStats() |
||||
} |
||||
|
||||
// Stop 停止内存管理器
|
||||
func (m *Manager) Stop() { |
||||
if m.monitor != nil { |
||||
m.monitor.Stop() |
||||
} |
||||
} |
||||
|
||||
// ForceGC 强制垃圾回收
|
||||
func (m *Manager) ForceGC() { |
||||
m.logger.Info("Manual GC triggered") |
||||
runtime.GC() |
||||
} |
||||
|
||||
// autoGCCallback 自动GC回调
|
||||
func (m *Manager) autoGCCallback(stats runtime.MemStats) { |
||||
heapAllocMB := stats.HeapAlloc / 1024 / 1024 |
||||
|
||||
// 当堆分配超过阈值时触发GC
|
||||
if int64(heapAllocMB) > m.config.Thresholds.ForceGCThreshMB { |
||||
m.ForceGC() |
||||
} |
||||
} |
||||
|
||||
// GetOverallStats 获取总体统计信息
|
||||
func (m *Manager) GetOverallStats() map[string]interface{} { |
||||
memStats := m.GetMemoryStats() |
||||
bufferStats := m.GetBufferStats() |
||||
|
||||
return map[string]interface{}{ |
||||
"memory": map[string]interface{}{ |
||||
"heap_alloc_mb": memStats.HeapAlloc / 1024 / 1024, |
||||
"heap_sys_mb": memStats.HeapSys / 1024 / 1024, |
||||
"gc_num": memStats.NumGC, |
||||
"goroutines": runtime.NumGoroutine(), |
||||
}, |
||||
"buffers": bufferStats, |
||||
"config": map[string]interface{}{ |
||||
"optimization_enabled": m.config.EnableOptimization, |
||||
"auto_gc_enabled": m.config.EnableAutoGC, |
||||
"buffer_sizes": m.config.BufferSizes, |
||||
}, |
||||
} |
||||
} |
||||
|
||||
// BufferStats 方法
|
||||
|
||||
func (bs *BufferStats) incGets() { |
||||
bs.mu.Lock() |
||||
bs.Gets++ |
||||
bs.mu.Unlock() |
||||
} |
||||
|
||||
func (bs *BufferStats) incPuts() { |
||||
bs.mu.Lock() |
||||
bs.Puts++ |
||||
bs.mu.Unlock() |
||||
} |
||||
|
||||
func (bs *BufferStats) incNews() { |
||||
bs.mu.Lock() |
||||
bs.News++ |
||||
bs.mu.Unlock() |
||||
} |
||||
|
||||
func (bs *BufferStats) incReuses() { |
||||
bs.mu.Lock() |
||||
bs.Reuses++ |
||||
bs.mu.Unlock() |
||||
} |
||||
|
||||
func (bs *BufferStats) addTotalAlloc(size int64) { |
||||
bs.mu.Lock() |
||||
bs.TotalAlloc += size |
||||
bs.mu.Unlock() |
||||
} |
||||
|
||||
func (bs *BufferStats) addTotalReused(size int64) { |
||||
bs.mu.Lock() |
||||
bs.TotalReused += size |
||||
bs.mu.Unlock() |
||||
} |
@ -0,0 +1,413 @@ |
||||
package pool |
||||
|
||||
import ( |
||||
"errors" |
||||
"net" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
var ( |
||||
ErrPoolClosed = errors.New("connection pool is closed") |
||||
ErrPoolFull = errors.New("connection pool is full") |
||||
ErrConnExpired = errors.New("connection has expired") |
||||
ErrConnInvalid = errors.New("connection is invalid") |
||||
) |
||||
|
||||
// ConnectionPool 连接池
|
||||
type ConnectionPool struct { |
||||
logger *logrus.Logger |
||||
factory Factory |
||||
pool chan *PooledConnection |
||||
mu sync.RWMutex |
||||
closed bool |
||||
maxSize int |
||||
maxLifetime time.Duration |
||||
maxIdle time.Duration |
||||
|
||||
// 统计信息
|
||||
stats Stats |
||||
} |
||||
|
||||
// PooledConnection 池化连接
|
||||
type PooledConnection struct { |
||||
conn net.Conn |
||||
createdAt time.Time |
||||
lastUsed time.Time |
||||
pool *ConnectionPool |
||||
} |
||||
|
||||
// Factory 连接工厂
|
||||
type Factory interface { |
||||
Create() (net.Conn, error) |
||||
Validate(net.Conn) bool |
||||
Close(net.Conn) error |
||||
} |
||||
|
||||
// Config 连接池配置
|
||||
type Config struct { |
||||
MaxSize int `json:"maxSize"` |
||||
MaxLifetime time.Duration `json:"maxLifetime"` |
||||
MaxIdle time.Duration `json:"maxIdle"` |
||||
InitialSize int `json:"initialSize"` |
||||
} |
||||
|
||||
// Stats 统计信息
|
||||
type Stats struct { |
||||
mu sync.RWMutex |
||||
Created int64 `json:"created"` |
||||
Reused int64 `json:"reused"` |
||||
Closed int64 `json:"closed"` |
||||
Active int64 `json:"active"` |
||||
Idle int64 `json:"idle"` |
||||
Failures int64 `json:"failures"` |
||||
} |
||||
|
||||
// NewConnectionPool 创建新的连接池
|
||||
func NewConnectionPool(config Config, factory Factory, logger *logrus.Logger) (*ConnectionPool, error) { |
||||
if config.MaxSize <= 0 { |
||||
config.MaxSize = 100 |
||||
} |
||||
if config.MaxLifetime == 0 { |
||||
config.MaxLifetime = 30 * time.Minute |
||||
} |
||||
if config.MaxIdle == 0 { |
||||
config.MaxIdle = 5 * time.Minute |
||||
} |
||||
|
||||
pool := &ConnectionPool{ |
||||
logger: logger, |
||||
factory: factory, |
||||
pool: make(chan *PooledConnection, config.MaxSize), |
||||
maxSize: config.MaxSize, |
||||
maxLifetime: config.MaxLifetime, |
||||
maxIdle: config.MaxIdle, |
||||
} |
||||
|
||||
// 预创建连接
|
||||
for i := 0; i < config.InitialSize && i < config.MaxSize; i++ { |
||||
conn, err := pool.factory.Create() |
||||
if err != nil { |
||||
pool.logger.WithError(err).Warn("Failed to create initial connection") |
||||
continue |
||||
} |
||||
|
||||
pooledConn := &PooledConnection{ |
||||
conn: conn, |
||||
createdAt: time.Now(), |
||||
lastUsed: time.Now(), |
||||
pool: pool, |
||||
} |
||||
|
||||
select { |
||||
case pool.pool <- pooledConn: |
||||
pool.stats.incCreated() |
||||
pool.stats.incIdle() |
||||
default: |
||||
conn.Close() |
||||
} |
||||
} |
||||
|
||||
// 启动清理goroutine
|
||||
go pool.cleaner() |
||||
|
||||
return pool, nil |
||||
} |
||||
|
||||
// Get 获取连接
|
||||
func (p *ConnectionPool) Get() (*PooledConnection, error) { |
||||
p.mu.RLock() |
||||
if p.closed { |
||||
p.mu.RUnlock() |
||||
return nil, ErrPoolClosed |
||||
} |
||||
p.mu.RUnlock() |
||||
|
||||
// 尝试从池中获取连接
|
||||
for { |
||||
select { |
||||
case conn := <-p.pool: |
||||
p.stats.decIdle() |
||||
|
||||
// 检查连接是否有效
|
||||
if p.isConnValid(conn) { |
||||
conn.lastUsed = time.Now() |
||||
p.stats.incReused() |
||||
p.stats.incActive() |
||||
return conn, nil |
||||
} |
||||
|
||||
// 连接无效,关闭并继续尝试
|
||||
p.closeConn(conn) |
||||
continue |
||||
|
||||
default: |
||||
// 池中没有可用连接,创建新连接
|
||||
return p.createConnection() |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Put 归还连接到池
|
||||
func (p *ConnectionPool) Put(conn *PooledConnection) error { |
||||
if conn == nil { |
||||
return nil |
||||
} |
||||
|
||||
p.mu.RLock() |
||||
if p.closed { |
||||
p.mu.RUnlock() |
||||
p.closeConn(conn) |
||||
return ErrPoolClosed |
||||
} |
||||
p.mu.RUnlock() |
||||
|
||||
p.stats.decActive() |
||||
|
||||
// 检查连接是否有效
|
||||
if !p.isConnValid(conn) { |
||||
p.closeConn(conn) |
||||
return ErrConnInvalid |
||||
} |
||||
|
||||
// 尝试归还到池
|
||||
select { |
||||
case p.pool <- conn: |
||||
p.stats.incIdle() |
||||
return nil |
||||
default: |
||||
// 池已满,关闭连接
|
||||
p.closeConn(conn) |
||||
return ErrPoolFull |
||||
} |
||||
} |
||||
|
||||
// Close 关闭连接池
|
||||
func (p *ConnectionPool) Close() error { |
||||
p.mu.Lock() |
||||
if p.closed { |
||||
p.mu.Unlock() |
||||
return nil |
||||
} |
||||
p.closed = true |
||||
p.mu.Unlock() |
||||
|
||||
// 关闭所有池中的连接
|
||||
close(p.pool) |
||||
for conn := range p.pool { |
||||
p.closeConn(conn) |
||||
} |
||||
|
||||
p.logger.Info("Connection pool closed") |
||||
return nil |
||||
} |
||||
|
||||
// GetStats 获取统计信息
|
||||
func (p *ConnectionPool) GetStats() Stats { |
||||
p.stats.mu.RLock() |
||||
defer p.stats.mu.RUnlock() |
||||
|
||||
stats := p.stats |
||||
stats.Idle = int64(len(p.pool)) |
||||
return stats |
||||
} |
||||
|
||||
// createConnection 创建新连接
|
||||
func (p *ConnectionPool) createConnection() (*PooledConnection, error) { |
||||
conn, err := p.factory.Create() |
||||
if err != nil { |
||||
p.stats.incFailures() |
||||
return nil, err |
||||
} |
||||
|
||||
pooledConn := &PooledConnection{ |
||||
conn: conn, |
||||
createdAt: time.Now(), |
||||
lastUsed: time.Now(), |
||||
pool: p, |
||||
} |
||||
|
||||
p.stats.incCreated() |
||||
p.stats.incActive() |
||||
return pooledConn, nil |
||||
} |
||||
|
||||
// isConnValid 检查连接是否有效
|
||||
func (p *ConnectionPool) isConnValid(conn *PooledConnection) bool { |
||||
now := time.Now() |
||||
|
||||
// 检查连接生命周期
|
||||
if now.Sub(conn.createdAt) > p.maxLifetime { |
||||
return false |
||||
} |
||||
|
||||
// 检查空闲时间
|
||||
if now.Sub(conn.lastUsed) > p.maxIdle { |
||||
return false |
||||
} |
||||
|
||||
// 使用工厂验证连接
|
||||
return p.factory.Validate(conn.conn) |
||||
} |
||||
|
||||
// closeConn 关闭连接
|
||||
func (p *ConnectionPool) closeConn(conn *PooledConnection) { |
||||
if conn != nil && conn.conn != nil { |
||||
p.factory.Close(conn.conn) |
||||
p.stats.incClosed() |
||||
} |
||||
} |
||||
|
||||
// cleaner 清理过期连接
|
||||
func (p *ConnectionPool) cleaner() { |
||||
ticker := time.NewTicker(1 * time.Minute) |
||||
defer ticker.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case <-ticker.C: |
||||
p.cleanExpiredConnections() |
||||
} |
||||
} |
||||
} |
||||
|
||||
// cleanExpiredConnections 清理过期连接
|
||||
func (p *ConnectionPool) cleanExpiredConnections() { |
||||
p.mu.RLock() |
||||
if p.closed { |
||||
p.mu.RUnlock() |
||||
return |
||||
} |
||||
p.mu.RUnlock() |
||||
|
||||
// 检查池中的连接
|
||||
poolSize := len(p.pool) |
||||
cleaned := 0 |
||||
|
||||
for i := 0; i < poolSize; i++ { |
||||
select { |
||||
case conn := <-p.pool: |
||||
if p.isConnValid(conn) { |
||||
// 连接有效,放回池中
|
||||
select { |
||||
case p.pool <- conn: |
||||
default: |
||||
// 池满了,关闭连接
|
||||
p.closeConn(conn) |
||||
cleaned++ |
||||
} |
||||
} else { |
||||
// 连接无效,关闭
|
||||
p.closeConn(conn) |
||||
p.stats.decIdle() |
||||
cleaned++ |
||||
} |
||||
default: |
||||
break |
||||
} |
||||
} |
||||
|
||||
if cleaned > 0 { |
||||
p.logger.WithField("count", cleaned).Debug("Cleaned expired connections") |
||||
} |
||||
} |
||||
|
||||
// PooledConnection 方法
|
||||
|
||||
// Read 读取数据
|
||||
func (pc *PooledConnection) Read(b []byte) (n int, err error) { |
||||
return pc.conn.Read(b) |
||||
} |
||||
|
||||
// Write 写入数据
|
||||
func (pc *PooledConnection) Write(b []byte) (n int, err error) { |
||||
return pc.conn.Write(b) |
||||
} |
||||
|
||||
// Close 关闭连接(归还到池)
|
||||
func (pc *PooledConnection) Close() error { |
||||
return pc.pool.Put(pc) |
||||
} |
||||
|
||||
// ForceClose 强制关闭连接
|
||||
func (pc *PooledConnection) ForceClose() error { |
||||
pc.pool.closeConn(pc) |
||||
return nil |
||||
} |
||||
|
||||
// LocalAddr 获取本地地址
|
||||
func (pc *PooledConnection) LocalAddr() net.Addr { |
||||
return pc.conn.LocalAddr() |
||||
} |
||||
|
||||
// RemoteAddr 获取远程地址
|
||||
func (pc *PooledConnection) RemoteAddr() net.Addr { |
||||
return pc.conn.RemoteAddr() |
||||
} |
||||
|
||||
// SetDeadline 设置截止时间
|
||||
func (pc *PooledConnection) SetDeadline(t time.Time) error { |
||||
return pc.conn.SetDeadline(t) |
||||
} |
||||
|
||||
// SetReadDeadline 设置读截止时间
|
||||
func (pc *PooledConnection) SetReadDeadline(t time.Time) error { |
||||
return pc.conn.SetReadDeadline(t) |
||||
} |
||||
|
||||
// SetWriteDeadline 设置写截止时间
|
||||
func (pc *PooledConnection) SetWriteDeadline(t time.Time) error { |
||||
return pc.conn.SetWriteDeadline(t) |
||||
} |
||||
|
||||
// Stats 方法
|
||||
|
||||
func (s *Stats) incCreated() { |
||||
s.mu.Lock() |
||||
s.Created++ |
||||
s.mu.Unlock() |
||||
} |
||||
|
||||
func (s *Stats) incReused() { |
||||
s.mu.Lock() |
||||
s.Reused++ |
||||
s.mu.Unlock() |
||||
} |
||||
|
||||
func (s *Stats) incClosed() { |
||||
s.mu.Lock() |
||||
s.Closed++ |
||||
s.mu.Unlock() |
||||
} |
||||
|
||||
func (s *Stats) incActive() { |
||||
s.mu.Lock() |
||||
s.Active++ |
||||
s.mu.Unlock() |
||||
} |
||||
|
||||
func (s *Stats) decActive() { |
||||
s.mu.Lock() |
||||
s.Active-- |
||||
s.mu.Unlock() |
||||
} |
||||
|
||||
func (s *Stats) incIdle() { |
||||
s.mu.Lock() |
||||
s.Idle++ |
||||
s.mu.Unlock() |
||||
} |
||||
|
||||
func (s *Stats) decIdle() { |
||||
s.mu.Lock() |
||||
s.Idle-- |
||||
s.mu.Unlock() |
||||
} |
||||
|
||||
func (s *Stats) incFailures() { |
||||
s.mu.Lock() |
||||
s.Failures++ |
||||
s.mu.Unlock() |
||||
} |
@ -0,0 +1,231 @@ |
||||
package ratelimit |
||||
|
||||
import ( |
||||
"net" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
// RateLimiter 速率限制器
|
||||
type RateLimiter struct { |
||||
logger *logrus.Logger |
||||
globalLimit *TokenBucket |
||||
perIPLimiters map[string]*TokenBucket |
||||
cleanupInterval time.Duration |
||||
mu sync.RWMutex |
||||
stopCh chan struct{} |
||||
} |
||||
|
||||
// TokenBucket Token桶算法实现
|
||||
type TokenBucket struct { |
||||
capacity int64 |
||||
tokens int64 |
||||
refillRate int64 // tokens per second
|
||||
lastRefill time.Time |
||||
mu sync.Mutex |
||||
} |
||||
|
||||
// Config 速率限制配置
|
||||
type Config struct { |
||||
Enabled bool `json:"enabled"` |
||||
RequestsPerSecond int `json:"requestsPerSecond"` |
||||
BurstSize int `json:"burstSize"` |
||||
PerIPRequestsPerSec int `json:"perIPRequestsPerSec"` |
||||
PerIPBurstSize int `json:"perIPBurstSize"` |
||||
CleanupInterval time.Duration `json:"cleanupInterval"` |
||||
} |
||||
|
||||
// NewRateLimiter 创建新的速率限制器
|
||||
func NewRateLimiter(config Config, logger *logrus.Logger) *RateLimiter { |
||||
if !config.Enabled { |
||||
return &RateLimiter{ |
||||
logger: logger, |
||||
} |
||||
} |
||||
|
||||
// 默认值
|
||||
if config.BurstSize == 0 { |
||||
config.BurstSize = config.RequestsPerSecond * 2 |
||||
} |
||||
if config.PerIPBurstSize == 0 { |
||||
config.PerIPBurstSize = config.PerIPRequestsPerSec * 2 |
||||
} |
||||
if config.CleanupInterval == 0 { |
||||
config.CleanupInterval = 5 * time.Minute |
||||
} |
||||
|
||||
rl := &RateLimiter{ |
||||
logger: logger, |
||||
globalLimit: NewTokenBucket(int64(config.BurstSize), int64(config.RequestsPerSecond)), |
||||
perIPLimiters: make(map[string]*TokenBucket), |
||||
cleanupInterval: config.CleanupInterval, |
||||
stopCh: make(chan struct{}), |
||||
} |
||||
|
||||
// 启动清理goroutine
|
||||
go rl.cleanup() |
||||
|
||||
return rl |
||||
} |
||||
|
||||
// Allow 检查是否允许请求
|
||||
func (rl *RateLimiter) Allow(remoteAddr string) bool { |
||||
if rl.globalLimit == nil { |
||||
return true // 未启用速率限制
|
||||
} |
||||
|
||||
// 全局速率限制
|
||||
if !rl.globalLimit.Allow() { |
||||
rl.logger.WithField("type", "global").Debug("Rate limit exceeded") |
||||
return false |
||||
} |
||||
|
||||
// 获取客户端IP
|
||||
host, _, err := net.SplitHostPort(remoteAddr) |
||||
if err != nil { |
||||
host = remoteAddr |
||||
} |
||||
|
||||
// 单IP速率限制
|
||||
rl.mu.RLock() |
||||
bucket, exists := rl.perIPLimiters[host] |
||||
rl.mu.RUnlock() |
||||
|
||||
if !exists { |
||||
rl.mu.Lock() |
||||
bucket, exists = rl.perIPLimiters[host] |
||||
if !exists { |
||||
bucket = NewTokenBucket(20, 10) // 默认每IP 10 rps, burst 20
|
||||
rl.perIPLimiters[host] = bucket |
||||
} |
||||
rl.mu.Unlock() |
||||
} |
||||
|
||||
if !bucket.Allow() { |
||||
rl.logger.WithFields(logrus.Fields{ |
||||
"type": "per_ip", |
||||
"ip": host, |
||||
}).Debug("Per-IP rate limit exceeded") |
||||
return false |
||||
} |
||||
|
||||
return true |
||||
} |
||||
|
||||
// Stop 停止速率限制器
|
||||
func (rl *RateLimiter) Stop() { |
||||
if rl.stopCh != nil { |
||||
close(rl.stopCh) |
||||
} |
||||
} |
||||
|
||||
// GetStats 获取统计信息
|
||||
func (rl *RateLimiter) GetStats() map[string]interface{} { |
||||
if rl.globalLimit == nil { |
||||
return map[string]interface{}{ |
||||
"enabled": false, |
||||
} |
||||
} |
||||
|
||||
rl.mu.RLock() |
||||
perIPCount := len(rl.perIPLimiters) |
||||
rl.mu.RUnlock() |
||||
|
||||
return map[string]interface{}{ |
||||
"enabled": true, |
||||
"global_tokens": rl.globalLimit.GetTokens(), |
||||
"per_ip_buckets": perIPCount, |
||||
} |
||||
} |
||||
|
||||
// cleanup 清理过期的IP限制器
|
||||
func (rl *RateLimiter) cleanup() { |
||||
ticker := time.NewTicker(rl.cleanupInterval) |
||||
defer ticker.Stop() |
||||
|
||||
for { |
||||
select { |
||||
case <-rl.stopCh: |
||||
return |
||||
case <-ticker.C: |
||||
rl.cleanupExpiredBuckets() |
||||
} |
||||
} |
||||
} |
||||
|
||||
// cleanupExpiredBuckets 清理过期的IP桶
|
||||
func (rl *RateLimiter) cleanupExpiredBuckets() { |
||||
rl.mu.Lock() |
||||
defer rl.mu.Unlock() |
||||
|
||||
now := time.Now() |
||||
expiredIPs := make([]string, 0) |
||||
|
||||
for ip, bucket := range rl.perIPLimiters { |
||||
// 如果桶超过10分钟没有活动,清理掉
|
||||
if now.Sub(bucket.lastRefill) > 10*time.Minute { |
||||
expiredIPs = append(expiredIPs, ip) |
||||
} |
||||
} |
||||
|
||||
for _, ip := range expiredIPs { |
||||
delete(rl.perIPLimiters, ip) |
||||
} |
||||
|
||||
if len(expiredIPs) > 0 { |
||||
rl.logger.WithField("count", len(expiredIPs)).Debug("Cleaned up expired rate limit buckets") |
||||
} |
||||
} |
||||
|
||||
// NewTokenBucket 创建新的Token桶
|
||||
func NewTokenBucket(capacity, refillRate int64) *TokenBucket { |
||||
return &TokenBucket{ |
||||
capacity: capacity, |
||||
tokens: capacity, |
||||
refillRate: refillRate, |
||||
lastRefill: time.Now(), |
||||
} |
||||
} |
||||
|
||||
// Allow 检查是否允许请求(消耗一个token)
|
||||
func (tb *TokenBucket) Allow() bool { |
||||
tb.mu.Lock() |
||||
defer tb.mu.Unlock() |
||||
|
||||
tb.refill() |
||||
|
||||
if tb.tokens > 0 { |
||||
tb.tokens-- |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
// refill 补充tokens
|
||||
func (tb *TokenBucket) refill() { |
||||
now := time.Now() |
||||
elapsed := now.Sub(tb.lastRefill) |
||||
|
||||
// 计算应该补充的tokens
|
||||
tokensToAdd := int64(elapsed.Seconds()) * tb.refillRate |
||||
|
||||
if tokensToAdd > 0 { |
||||
tb.tokens += tokensToAdd |
||||
if tb.tokens > tb.capacity { |
||||
tb.tokens = tb.capacity |
||||
} |
||||
tb.lastRefill = now |
||||
} |
||||
} |
||||
|
||||
// GetTokens 获取当前token数量
|
||||
func (tb *TokenBucket) GetTokens() int64 { |
||||
tb.mu.Lock() |
||||
defer tb.mu.Unlock() |
||||
|
||||
tb.refill() |
||||
return tb.tokens |
||||
} |
@ -0,0 +1,58 @@ |
||||
package socks5 |
||||
|
||||
import "strings" |
||||
|
||||
// SimpleAuthHandler 简单认证处理器
|
||||
type SimpleAuthHandler struct { |
||||
username string |
||||
password string |
||||
methods []byte |
||||
} |
||||
|
||||
// NewAuthHandler 创建认证处理器
|
||||
func NewAuthHandler(config AuthConfig) AuthHandler { |
||||
handler := &SimpleAuthHandler{ |
||||
username: config.Username, |
||||
password: config.Password, |
||||
} |
||||
|
||||
// 设置支持的认证方法
|
||||
if config.Username != "" && config.Password != "" { |
||||
handler.methods = []byte{AuthPassword} |
||||
} else { |
||||
handler.methods = []byte{AuthNone} |
||||
} |
||||
|
||||
// 如果配置中指定了方法,使用配置的方法
|
||||
if len(config.Methods) > 0 { |
||||
handler.methods = []byte{} |
||||
for _, method := range config.Methods { |
||||
switch strings.ToLower(method) { |
||||
case "none": |
||||
handler.methods = append(handler.methods, AuthNone) |
||||
case "password": |
||||
handler.methods = append(handler.methods, AuthPassword) |
||||
} |
||||
} |
||||
} |
||||
|
||||
return handler |
||||
} |
||||
|
||||
// Authenticate 验证用户名和密码
|
||||
func (h *SimpleAuthHandler) Authenticate(username, password string) bool { |
||||
// 如果支持无认证,直接返回true
|
||||
for _, method := range h.methods { |
||||
if method == AuthNone { |
||||
return true |
||||
} |
||||
} |
||||
|
||||
// 检查用户名密码
|
||||
return h.username == username && h.password == password |
||||
} |
||||
|
||||
// Methods 返回支持的认证方法
|
||||
func (h *SimpleAuthHandler) Methods() []byte { |
||||
return h.methods |
||||
} |
@ -0,0 +1,34 @@ |
||||
package socks5 |
||||
|
||||
import ( |
||||
"net" |
||||
"time" |
||||
) |
||||
|
||||
// DirectDialer 直接连接拨号器
|
||||
type DirectDialer struct { |
||||
timeout time.Duration |
||||
} |
||||
|
||||
// Dial 建立连接
|
||||
func (d *DirectDialer) Dial(network, address string) (net.Conn, error) { |
||||
if d.timeout > 0 { |
||||
return net.DialTimeout(network, address, d.timeout) |
||||
} |
||||
return net.Dial(network, address) |
||||
} |
||||
|
||||
// DirectResolver 直接DNS解析器
|
||||
type DirectResolver struct{} |
||||
|
||||
// Resolve 解析域名
|
||||
func (r *DirectResolver) Resolve(domain string) (net.IP, error) { |
||||
ips, err := net.LookupIP(domain) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(ips) == 0 { |
||||
return nil, net.ErrClosed |
||||
} |
||||
return ips[0], nil |
||||
} |
@ -0,0 +1,114 @@ |
||||
package socks5 |
||||
|
||||
import ( |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
) |
||||
|
||||
// SimpleRule 简单访问规则
|
||||
type SimpleRule struct { |
||||
action string // allow, deny
|
||||
networks []*net.IPNet |
||||
ports []int |
||||
} |
||||
|
||||
// NewRule 创建访问规则
|
||||
func NewRule(config RuleConfig) Rule { |
||||
rule := &SimpleRule{ |
||||
action: strings.ToLower(config.Action), |
||||
ports: config.Ports, |
||||
} |
||||
|
||||
// 解析IP网段
|
||||
for _, ipStr := range config.IPs { |
||||
if !strings.Contains(ipStr, "/") { |
||||
// 单个IP,添加适当的子网掩码
|
||||
if strings.Contains(ipStr, ":") { |
||||
ipStr += "/128" // IPv6
|
||||
} else { |
||||
ipStr += "/32" // IPv4
|
||||
} |
||||
} |
||||
|
||||
_, network, err := net.ParseCIDR(ipStr) |
||||
if err != nil { |
||||
// 如果解析失败,尝试作为单个IP处理
|
||||
ip := net.ParseIP(ipStr) |
||||
if ip != nil { |
||||
if ip.To4() != nil { |
||||
_, network, _ = net.ParseCIDR(ip.String() + "/32") |
||||
} else { |
||||
_, network, _ = net.ParseCIDR(ip.String() + "/128") |
||||
} |
||||
} |
||||
} |
||||
|
||||
if network != nil { |
||||
rule.networks = append(rule.networks, network) |
||||
} |
||||
} |
||||
|
||||
return rule |
||||
} |
||||
|
||||
// Allow 检查是否允许访问
|
||||
func (r *SimpleRule) Allow(addr net.Addr) bool { |
||||
// 解析地址
|
||||
var ip net.IP |
||||
var port int |
||||
|
||||
switch a := addr.(type) { |
||||
case *net.IPAddr: |
||||
ip = a.IP |
||||
case *net.TCPAddr: |
||||
ip = a.IP |
||||
port = a.Port |
||||
case *net.UDPAddr: |
||||
ip = a.IP |
||||
port = a.Port |
||||
default: |
||||
// 尝试从字符串解析
|
||||
host, portStr, err := net.SplitHostPort(addr.String()) |
||||
if err != nil { |
||||
return r.action == "allow" // 默认策略
|
||||
} |
||||
ip = net.ParseIP(host) |
||||
if p, err := strconv.Atoi(portStr); err == nil { |
||||
port = p |
||||
} |
||||
} |
||||
|
||||
if ip == nil { |
||||
return r.action == "allow" // 默认策略
|
||||
} |
||||
|
||||
// 检查IP是否匹配
|
||||
ipMatches := len(r.networks) == 0 // 如果没有指定网段,默认匹配
|
||||
for _, network := range r.networks { |
||||
if network.Contains(ip) { |
||||
ipMatches = true |
||||
break |
||||
} |
||||
} |
||||
|
||||
// 检查端口是否匹配
|
||||
portMatches := len(r.ports) == 0 // 如果没有指定端口,默认匹配
|
||||
for _, p := range r.ports { |
||||
if p == port { |
||||
portMatches = true |
||||
break |
||||
} |
||||
} |
||||
|
||||
// 根据规则类型返回结果
|
||||
matches := ipMatches && portMatches |
||||
switch r.action { |
||||
case "allow": |
||||
return matches |
||||
case "deny": |
||||
return !matches |
||||
default: |
||||
return true // 默认允许
|
||||
} |
||||
} |
@ -0,0 +1,432 @@ |
||||
package socks5 |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"time" |
||||
|
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
// SOCKS5 协议常量
|
||||
const ( |
||||
// SOCKS版本
|
||||
Version5 = 0x05 |
||||
|
||||
// 认证方法
|
||||
AuthNone = 0x00 |
||||
AuthPassword = 0x02 |
||||
AuthNoSupported = 0xFF |
||||
|
||||
// 命令类型
|
||||
CmdConnect = 0x01 |
||||
CmdBind = 0x02 |
||||
CmdUDP = 0x03 |
||||
|
||||
// 地址类型
|
||||
AddrIPv4 = 0x01 |
||||
AddrDomain = 0x03 |
||||
AddrIPv6 = 0x04 |
||||
|
||||
// 响应状态
|
||||
StatusSuccess = 0x00 |
||||
StatusServerFailure = 0x01 |
||||
StatusConnectionNotAllowed = 0x02 |
||||
StatusNetworkUnreachable = 0x03 |
||||
StatusHostUnreachable = 0x04 |
||||
StatusConnectionRefused = 0x05 |
||||
StatusTTLExpired = 0x06 |
||||
StatusCommandNotSupported = 0x07 |
||||
StatusAddressNotSupported = 0x08 |
||||
) |
||||
|
||||
type Server struct { |
||||
logger *logrus.Logger |
||||
auth AuthHandler |
||||
dialer Dialer |
||||
resolver Resolver |
||||
rules []Rule |
||||
} |
||||
|
||||
type Config struct { |
||||
Auth AuthConfig |
||||
Timeout time.Duration |
||||
Rules []RuleConfig |
||||
} |
||||
|
||||
type AuthConfig struct { |
||||
Methods []string |
||||
Username string |
||||
Password string |
||||
} |
||||
|
||||
type RuleConfig struct { |
||||
Action string // allow, deny
|
||||
IPs []string |
||||
Ports []int |
||||
} |
||||
|
||||
type AuthHandler interface { |
||||
Authenticate(username, password string) bool |
||||
Methods() []byte |
||||
} |
||||
|
||||
type Dialer interface { |
||||
Dial(network, address string) (net.Conn, error) |
||||
} |
||||
|
||||
type Resolver interface { |
||||
Resolve(domain string) (net.IP, error) |
||||
} |
||||
|
||||
type Rule interface { |
||||
Allow(addr net.Addr) bool |
||||
} |
||||
|
||||
// NewServer 创建新的SOCKS5服务器
|
||||
func NewServer(config Config, logger *logrus.Logger) *Server { |
||||
server := &Server{ |
||||
logger: logger, |
||||
dialer: &DirectDialer{timeout: config.Timeout}, |
||||
resolver: &DirectResolver{}, |
||||
} |
||||
|
||||
// 设置认证处理器
|
||||
server.auth = NewAuthHandler(config.Auth) |
||||
|
||||
// 设置访问规则
|
||||
for _, ruleConfig := range config.Rules { |
||||
rule := NewRule(ruleConfig) |
||||
server.rules = append(server.rules, rule) |
||||
} |
||||
|
||||
return server |
||||
} |
||||
|
||||
// HandleConnection 处理SOCKS5连接
|
||||
func (s *Server) HandleConnection(conn net.Conn) error { |
||||
defer conn.Close() |
||||
|
||||
s.logger.WithField("remote_addr", conn.RemoteAddr()).Debug("New SOCKS5 connection") |
||||
|
||||
// 设置连接超时
|
||||
conn.SetDeadline(time.Now().Add(30 * time.Second)) |
||||
|
||||
// 1. 协议版本协商
|
||||
if err := s.handleVersionNegotiation(conn); err != nil { |
||||
s.logger.WithError(err).Error("Version negotiation failed") |
||||
return err |
||||
} |
||||
|
||||
// 2. 认证
|
||||
if err := s.handleAuthentication(conn); err != nil { |
||||
s.logger.WithError(err).Error("Authentication failed") |
||||
return err |
||||
} |
||||
|
||||
// 3. 处理请求
|
||||
return s.handleRequest(conn) |
||||
} |
||||
|
||||
// handleVersionNegotiation 处理版本协商
|
||||
func (s *Server) handleVersionNegotiation(conn net.Conn) error { |
||||
// 读取客户端版本协商请求
|
||||
buf := make([]byte, 2) |
||||
if _, err := io.ReadFull(conn, buf); err != nil { |
||||
return fmt.Errorf("failed to read version: %w", err) |
||||
} |
||||
|
||||
version := buf[0] |
||||
nmethods := buf[1] |
||||
|
||||
if version != Version5 { |
||||
return fmt.Errorf("unsupported SOCKS version: %d", version) |
||||
} |
||||
|
||||
// 读取支持的认证方法
|
||||
methods := make([]byte, nmethods) |
||||
if _, err := io.ReadFull(conn, methods); err != nil { |
||||
return fmt.Errorf("failed to read methods: %w", err) |
||||
} |
||||
|
||||
// 选择认证方法
|
||||
authMethods := s.auth.Methods() |
||||
selectedMethod := byte(AuthNoSupported) |
||||
|
||||
for _, method := range methods { |
||||
for _, supported := range authMethods { |
||||
if method == supported { |
||||
selectedMethod = method |
||||
break |
||||
} |
||||
} |
||||
if selectedMethod != AuthNoSupported { |
||||
break |
||||
} |
||||
} |
||||
|
||||
// 发送选择的认证方法
|
||||
response := []byte{Version5, selectedMethod} |
||||
if _, err := conn.Write(response); err != nil { |
||||
return fmt.Errorf("failed to write method selection: %w", err) |
||||
} |
||||
|
||||
if selectedMethod == AuthNoSupported { |
||||
return fmt.Errorf("no supported authentication method") |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// handleAuthentication 处理认证
|
||||
func (s *Server) handleAuthentication(conn net.Conn) error { |
||||
// 读取认证请求
|
||||
buf := make([]byte, 2) |
||||
if _, err := io.ReadFull(conn, buf); err != nil { |
||||
return fmt.Errorf("failed to read auth version: %w", err) |
||||
} |
||||
|
||||
version := buf[0] |
||||
usernameLen := buf[1] |
||||
|
||||
if version != 0x01 { |
||||
return fmt.Errorf("unsupported auth version: %d", version) |
||||
} |
||||
|
||||
// 读取用户名
|
||||
username := make([]byte, usernameLen) |
||||
if _, err := io.ReadFull(conn, username); err != nil { |
||||
return fmt.Errorf("failed to read username: %w", err) |
||||
} |
||||
|
||||
// 读取密码长度
|
||||
if _, err := io.ReadFull(conn, buf[:1]); err != nil { |
||||
return fmt.Errorf("failed to read password length: %w", err) |
||||
} |
||||
passwordLen := buf[0] |
||||
|
||||
// 读取密码
|
||||
password := make([]byte, passwordLen) |
||||
if _, err := io.ReadFull(conn, password); err != nil { |
||||
return fmt.Errorf("failed to read password: %w", err) |
||||
} |
||||
|
||||
// 验证认证
|
||||
success := s.auth.Authenticate(string(username), string(password)) |
||||
|
||||
// 发送认证结果
|
||||
status := byte(0x01) // 失败
|
||||
if success { |
||||
status = 0x00 // 成功
|
||||
} |
||||
|
||||
response := []byte{0x01, status} |
||||
if _, err := conn.Write(response); err != nil { |
||||
return fmt.Errorf("failed to write auth response: %w", err) |
||||
} |
||||
|
||||
if !success { |
||||
return fmt.Errorf("authentication failed") |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// handleRequest 处理SOCKS5请求
|
||||
func (s *Server) handleRequest(conn net.Conn) error { |
||||
// 读取请求头
|
||||
buf := make([]byte, 4) |
||||
if _, err := io.ReadFull(conn, buf); err != nil { |
||||
return fmt.Errorf("failed to read request header: %w", err) |
||||
} |
||||
|
||||
version := buf[0] |
||||
cmd := buf[1] |
||||
// rsv := buf[2] // 保留字段
|
||||
addrType := buf[3] |
||||
|
||||
if version != Version5 { |
||||
return fmt.Errorf("invalid SOCKS version: %d", version) |
||||
} |
||||
|
||||
// 读取目标地址
|
||||
addr, err := s.readAddress(conn, addrType) |
||||
if err != nil { |
||||
s.sendResponse(conn, StatusServerFailure, "0.0.0.0", 0) |
||||
return fmt.Errorf("failed to read address: %w", err) |
||||
} |
||||
|
||||
// 检查访问规则
|
||||
if !s.checkRules(addr) { |
||||
s.sendResponse(conn, StatusConnectionNotAllowed, "0.0.0.0", 0) |
||||
return fmt.Errorf("connection not allowed: %s", addr) |
||||
} |
||||
|
||||
// 处理不同的命令
|
||||
switch cmd { |
||||
case CmdConnect: |
||||
return s.handleConnect(conn, addr) |
||||
case CmdBind: |
||||
return s.handleBind(conn, addr) |
||||
case CmdUDP: |
||||
return s.handleUDP(conn, addr) |
||||
default: |
||||
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0) |
||||
return fmt.Errorf("unsupported command: %d", cmd) |
||||
} |
||||
} |
||||
|
||||
// readAddress 读取地址信息
|
||||
func (s *Server) readAddress(conn net.Conn, addrType byte) (string, error) { |
||||
switch addrType { |
||||
case AddrIPv4: |
||||
buf := make([]byte, 6) // 4字节IP + 2字节端口
|
||||
if _, err := io.ReadFull(conn, buf); err != nil { |
||||
return "", err |
||||
} |
||||
ip := net.IP(buf[:4]) |
||||
port := binary.BigEndian.Uint16(buf[4:6]) |
||||
return fmt.Sprintf("%s:%d", ip.String(), port), nil |
||||
|
||||
case AddrDomain: |
||||
buf := make([]byte, 1) |
||||
if _, err := io.ReadFull(conn, buf); err != nil { |
||||
return "", err |
||||
} |
||||
domainLen := buf[0] |
||||
|
||||
domain := make([]byte, domainLen+2) // 域名 + 2字节端口
|
||||
if _, err := io.ReadFull(conn, domain); err != nil { |
||||
return "", err |
||||
} |
||||
port := binary.BigEndian.Uint16(domain[domainLen:]) |
||||
return fmt.Sprintf("%s:%d", string(domain[:domainLen]), port), nil |
||||
|
||||
case AddrIPv6: |
||||
buf := make([]byte, 18) // 16字节IP + 2字节端口
|
||||
if _, err := io.ReadFull(conn, buf); err != nil { |
||||
return "", err |
||||
} |
||||
ip := net.IP(buf[:16]) |
||||
port := binary.BigEndian.Uint16(buf[16:18]) |
||||
return fmt.Sprintf("[%s]:%d", ip.String(), port), nil |
||||
|
||||
default: |
||||
return "", fmt.Errorf("unsupported address type: %d", addrType) |
||||
} |
||||
} |
||||
|
||||
// handleConnect 处理CONNECT命令
|
||||
func (s *Server) handleConnect(conn net.Conn, addr string) error { |
||||
s.logger.WithField("target", addr).Debug("Handling CONNECT request") |
||||
|
||||
// 连接到目标服务器
|
||||
target, err := s.dialer.Dial("tcp", addr) |
||||
if err != nil { |
||||
s.logger.WithError(err).WithField("target", addr).Error("Failed to connect to target") |
||||
s.sendResponse(conn, StatusConnectionRefused, "0.0.0.0", 0) |
||||
return err |
||||
} |
||||
defer target.Close() |
||||
|
||||
// 发送成功响应
|
||||
localAddr := target.LocalAddr().(*net.TCPAddr) |
||||
s.sendResponse(conn, StatusSuccess, localAddr.IP.String(), uint16(localAddr.Port)) |
||||
|
||||
// 开始数据转发
|
||||
s.logger.WithField("target", addr).Info("Starting data relay") |
||||
return s.relay(conn, target) |
||||
} |
||||
|
||||
// handleBind 处理BIND命令
|
||||
func (s *Server) handleBind(conn net.Conn, addr string) error { |
||||
// BIND命令实现(简化版本)
|
||||
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0) |
||||
return fmt.Errorf("BIND command not implemented") |
||||
} |
||||
|
||||
// handleUDP 处理UDP命令
|
||||
func (s *Server) handleUDP(conn net.Conn, addr string) error { |
||||
// UDP关联实现(简化版本)
|
||||
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0) |
||||
return fmt.Errorf("UDP command not implemented") |
||||
} |
||||
|
||||
// sendResponse 发送SOCKS5响应
|
||||
func (s *Server) sendResponse(conn net.Conn, status byte, ip string, port uint16) error { |
||||
response := []byte{ |
||||
Version5, |
||||
status, |
||||
0x00, // 保留字段
|
||||
AddrIPv4, |
||||
} |
||||
|
||||
// 添加IP地址
|
||||
ipAddr := net.ParseIP(ip) |
||||
if ipAddr == nil { |
||||
ipAddr = net.ParseIP("0.0.0.0") |
||||
} |
||||
if ipv4 := ipAddr.To4(); ipv4 != nil { |
||||
response = append(response, ipv4...) |
||||
} else { |
||||
response[3] = AddrIPv6 |
||||
response = append(response, ipAddr.To16()...) |
||||
} |
||||
|
||||
// 添加端口
|
||||
portBytes := make([]byte, 2) |
||||
binary.BigEndian.PutUint16(portBytes, port) |
||||
response = append(response, portBytes...) |
||||
|
||||
_, err := conn.Write(response) |
||||
return err |
||||
} |
||||
|
||||
// checkRules 检查访问规则
|
||||
func (s *Server) checkRules(addr string) bool { |
||||
if len(s.rules) == 0 { |
||||
return true // 没有规则时允许所有连接
|
||||
} |
||||
|
||||
host, _, err := net.SplitHostPort(addr) |
||||
if err != nil { |
||||
return false |
||||
} |
||||
|
||||
targetAddr, err := net.ResolveIPAddr("ip", host) |
||||
if err != nil { |
||||
return false |
||||
} |
||||
|
||||
for _, rule := range s.rules { |
||||
if rule.Allow(targetAddr) { |
||||
return true |
||||
} |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
// relay 数据转发
|
||||
func (s *Server) relay(client, target net.Conn) error { |
||||
// 创建双向数据转发
|
||||
errChan := make(chan error, 2) |
||||
|
||||
// 客户端到目标服务器
|
||||
go func() { |
||||
_, err := io.Copy(target, client) |
||||
errChan <- err |
||||
}() |
||||
|
||||
// 目标服务器到客户端
|
||||
go func() { |
||||
_, err := io.Copy(client, target) |
||||
errChan <- err |
||||
}() |
||||
|
||||
// 等待任一方向的连接关闭
|
||||
err := <-errChan |
||||
return err |
||||
} |
@ -0,0 +1,54 @@ |
||||
package socks5 |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/azoic/wormhole-server/pkg/memory" |
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
func BenchmarkBufferPool(b *testing.B) { |
||||
logger := logrus.New() |
||||
logger.SetLevel(logrus.ErrorLevel) |
||||
|
||||
memConfig := memory.Config{ |
||||
BufferSizes: []int{512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}, |
||||
EnableOptimization: true, |
||||
EnableAutoGC: false, |
||||
MonitorInterval: time.Minute, |
||||
} |
||||
memManager := memory.NewManager(memConfig, logger) |
||||
defer memManager.Stop() |
||||
|
||||
b.ResetTimer() |
||||
b.RunParallel(func(pb *testing.PB) { |
||||
for pb.Next() { |
||||
// 测试不同大小的缓冲区
|
||||
sizes := []int{512, 1024, 2048, 4096, 8192} |
||||
for _, size := range sizes { |
||||
buf := memManager.GetBuffer(size) |
||||
memManager.PutBuffer(buf) |
||||
} |
||||
} |
||||
}) |
||||
} |
||||
|
||||
func BenchmarkMemoryManagerStats(b *testing.B) { |
||||
logger := logrus.New() |
||||
logger.SetLevel(logrus.ErrorLevel) |
||||
|
||||
memConfig := memory.Config{ |
||||
BufferSizes: []int{512, 1024, 2048, 4096}, |
||||
EnableOptimization: true, |
||||
EnableAutoGC: false, |
||||
MonitorInterval: time.Minute, |
||||
} |
||||
memManager := memory.NewManager(memConfig, logger) |
||||
defer memManager.Stop() |
||||
|
||||
b.ResetTimer() |
||||
for i := 0; i < b.N; i++ { |
||||
_ = memManager.GetOverallStats() |
||||
} |
||||
} |
@ -0,0 +1,71 @@ |
||||
package socks5 |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
func TestNewServer(t *testing.T) { |
||||
logger := logrus.New() |
||||
logger.SetLevel(logrus.ErrorLevel) // 减少测试输出
|
||||
|
||||
config := Config{ |
||||
Auth: AuthConfig{ |
||||
Username: "test", |
||||
Password: "pass", |
||||
}, |
||||
Timeout: 30 * time.Second, |
||||
} |
||||
|
||||
server := NewServer(config, logger) |
||||
if server == nil { |
||||
t.Fatal("Expected server to be created, got nil") |
||||
} |
||||
|
||||
if server.logger != logger { |
||||
t.Error("Logger not set correctly") |
||||
} |
||||
} |
||||
|
||||
func TestAuthHandler(t *testing.T) { |
||||
config := AuthConfig{ |
||||
Username: "testuser", |
||||
Password: "testpass", |
||||
Methods: []string{"password"}, |
||||
} |
||||
|
||||
handler := NewAuthHandler(config) |
||||
|
||||
// 测试正确的认证
|
||||
if !handler.Authenticate("testuser", "testpass") { |
||||
t.Error("Valid authentication failed") |
||||
} |
||||
|
||||
// 测试错误的认证
|
||||
if handler.Authenticate("wronguser", "wrongpass") { |
||||
t.Error("Invalid authentication succeeded") |
||||
} |
||||
|
||||
// 检查支持的方法
|
||||
methods := handler.Methods() |
||||
if len(methods) != 1 || methods[0] != AuthPassword { |
||||
t.Error("Expected password authentication method") |
||||
} |
||||
} |
||||
|
||||
func TestRule(t *testing.T) { |
||||
config := RuleConfig{ |
||||
Action: "allow", |
||||
IPs: []string{"192.168.1.0/24", "127.0.0.1"}, |
||||
Ports: []int{80, 443}, |
||||
} |
||||
|
||||
rule := NewRule(config) |
||||
if rule == nil { |
||||
t.Fatal("Expected rule to be created, got nil") |
||||
} |
||||
|
||||
// 这里可以添加更多的规则测试
|
||||
} |
@ -1,368 +0,0 @@ |
||||
package system |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"os/exec" |
||||
"runtime" |
||||
"strings" |
||||
|
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
type SystemProxy struct { |
||||
logger *logrus.Logger |
||||
proxyAddr string |
||||
backupConfig map[string]string |
||||
} |
||||
|
||||
type Config struct { |
||||
HTTPProxy string |
||||
HTTPSProxy string |
||||
SOCKSProxy string |
||||
NoProxy []string |
||||
} |
||||
|
||||
func NewSystemProxy(proxyAddr string, logger *logrus.Logger) *SystemProxy { |
||||
return &SystemProxy{ |
||||
logger: logger, |
||||
proxyAddr: proxyAddr, |
||||
backupConfig: make(map[string]string), |
||||
} |
||||
} |
||||
|
||||
// SetSystemProxy 设置系统代理
|
||||
func (sp *SystemProxy) SetSystemProxy(config Config) error { |
||||
sp.logger.Info("Setting system proxy configuration") |
||||
|
||||
// 备份当前配置
|
||||
if err := sp.backupCurrentConfig(); err != nil { |
||||
sp.logger.WithError(err).Warn("Failed to backup current proxy config") |
||||
} |
||||
|
||||
switch runtime.GOOS { |
||||
case "darwin": |
||||
return sp.setMacOSProxy(config) |
||||
case "linux": |
||||
return sp.setLinuxProxy(config) |
||||
case "windows": |
||||
return sp.setWindowsProxy(config) |
||||
default: |
||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) |
||||
} |
||||
} |
||||
|
||||
// RestoreSystemProxy 恢复系统代理设置
|
||||
func (sp *SystemProxy) RestoreSystemProxy() error { |
||||
sp.logger.Info("Restoring system proxy configuration") |
||||
|
||||
switch runtime.GOOS { |
||||
case "darwin": |
||||
return sp.restoreMacOSProxy() |
||||
case "linux": |
||||
return sp.restoreLinuxProxy() |
||||
case "windows": |
||||
return sp.restoreWindowsProxy() |
||||
default: |
||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) |
||||
} |
||||
} |
||||
|
||||
// setMacOSProxy 设置macOS系统代理
|
||||
func (sp *SystemProxy) setMacOSProxy(config Config) error { |
||||
// 获取所有网络服务
|
||||
services, err := sp.getMacOSNetworkServices() |
||||
if err != nil { |
||||
return fmt.Errorf("failed to get network services: %w", err) |
||||
} |
||||
|
||||
for _, service := range services { |
||||
// 设置HTTP代理
|
||||
if config.HTTPProxy != "" { |
||||
parts := strings.Split(config.HTTPProxy, ":") |
||||
if len(parts) == 2 { |
||||
if err := sp.runCommand("networksetup", "-setwebproxy", service, parts[0], parts[1]); err != nil { |
||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to set HTTP proxy") |
||||
} |
||||
if err := sp.runCommand("networksetup", "-setwebproxystate", service, "on"); err != nil { |
||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to enable HTTP proxy") |
||||
} |
||||
} |
||||
} |
||||
|
||||
// 设置HTTPS代理
|
||||
if config.HTTPSProxy != "" { |
||||
parts := strings.Split(config.HTTPSProxy, ":") |
||||
if len(parts) == 2 { |
||||
if err := sp.runCommand("networksetup", "-setsecurewebproxy", service, parts[0], parts[1]); err != nil { |
||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to set HTTPS proxy") |
||||
} |
||||
if err := sp.runCommand("networksetup", "-setsecurewebproxystate", service, "on"); err != nil { |
||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to enable HTTPS proxy") |
||||
} |
||||
} |
||||
} |
||||
|
||||
// 设置SOCKS代理
|
||||
if config.SOCKSProxy != "" { |
||||
parts := strings.Split(config.SOCKSProxy, ":") |
||||
if len(parts) == 2 { |
||||
if err := sp.runCommand("networksetup", "-setsocksfirewallproxy", service, parts[0], parts[1]); err != nil { |
||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to set SOCKS proxy") |
||||
} |
||||
if err := sp.runCommand("networksetup", "-setsocksfirewallproxystate", service, "on"); err != nil { |
||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to enable SOCKS proxy") |
||||
} |
||||
} |
||||
} |
||||
|
||||
// 设置代理绕过列表
|
||||
if len(config.NoProxy) > 0 { |
||||
bypassList := strings.Join(config.NoProxy, " ") |
||||
if err := sp.runCommand("networksetup", "-setproxybypassdomains", service, bypassList); err != nil { |
||||
sp.logger.WithError(err).WithField("service", service).Warn("Failed to set proxy bypass list") |
||||
} |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// setLinuxProxy 设置Linux系统代理(通过环境变量)
|
||||
func (sp *SystemProxy) setLinuxProxy(config Config) error { |
||||
envVars := make(map[string]string) |
||||
|
||||
if config.HTTPProxy != "" { |
||||
envVars["http_proxy"] = "http://" + config.HTTPProxy |
||||
envVars["HTTP_PROXY"] = "http://" + config.HTTPProxy |
||||
} |
||||
|
||||
if config.HTTPSProxy != "" { |
||||
envVars["https_proxy"] = "https://" + config.HTTPSProxy |
||||
envVars["HTTPS_PROXY"] = "https://" + config.HTTPSProxy |
||||
} |
||||
|
||||
if len(config.NoProxy) > 0 { |
||||
noProxy := strings.Join(config.NoProxy, ",") |
||||
envVars["no_proxy"] = noProxy |
||||
envVars["NO_PROXY"] = noProxy |
||||
} |
||||
|
||||
// 写入到用户的shell配置文件
|
||||
return sp.writeLinuxProxyConfig(envVars) |
||||
} |
||||
|
||||
// setWindowsProxy 设置Windows系统代理
|
||||
func (sp *SystemProxy) setWindowsProxy(config Config) error { |
||||
// Windows代理设置通过注册表
|
||||
if config.HTTPProxy != "" { |
||||
// 启用代理
|
||||
if err := sp.runCommand("reg", "add", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyEnable", "/t", "REG_DWORD", "/d", "1", "/f"); err != nil { |
||||
return fmt.Errorf("failed to enable proxy: %w", err) |
||||
} |
||||
|
||||
// 设置代理服务器
|
||||
if err := sp.runCommand("reg", "add", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyServer", "/t", "REG_SZ", "/d", config.HTTPProxy, "/f"); err != nil { |
||||
return fmt.Errorf("failed to set proxy server: %w", err) |
||||
} |
||||
} |
||||
|
||||
// 设置代理绕过列表
|
||||
if len(config.NoProxy) > 0 { |
||||
bypassList := strings.Join(config.NoProxy, ";") |
||||
if err := sp.runCommand("reg", "add", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyOverride", "/t", "REG_SZ", "/d", bypassList, "/f"); err != nil { |
||||
return fmt.Errorf("failed to set proxy bypass list: %w", err) |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (sp *SystemProxy) getMacOSNetworkServices() ([]string, error) { |
||||
output, err := exec.Command("networksetup", "-listallnetworkservices").Output() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
lines := strings.Split(string(output), "\n") |
||||
var services []string |
||||
|
||||
for _, line := range lines { |
||||
line = strings.TrimSpace(line) |
||||
if line != "" && !strings.HasPrefix(line, "*") && line != "An asterisk (*) denotes that a network service is disabled." { |
||||
services = append(services, line) |
||||
} |
||||
} |
||||
|
||||
return services, nil |
||||
} |
||||
|
||||
func (sp *SystemProxy) backupCurrentConfig() error { |
||||
switch runtime.GOOS { |
||||
case "darwin": |
||||
return sp.backupMacOSConfig() |
||||
case "linux": |
||||
return sp.backupLinuxConfig() |
||||
case "windows": |
||||
return sp.backupWindowsConfig() |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (sp *SystemProxy) backupMacOSConfig() error { |
||||
services, err := sp.getMacOSNetworkServices() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
for _, service := range services { |
||||
// 备份HTTP代理设置
|
||||
output, _ := exec.Command("networksetup", "-getwebproxy", service).Output() |
||||
sp.backupConfig[fmt.Sprintf("http_%s", service)] = string(output) |
||||
|
||||
// 备份HTTPS代理设置
|
||||
output, _ = exec.Command("networksetup", "-getsecurewebproxy", service).Output() |
||||
sp.backupConfig[fmt.Sprintf("https_%s", service)] = string(output) |
||||
|
||||
// 备份SOCKS代理设置
|
||||
output, _ = exec.Command("networksetup", "-getsocksfirewallproxy", service).Output() |
||||
sp.backupConfig[fmt.Sprintf("socks_%s", service)] = string(output) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (sp *SystemProxy) backupLinuxConfig() error { |
||||
// 备份环境变量
|
||||
sp.backupConfig["http_proxy"] = os.Getenv("http_proxy") |
||||
sp.backupConfig["https_proxy"] = os.Getenv("https_proxy") |
||||
sp.backupConfig["no_proxy"] = os.Getenv("no_proxy") |
||||
return nil |
||||
} |
||||
|
||||
func (sp *SystemProxy) backupWindowsConfig() error { |
||||
// 备份Windows注册表设置
|
||||
output, _ := exec.Command("reg", "query", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyEnable").Output() |
||||
sp.backupConfig["ProxyEnable"] = string(output) |
||||
|
||||
output, _ = exec.Command("reg", "query", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyServer").Output() |
||||
sp.backupConfig["ProxyServer"] = string(output) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (sp *SystemProxy) restoreMacOSProxy() error { |
||||
services, err := sp.getMacOSNetworkServices() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
for _, service := range services { |
||||
// 禁用所有代理
|
||||
sp.runCommand("networksetup", "-setwebproxystate", service, "off") |
||||
sp.runCommand("networksetup", "-setsecurewebproxystate", service, "off") |
||||
sp.runCommand("networksetup", "-setsocksfirewallproxystate", service, "off") |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (sp *SystemProxy) restoreLinuxProxy() error { |
||||
// 清除环境变量(这里简化处理)
|
||||
os.Unsetenv("http_proxy") |
||||
os.Unsetenv("https_proxy") |
||||
os.Unsetenv("no_proxy") |
||||
os.Unsetenv("HTTP_PROXY") |
||||
os.Unsetenv("HTTPS_PROXY") |
||||
os.Unsetenv("NO_PROXY") |
||||
return nil |
||||
} |
||||
|
||||
func (sp *SystemProxy) restoreWindowsProxy() error { |
||||
// 禁用代理
|
||||
return sp.runCommand("reg", "add", `HKCU\Software\Microsoft\Windows\CurrentVersion\Internet Settings`, "/v", "ProxyEnable", "/t", "REG_DWORD", "/d", "0", "/f") |
||||
} |
||||
|
||||
func (sp *SystemProxy) writeLinuxProxyConfig(envVars map[string]string) error { |
||||
// 写入到 ~/.bashrc 和 ~/.profile
|
||||
homeDir, err := os.UserHomeDir() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
configFiles := []string{ |
||||
homeDir + "/.bashrc", |
||||
homeDir + "/.profile", |
||||
} |
||||
|
||||
proxyLines := []string{"\n# Wormhole SOCKS5 Proxy Configuration"} |
||||
for key, value := range envVars { |
||||
proxyLines = append(proxyLines, fmt.Sprintf("export %s=%s", key, value)) |
||||
} |
||||
proxyLines = append(proxyLines, "# End Wormhole Configuration\n") |
||||
|
||||
for _, configFile := range configFiles { |
||||
file, err := os.OpenFile(configFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) |
||||
if err != nil { |
||||
sp.logger.WithError(err).WithField("file", configFile).Warn("Failed to open config file") |
||||
continue |
||||
} |
||||
|
||||
for _, line := range proxyLines { |
||||
if _, err := file.WriteString(line + "\n"); err != nil { |
||||
sp.logger.WithError(err).WithField("file", configFile).Warn("Failed to write to config file") |
||||
} |
||||
} |
||||
|
||||
file.Close() |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (sp *SystemProxy) runCommand(name string, args ...string) error { |
||||
cmd := exec.Command(name, args...) |
||||
output, err := cmd.CombinedOutput() |
||||
if err != nil { |
||||
sp.logger.WithFields(logrus.Fields{ |
||||
"command": fmt.Sprintf("%s %s", name, strings.Join(args, " ")), |
||||
"output": string(output), |
||||
}).WithError(err).Debug("Command execution failed") |
||||
return err |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// GetCurrentConfig 获取当前系统代理配置
|
||||
func (sp *SystemProxy) GetCurrentConfig() (Config, error) { |
||||
config := Config{ |
||||
NoProxy: []string{}, |
||||
} |
||||
|
||||
switch runtime.GOOS { |
||||
case "linux": |
||||
config.HTTPProxy = os.Getenv("http_proxy") |
||||
config.HTTPSProxy = os.Getenv("https_proxy") |
||||
noProxy := os.Getenv("no_proxy") |
||||
if noProxy != "" { |
||||
config.NoProxy = strings.Split(noProxy, ",") |
||||
} |
||||
} |
||||
|
||||
return config, nil |
||||
} |
||||
|
||||
// IsProxySet 检查是否已设置代理
|
||||
func (sp *SystemProxy) IsProxySet() bool { |
||||
switch runtime.GOOS { |
||||
case "linux": |
||||
return os.Getenv("http_proxy") != "" || os.Getenv("https_proxy") != "" |
||||
case "darwin": |
||||
// 检查macOS代理设置(简化)
|
||||
return false |
||||
case "windows": |
||||
// 检查Windows代理设置(简化)
|
||||
return false |
||||
} |
||||
return false |
||||
} |
Loading…
Reference in new issue