parent
6bad37115b
commit
4c0aeeb96d
@ -0,0 +1,55 @@ |
||||
run: |
||||
timeout: 5m |
||||
tests: true |
||||
|
||||
linters: |
||||
enable: |
||||
- errcheck # 检查未处理的错误 |
||||
- gosimple # 建议代码简化 |
||||
- govet # Go静态分析 |
||||
- ineffassign # 检查无效赋值 |
||||
- staticcheck # 静态分析 |
||||
- typecheck # 类型检查 |
||||
- unused # 检查未使用的代码 |
||||
- gofmt # 检查代码格式 |
||||
- goimports # 检查导入顺序 |
||||
- misspell # 检查拼写错误 |
||||
- unconvert # 检查不必要的类型转换 |
||||
- unparam # 检查未使用的参数 |
||||
- gosec # 安全检查 |
||||
- gocritic # Go代码审查 |
||||
|
||||
linters-settings: |
||||
gosec: |
||||
excludes: |
||||
- G204 # 子进程审计(我们需要执行系统命令) |
||||
|
||||
gocritic: |
||||
enabled-tags: |
||||
- diagnostic |
||||
- style |
||||
- performance |
||||
disabled-checks: |
||||
- paramTypeCombine |
||||
- whyNoLint |
||||
|
||||
issues: |
||||
exclude-rules: |
||||
# 忽略测试文件中的错误检查 |
||||
- path: _test\.go |
||||
linters: |
||||
- errcheck |
||||
- gosec |
||||
|
||||
# 忽略生成的文件 |
||||
- path: ".*\\.pb\\.go" |
||||
linters: |
||||
- all |
||||
|
||||
max-issues-per-linter: 0 |
||||
max-same-issues: 0 |
||||
|
||||
output: |
||||
format: colored-line-number |
||||
print-issued-lines: true |
||||
print-linter-name: true |
@ -0,0 +1,21 @@ |
||||
MIT License |
||||
|
||||
Copyright (c) 2024 Wormhole SOCKS5 Client |
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy |
||||
of this software and associated documentation files (the "Software"), to deal |
||||
in the Software without restriction, including without limitation the rights |
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
||||
copies of the Software, and to permit persons to whom the Software is |
||||
furnished to do so, subject to the following conditions: |
||||
|
||||
The above copyright notice and this permission notice shall be included in all |
||||
copies or substantial portions of the Software. |
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
||||
SOFTWARE. |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,231 @@ |
||||
# Wormhole SOCKS5 Client API |
||||
|
||||
## 概述 |
||||
|
||||
Wormhole SOCKS5 Client 在HTTP代理模式下提供了内置的API端点,用于监控和管理代理服务。 |
||||
|
||||
## 端点 |
||||
|
||||
### 代理功能 |
||||
|
||||
所有非API请求都会通过SOCKS5代理转发: |
||||
|
||||
- **HTTP 代理**: 处理普通HTTP请求 |
||||
- **HTTPS 代理**: 处理CONNECT方法的HTTPS隧道 |
||||
|
||||
### 监控API |
||||
|
||||
#### GET /stats |
||||
|
||||
获取详细的代理统计信息。 |
||||
|
||||
**响应示例:** |
||||
```json |
||||
{ |
||||
"start_time": "2024-01-01T12:00:00Z", |
||||
"uptime": "2h30m15s", |
||||
"total_connections": 150, |
||||
"active_connections": 5, |
||||
"successful_requests": 145, |
||||
"failed_requests": 5, |
||||
"bytes_sent": 1024000, |
||||
"bytes_received": 2048000, |
||||
"socks5_errors": { |
||||
"connection_failed": 3, |
||||
"auth_failed": 1, |
||||
"timeout": 1 |
||||
} |
||||
} |
||||
``` |
||||
|
||||
**字段说明:** |
||||
- `start_time`: 代理启动时间 |
||||
- `uptime`: 运行时间 |
||||
- `total_connections`: 总连接数 |
||||
- `active_connections`: 当前活跃连接数 |
||||
- `successful_requests`: 成功请求数 |
||||
- `failed_requests`: 失败请求数 |
||||
- `bytes_sent`: 发送字节数 |
||||
- `bytes_received`: 接收字节数 |
||||
- `socks5_errors`: SOCKS5错误分类统计 |
||||
|
||||
#### GET /health |
||||
|
||||
获取代理健康状态信息。 |
||||
|
||||
**响应示例:** |
||||
```json |
||||
{ |
||||
"status": "healthy", |
||||
"uptime": "2h30m15s", |
||||
"active_connections": 5, |
||||
"success_rate": 96.67 |
||||
} |
||||
``` |
||||
|
||||
**字段说明:** |
||||
- `status`: 健康状态 ("healthy") |
||||
- `uptime`: 运行时间 |
||||
- `active_connections`: 当前活跃连接数 |
||||
- `success_rate`: 成功率百分比 |
||||
|
||||
## 使用示例 |
||||
|
||||
### 基本用法 |
||||
|
||||
1. **启动HTTP代理模式:** |
||||
```bash |
||||
./bin/wormhole-client -mode http -config configs/client.yaml |
||||
``` |
||||
|
||||
2. **配置浏览器代理:** |
||||
- HTTP代理: `127.0.0.1:8080` |
||||
- HTTPS代理: `127.0.0.1:8080` |
||||
|
||||
3. **访问统计信息:** |
||||
```bash |
||||
curl http://127.0.0.1:8080/stats |
||||
``` |
||||
|
||||
4. **检查健康状态:** |
||||
```bash |
||||
curl http://127.0.0.1:8080/health |
||||
``` |
||||
|
||||
### 监控脚本示例 |
||||
|
||||
#### Bash监控脚本 |
||||
|
||||
```bash |
||||
#!/bin/bash |
||||
|
||||
PROXY_URL="http://127.0.0.1:8080" |
||||
|
||||
echo "=== Wormhole SOCKS5 Client Status ===" |
||||
|
||||
# 健康检查 |
||||
health=$(curl -s "$PROXY_URL/health") |
||||
echo "Health: $health" |
||||
|
||||
# 统计信息 |
||||
stats=$(curl -s "$PROXY_URL/stats") |
||||
echo "Stats: $stats" |
||||
|
||||
# 提取关键指标 |
||||
success_rate=$(echo "$health" | jq -r '.success_rate') |
||||
active_conn=$(echo "$health" | jq -r '.active_connections') |
||||
|
||||
echo "Success Rate: ${success_rate}%" |
||||
echo "Active Connections: $active_conn" |
||||
``` |
||||
|
||||
#### Python监控脚本 |
||||
|
||||
```python |
||||
import requests |
||||
import json |
||||
import time |
||||
|
||||
def get_proxy_stats(): |
||||
try: |
||||
response = requests.get('http://127.0.0.1:8080/stats') |
||||
return response.json() |
||||
except Exception as e: |
||||
print(f"Error getting stats: {e}") |
||||
return None |
||||
|
||||
def get_proxy_health(): |
||||
try: |
||||
response = requests.get('http://127.0.0.1:8080/health') |
||||
return response.json() |
||||
except Exception as e: |
||||
print(f"Error getting health: {e}") |
||||
return None |
||||
|
||||
def monitor_proxy(): |
||||
while True: |
||||
health = get_proxy_health() |
||||
if health: |
||||
print(f"Status: {health['status']}") |
||||
print(f"Success Rate: {health['success_rate']:.2f}%") |
||||
print(f"Active Connections: {health['active_connections']}") |
||||
print(f"Uptime: {health['uptime']}") |
||||
|
||||
print("-" * 40) |
||||
time.sleep(30) # 每30秒检查一次 |
||||
|
||||
if __name__ == "__main__": |
||||
monitor_proxy() |
||||
``` |
||||
|
||||
## SOCKS5协议支持 |
||||
|
||||
### 地址类型支持 |
||||
|
||||
- ✅ **IPv4地址** (0x01): `192.168.1.1:80` |
||||
- ✅ **域名** (0x03): `example.com:80` |
||||
- ✅ **IPv6地址** (0x04): `[2001:db8::1]:80` |
||||
|
||||
### 认证方法支持 |
||||
|
||||
- ✅ **无认证** (0x00) |
||||
- ✅ **用户名密码认证** (0x02) |
||||
|
||||
### 错误处理 |
||||
|
||||
代理能够处理并报告以下SOCKS5错误: |
||||
|
||||
| 错误码 | 含义 | 统计字段 | |
||||
|--------|------|----------| |
||||
| 0x01 | 一般SOCKS服务器故障 | `general_failure` | |
||||
| 0x02 | 连接不被规则集允许 | `not_allowed` | |
||||
| 0x03 | 网络不可达 | `network_unreachable` | |
||||
| 0x04 | 主机不可达 | `host_unreachable` | |
||||
| 0x05 | 连接被拒绝 | `connection_refused` | |
||||
| 0x06 | TTL过期 | `ttl_expired` | |
||||
| 0x07 | 不支持的命令 | `command_not_supported` | |
||||
| 0x08 | 不支持的地址类型 | `address_not_supported` | |
||||
|
||||
## 性能特性 |
||||
|
||||
### 连接管理 |
||||
- 异步处理多个并发连接 |
||||
- 自动资源清理 |
||||
- 超时管理 |
||||
|
||||
### 统计精度 |
||||
- 原子操作保证并发安全 |
||||
- 实时字节传输统计 |
||||
- 详细的错误分类 |
||||
|
||||
### HTTP服务器配置 |
||||
- 读取超时: 30秒 |
||||
- 写入超时: 30秒 |
||||
- 空闲超时: 120秒 |
||||
- 最大头部大小: 1MB |
||||
|
||||
## 故障排除 |
||||
|
||||
### 常见问题 |
||||
|
||||
1. **无法访问统计API** |
||||
- 确认代理运行在HTTP模式 |
||||
- 检查端口是否正确 |
||||
- 确认防火墙设置 |
||||
|
||||
2. **统计数据不准确** |
||||
- 重启代理服务 |
||||
- 检查并发连接情况 |
||||
|
||||
3. **高错误率** |
||||
- 检查SOCKS5服务器状态 |
||||
- 验证认证信息 |
||||
- 查看网络连接质量 |
||||
|
||||
### 日志级别 |
||||
|
||||
设置 `logLevel: debug` 可以获得详细的调试信息: |
||||
|
||||
```yaml |
||||
logLevel: debug # debug, info, warn, error |
||||
``` |
@ -0,0 +1,205 @@ |
||||
package config |
||||
|
||||
import ( |
||||
"os" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestLoadConfig(t *testing.T) { |
||||
// 创建临时配置文件
|
||||
tmpFile, err := os.CreateTemp("", "test_config_*.yaml") |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
defer os.Remove(tmpFile.Name()) |
||||
|
||||
configContent := ` |
||||
serviceType: client |
||||
server: |
||||
address: 127.0.0.1 |
||||
port: 1080 |
||||
username: testuser |
||||
password: testpass |
||||
proxy: |
||||
mode: http |
||||
localPort: 8080 |
||||
globalProxy: |
||||
enabled: true |
||||
dnsProxy: true |
||||
dnsPort: 5353 |
||||
logLevel: debug |
||||
timeout: 60s |
||||
` |
||||
|
||||
if _, err := tmpFile.WriteString(configContent); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
tmpFile.Close() |
||||
|
||||
// 测试加载配置
|
||||
config, err := LoadConfig(tmpFile.Name()) |
||||
if err != nil { |
||||
t.Fatalf("LoadConfig failed: %v", err) |
||||
} |
||||
|
||||
// 验证配置值
|
||||
if config.ServiceType != "client" { |
||||
t.Errorf("Expected ServiceType 'client', got '%s'", config.ServiceType) |
||||
} |
||||
|
||||
if config.Server.Address != "127.0.0.1" { |
||||
t.Errorf("Expected Server.Address '127.0.0.1', got '%s'", config.Server.Address) |
||||
} |
||||
|
||||
if config.Server.Port != 1080 { |
||||
t.Errorf("Expected Server.Port 1080, got %d", config.Server.Port) |
||||
} |
||||
|
||||
if config.Server.Username != "testuser" { |
||||
t.Errorf("Expected Server.Username 'testuser', got '%s'", config.Server.Username) |
||||
} |
||||
|
||||
if config.Server.Password != "testpass" { |
||||
t.Errorf("Expected Server.Password 'testpass', got '%s'", config.Server.Password) |
||||
} |
||||
|
||||
if config.Proxy.LocalPort != 8080 { |
||||
t.Errorf("Expected Proxy.LocalPort 8080, got %d", config.Proxy.LocalPort) |
||||
} |
||||
|
||||
if config.LogLevel != "debug" { |
||||
t.Errorf("Expected LogLevel 'debug', got '%s'", config.LogLevel) |
||||
} |
||||
|
||||
if config.Timeout != 60*time.Second { |
||||
t.Errorf("Expected Timeout 60s, got %v", config.Timeout) |
||||
} |
||||
} |
||||
|
||||
func TestLoadConfigDefaults(t *testing.T) { |
||||
// 创建最小配置
|
||||
tmpFile, err := os.CreateTemp("", "test_config_min_*.yaml") |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
defer os.Remove(tmpFile.Name()) |
||||
|
||||
minimalConfig := ` |
||||
serviceType: client |
||||
server: |
||||
address: 127.0.0.1 |
||||
port: 1080 |
||||
username: user |
||||
password: pass |
||||
` |
||||
|
||||
if _, err := tmpFile.WriteString(minimalConfig); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
tmpFile.Close() |
||||
|
||||
config, err := LoadConfig(tmpFile.Name()) |
||||
if err != nil { |
||||
t.Fatalf("LoadConfig failed: %v", err) |
||||
} |
||||
|
||||
// 验证默认值
|
||||
if config.LogLevel != "info" { |
||||
t.Errorf("Expected default LogLevel 'info', got '%s'", config.LogLevel) |
||||
} |
||||
|
||||
if config.Timeout != 30*time.Second { |
||||
t.Errorf("Expected default Timeout 30s, got %v", config.Timeout) |
||||
} |
||||
|
||||
if config.Proxy.LocalPort != 8080 { |
||||
t.Errorf("Expected default Proxy.LocalPort 8080, got %d", config.Proxy.LocalPort) |
||||
} |
||||
} |
||||
|
||||
func TestGetServerAddr(t *testing.T) { |
||||
config := &Config{ |
||||
Server: Server{ |
||||
Address: "example.com", |
||||
Port: 1080, |
||||
}, |
||||
} |
||||
|
||||
expected := "example.com:1080" |
||||
result := config.GetServerAddr() |
||||
|
||||
if result != expected { |
||||
t.Errorf("Expected '%s', got '%s'", expected, result) |
||||
} |
||||
} |
||||
|
||||
func TestValidate(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
config *Config |
||||
wantErr bool |
||||
}{ |
||||
{ |
||||
name: "valid config", |
||||
config: &Config{ |
||||
Server: Server{ |
||||
Address: "127.0.0.1", |
||||
Port: 1080, |
||||
}, |
||||
Proxy: Proxy{ |
||||
Mode: "http", |
||||
}, |
||||
}, |
||||
wantErr: false, |
||||
}, |
||||
{ |
||||
name: "empty address", |
||||
config: &Config{ |
||||
Server: Server{ |
||||
Address: "", |
||||
Port: 1080, |
||||
}, |
||||
Proxy: Proxy{ |
||||
Mode: "http", |
||||
}, |
||||
}, |
||||
wantErr: true, |
||||
}, |
||||
{ |
||||
name: "invalid port", |
||||
config: &Config{ |
||||
Server: Server{ |
||||
Address: "127.0.0.1", |
||||
Port: -1, |
||||
}, |
||||
Proxy: Proxy{ |
||||
Mode: "http", |
||||
}, |
||||
}, |
||||
wantErr: true, |
||||
}, |
||||
{ |
||||
name: "invalid mode", |
||||
config: &Config{ |
||||
Server: Server{ |
||||
Address: "127.0.0.1", |
||||
Port: 1080, |
||||
}, |
||||
Proxy: Proxy{ |
||||
Mode: "invalid", |
||||
}, |
||||
}, |
||||
wantErr: true, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
err := tt.config.Validate() |
||||
if (err != nil) != tt.wantErr { |
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) |
||||
} |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,277 @@ |
||||
package proxy |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestNewSOCKS5Proxy(t *testing.T) { |
||||
proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) |
||||
|
||||
if proxy == nil { |
||||
t.Fatal("NewSOCKS5Proxy returned nil") |
||||
} |
||||
|
||||
if proxy.serverAddr != "127.0.0.1:1080" { |
||||
t.Errorf("Expected serverAddr '127.0.0.1:1080', got '%s'", proxy.serverAddr) |
||||
} |
||||
|
||||
if proxy.username != "user" { |
||||
t.Errorf("Expected username 'user', got '%s'", proxy.username) |
||||
} |
||||
|
||||
if proxy.password != "pass" { |
||||
t.Errorf("Expected password 'pass', got '%s'", proxy.password) |
||||
} |
||||
|
||||
if proxy.timeout != 30*time.Second { |
||||
t.Errorf("Expected timeout 30s, got %v", proxy.timeout) |
||||
} |
||||
|
||||
// 检查连接池是否正确初始化
|
||||
if proxy.connPool == nil { |
||||
t.Error("Connection pool should be initialized") |
||||
} |
||||
|
||||
if proxy.connPool.maxSize != 10 { |
||||
t.Errorf("Expected connection pool size 10, got %d", proxy.connPool.maxSize) |
||||
} |
||||
} |
||||
|
||||
func TestParsePort(t *testing.T) { |
||||
tests := []struct { |
||||
input string |
||||
expected int |
||||
hasError bool |
||||
}{ |
||||
{"80", 80, false}, |
||||
{"443", 443, false}, |
||||
{"8080", 8080, false}, |
||||
{"", 80, false}, // 默认端口
|
||||
{"0", 0, true}, // 无效端口
|
||||
{"65536", 0, true}, // 端口超出范围
|
||||
{"abc", 0, true}, // 非数字
|
||||
{"-1", 0, true}, // 负数
|
||||
{"65535", 65535, false}, // 最大有效端口
|
||||
{"1", 1, false}, // 最小有效端口
|
||||
} |
||||
|
||||
for _, test := range tests { |
||||
result, err := parsePort(test.input) |
||||
|
||||
if test.hasError { |
||||
if err == nil { |
||||
t.Errorf("Expected error for input '%s', but got none", test.input) |
||||
} |
||||
} else { |
||||
if err != nil { |
||||
t.Errorf("Unexpected error for input '%s': %v", test.input, err) |
||||
} |
||||
if result != test.expected { |
||||
t.Errorf("For input '%s', expected %d, got %d", test.input, test.expected, result) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestCreateHTTPProxy(t *testing.T) { |
||||
proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) |
||||
server := proxy.CreateHTTPProxy(8080) |
||||
|
||||
if server == nil { |
||||
t.Fatal("CreateHTTPProxy returned nil") |
||||
} |
||||
|
||||
if server.Addr != ":8080" { |
||||
t.Errorf("Expected server address ':8080', got '%s'", server.Addr) |
||||
} |
||||
|
||||
// 检查服务器配置
|
||||
if server.ReadTimeout != 30*time.Second { |
||||
t.Errorf("Expected ReadTimeout 30s, got %v", server.ReadTimeout) |
||||
} |
||||
|
||||
if server.WriteTimeout != 30*time.Second { |
||||
t.Errorf("Expected WriteTimeout 30s, got %v", server.WriteTimeout) |
||||
} |
||||
|
||||
if server.IdleTimeout != 120*time.Second { |
||||
t.Errorf("Expected IdleTimeout 120s, got %v", server.IdleTimeout) |
||||
} |
||||
} |
||||
|
||||
func TestBuildConnectRequest(t *testing.T) { |
||||
proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) |
||||
|
||||
tests := []struct { |
||||
name string |
||||
targetAddr string |
||||
expectError bool |
||||
expectedType byte // 地址类型: 1=IPv4, 3=域名, 4=IPv6
|
||||
}{ |
||||
{ |
||||
name: "IPv4 address", |
||||
targetAddr: "192.168.1.1:80", |
||||
expectError: false, |
||||
expectedType: 0x01, |
||||
}, |
||||
{ |
||||
name: "IPv6 address", |
||||
targetAddr: "[2001:db8::1]:80", |
||||
expectError: false, |
||||
expectedType: 0x04, |
||||
}, |
||||
{ |
||||
name: "Domain name", |
||||
targetAddr: "example.com:80", |
||||
expectError: false, |
||||
expectedType: 0x03, |
||||
}, |
||||
{ |
||||
name: "Domain with HTTPS port", |
||||
targetAddr: "example.com:443", |
||||
expectError: false, |
||||
expectedType: 0x03, |
||||
}, |
||||
{ |
||||
name: "Invalid address format", |
||||
targetAddr: "invalid", |
||||
expectError: true, |
||||
}, |
||||
{ |
||||
name: "Invalid port", |
||||
targetAddr: "example.com:invalid", |
||||
expectError: true, |
||||
}, |
||||
{ |
||||
name: "Port out of range", |
||||
targetAddr: "example.com:70000", |
||||
expectError: true, |
||||
}, |
||||
} |
||||
|
||||
for _, test := range tests { |
||||
t.Run(test.name, func(t *testing.T) { |
||||
req, err := proxy.buildConnectRequest(test.targetAddr) |
||||
|
||||
if test.expectError { |
||||
if err == nil { |
||||
t.Errorf("Expected error for %s, but got none", test.targetAddr) |
||||
} |
||||
return |
||||
} |
||||
|
||||
if err != nil { |
||||
t.Errorf("Unexpected error for %s: %v", test.targetAddr, err) |
||||
return |
||||
} |
||||
|
||||
if len(req) < 4 { |
||||
t.Errorf("Request too short: %d bytes", len(req)) |
||||
return |
||||
} |
||||
|
||||
// 检查SOCKS版本
|
||||
if req[0] != 0x05 { |
||||
t.Errorf("Expected SOCKS version 5, got %d", req[0]) |
||||
} |
||||
|
||||
// 检查命令类型
|
||||
if req[1] != 0x01 { |
||||
t.Errorf("Expected CONNECT command (1), got %d", req[1]) |
||||
} |
||||
|
||||
// 检查保留字段
|
||||
if req[2] != 0x00 { |
||||
t.Errorf("Expected reserved field to be 0, got %d", req[2]) |
||||
} |
||||
|
||||
// 检查地址类型
|
||||
if req[3] != test.expectedType { |
||||
t.Errorf("Expected address type %d, got %d", test.expectedType, req[3]) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestGetSOCKS5ErrorMessage(t *testing.T) { |
||||
tests := []struct { |
||||
code byte |
||||
expected string |
||||
}{ |
||||
{0x01, "general SOCKS server failure"}, |
||||
{0x02, "connection not allowed by ruleset"}, |
||||
{0x03, "network unreachable"}, |
||||
{0x04, "host unreachable"}, |
||||
{0x05, "connection refused"}, |
||||
{0x06, "TTL expired"}, |
||||
{0x07, "command not supported"}, |
||||
{0x08, "address type not supported"}, |
||||
{0xFF, "unknown error"}, |
||||
} |
||||
|
||||
for _, test := range tests { |
||||
result := getSOCKS5ErrorMessage(test.code) |
||||
if result != test.expected { |
||||
t.Errorf("For code %d, expected '%s', got '%s'", test.code, test.expected, result) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestDialTCPWithContext(t *testing.T) { |
||||
proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 5*time.Second) |
||||
|
||||
// 测试超时上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) |
||||
defer cancel() |
||||
|
||||
// 由于没有真实的SOCKS5服务器,这应该会超时或连接失败
|
||||
_, err := proxy.DialTCPWithContext(ctx, "example.com:80") |
||||
if err == nil { |
||||
t.Error("Expected error when connecting to non-existent SOCKS5 server") |
||||
} |
||||
} |
||||
|
||||
func TestHTTPProxyHandler(t *testing.T) { |
||||
proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) |
||||
handler := &httpProxyHandler{socks5Proxy: proxy} |
||||
|
||||
if handler.socks5Proxy != proxy { |
||||
t.Error("Handler should have reference to SOCKS5 proxy") |
||||
} |
||||
} |
||||
|
||||
// BenchmarkParsePort 性能测试
|
||||
func BenchmarkParsePort(b *testing.B) { |
||||
ports := []string{"80", "443", "8080", "3000", "9999"} |
||||
|
||||
for i := 0; i < b.N; i++ { |
||||
port := ports[i%len(ports)] |
||||
_, err := parsePort(port) |
||||
if err != nil { |
||||
b.Errorf("Unexpected error: %v", err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// TestIPv6AddressHandling 测试IPv6地址处理
|
||||
func TestIPv6AddressHandling(t *testing.T) { |
||||
proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) |
||||
|
||||
// 测试完整的IPv6地址
|
||||
req, err := proxy.buildConnectRequest("[2001:db8:85a3::8a2e:370:7334]:443") |
||||
if err != nil { |
||||
t.Fatalf("Failed to build connect request for IPv6: %v", err) |
||||
} |
||||
|
||||
if req[3] != 0x04 { |
||||
t.Errorf("Expected IPv6 address type (4), got %d", req[3]) |
||||
} |
||||
|
||||
// IPv6地址应该是16字节 + 头部4字节 + 端口2字节 = 22字节
|
||||
expectedLen := 4 + 16 + 2 // 头部 + IPv6地址 + 端口
|
||||
if len(req) != expectedLen { |
||||
t.Errorf("Expected request length %d for IPv6, got %d", expectedLen, len(req)) |
||||
} |
||||
} |
@ -0,0 +1,121 @@ |
||||
package proxy |
||||
|
||||
import ( |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
) |
||||
|
||||
// ProxyStats 代理统计信息
|
||||
type ProxyStats struct { |
||||
StartTime time.Time |
||||
TotalConnections int64 |
||||
ActiveConnections int64 |
||||
SuccessfulRequests int64 |
||||
FailedRequests int64 |
||||
BytesSent int64 |
||||
BytesReceived int64 |
||||
SOCKS5Errors map[string]int64 |
||||
mutex sync.RWMutex |
||||
} |
||||
|
||||
// NewProxyStats 创建新的统计实例
|
||||
func NewProxyStats() *ProxyStats { |
||||
return &ProxyStats{ |
||||
StartTime: time.Now(), |
||||
SOCKS5Errors: make(map[string]int64), |
||||
} |
||||
} |
||||
|
||||
// IncrementConnections 增加连接计数
|
||||
func (s *ProxyStats) IncrementConnections() { |
||||
atomic.AddInt64(&s.TotalConnections, 1) |
||||
atomic.AddInt64(&s.ActiveConnections, 1) |
||||
} |
||||
|
||||
// DecrementActiveConnections 减少活跃连接计数
|
||||
func (s *ProxyStats) DecrementActiveConnections() { |
||||
atomic.AddInt64(&s.ActiveConnections, -1) |
||||
} |
||||
|
||||
// IncrementSuccessfulRequests 增加成功请求计数
|
||||
func (s *ProxyStats) IncrementSuccessfulRequests() { |
||||
atomic.AddInt64(&s.SuccessfulRequests, 1) |
||||
} |
||||
|
||||
// IncrementFailedRequests 增加失败请求计数
|
||||
func (s *ProxyStats) IncrementFailedRequests() { |
||||
atomic.AddInt64(&s.FailedRequests, 1) |
||||
} |
||||
|
||||
// AddBytesTransferred 添加传输字节数
|
||||
func (s *ProxyStats) AddBytesTransferred(sent, received int64) { |
||||
atomic.AddInt64(&s.BytesSent, sent) |
||||
atomic.AddInt64(&s.BytesReceived, received) |
||||
} |
||||
|
||||
// IncrementSOCKS5Error 增加SOCKS5错误计数
|
||||
func (s *ProxyStats) IncrementSOCKS5Error(errorType string) { |
||||
s.mutex.Lock() |
||||
defer s.mutex.Unlock() |
||||
s.SOCKS5Errors[errorType]++ |
||||
} |
||||
|
||||
// GetStats 获取统计快照
|
||||
func (s *ProxyStats) GetStats() ProxyStatsSnapshot { |
||||
s.mutex.RLock() |
||||
defer s.mutex.RUnlock() |
||||
|
||||
errors := make(map[string]int64) |
||||
for k, v := range s.SOCKS5Errors { |
||||
errors[k] = v |
||||
} |
||||
|
||||
return ProxyStatsSnapshot{ |
||||
StartTime: s.StartTime, |
||||
Uptime: time.Since(s.StartTime), |
||||
TotalConnections: atomic.LoadInt64(&s.TotalConnections), |
||||
ActiveConnections: atomic.LoadInt64(&s.ActiveConnections), |
||||
SuccessfulRequests: atomic.LoadInt64(&s.SuccessfulRequests), |
||||
FailedRequests: atomic.LoadInt64(&s.FailedRequests), |
||||
BytesSent: atomic.LoadInt64(&s.BytesSent), |
||||
BytesReceived: atomic.LoadInt64(&s.BytesReceived), |
||||
SOCKS5Errors: errors, |
||||
} |
||||
} |
||||
|
||||
// ProxyStatsSnapshot 统计快照
|
||||
type ProxyStatsSnapshot struct { |
||||
StartTime time.Time `json:"start_time"` |
||||
Uptime time.Duration `json:"uptime"` |
||||
TotalConnections int64 `json:"total_connections"` |
||||
ActiveConnections int64 `json:"active_connections"` |
||||
SuccessfulRequests int64 `json:"successful_requests"` |
||||
FailedRequests int64 `json:"failed_requests"` |
||||
BytesSent int64 `json:"bytes_sent"` |
||||
BytesReceived int64 `json:"bytes_received"` |
||||
SOCKS5Errors map[string]int64 `json:"socks5_errors"` |
||||
} |
||||
|
||||
// GetSuccessRate 获取成功率
|
||||
func (s *ProxyStatsSnapshot) GetSuccessRate() float64 { |
||||
total := s.SuccessfulRequests + s.FailedRequests |
||||
if total == 0 { |
||||
return 0 |
||||
} |
||||
return float64(s.SuccessfulRequests) / float64(total) * 100 |
||||
} |
||||
|
||||
// GetTotalBytes 获取总传输字节数
|
||||
func (s *ProxyStatsSnapshot) GetTotalBytes() int64 { |
||||
return s.BytesSent + s.BytesReceived |
||||
} |
||||
|
||||
// GetAverageConnectionsPerHour 获取每小时平均连接数
|
||||
func (s *ProxyStatsSnapshot) GetAverageConnectionsPerHour() float64 { |
||||
hours := s.Uptime.Hours() |
||||
if hours == 0 { |
||||
return 0 |
||||
} |
||||
return float64(s.TotalConnections) / hours |
||||
} |
@ -0,0 +1,224 @@ |
||||
package proxy |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestNewProxyStats(t *testing.T) { |
||||
stats := NewProxyStats() |
||||
|
||||
if stats == nil { |
||||
t.Fatal("NewProxyStats returned nil") |
||||
} |
||||
|
||||
if stats.SOCKS5Errors == nil { |
||||
t.Error("SOCKS5Errors map should be initialized") |
||||
} |
||||
|
||||
if time.Since(stats.StartTime) > time.Second { |
||||
t.Error("StartTime should be recent") |
||||
} |
||||
} |
||||
|
||||
func TestProxyStatsCounters(t *testing.T) { |
||||
stats := NewProxyStats() |
||||
|
||||
// 测试连接计数
|
||||
stats.IncrementConnections() |
||||
stats.IncrementConnections() |
||||
|
||||
snapshot := stats.GetStats() |
||||
if snapshot.TotalConnections != 2 { |
||||
t.Errorf("Expected 2 total connections, got %d", snapshot.TotalConnections) |
||||
} |
||||
|
||||
if snapshot.ActiveConnections != 2 { |
||||
t.Errorf("Expected 2 active connections, got %d", snapshot.ActiveConnections) |
||||
} |
||||
|
||||
// 测试减少活跃连接
|
||||
stats.DecrementActiveConnections() |
||||
snapshot = stats.GetStats() |
||||
if snapshot.ActiveConnections != 1 { |
||||
t.Errorf("Expected 1 active connection, got %d", snapshot.ActiveConnections) |
||||
} |
||||
|
||||
// 测试请求计数
|
||||
stats.IncrementSuccessfulRequests() |
||||
stats.IncrementSuccessfulRequests() |
||||
stats.IncrementFailedRequests() |
||||
|
||||
snapshot = stats.GetStats() |
||||
if snapshot.SuccessfulRequests != 2 { |
||||
t.Errorf("Expected 2 successful requests, got %d", snapshot.SuccessfulRequests) |
||||
} |
||||
|
||||
if snapshot.FailedRequests != 1 { |
||||
t.Errorf("Expected 1 failed request, got %d", snapshot.FailedRequests) |
||||
} |
||||
} |
||||
|
||||
func TestProxyStatsBytesTransferred(t *testing.T) { |
||||
stats := NewProxyStats() |
||||
|
||||
stats.AddBytesTransferred(100, 200) |
||||
stats.AddBytesTransferred(50, 75) |
||||
|
||||
snapshot := stats.GetStats() |
||||
if snapshot.BytesSent != 150 { |
||||
t.Errorf("Expected 150 bytes sent, got %d", snapshot.BytesSent) |
||||
} |
||||
|
||||
if snapshot.BytesReceived != 275 { |
||||
t.Errorf("Expected 275 bytes received, got %d", snapshot.BytesReceived) |
||||
} |
||||
|
||||
totalBytes := snapshot.GetTotalBytes() |
||||
if totalBytes != 425 { |
||||
t.Errorf("Expected 425 total bytes, got %d", totalBytes) |
||||
} |
||||
} |
||||
|
||||
func TestProxyStatsSOCKS5Errors(t *testing.T) { |
||||
stats := NewProxyStats() |
||||
|
||||
stats.IncrementSOCKS5Error("connection_failed") |
||||
stats.IncrementSOCKS5Error("auth_failed") |
||||
stats.IncrementSOCKS5Error("connection_failed") |
||||
|
||||
snapshot := stats.GetStats() |
||||
|
||||
if snapshot.SOCKS5Errors["connection_failed"] != 2 { |
||||
t.Errorf("Expected 2 connection_failed errors, got %d", |
||||
snapshot.SOCKS5Errors["connection_failed"]) |
||||
} |
||||
|
||||
if snapshot.SOCKS5Errors["auth_failed"] != 1 { |
||||
t.Errorf("Expected 1 auth_failed error, got %d", |
||||
snapshot.SOCKS5Errors["auth_failed"]) |
||||
} |
||||
} |
||||
|
||||
func TestProxyStatsSnapshot(t *testing.T) { |
||||
stats := NewProxyStats() |
||||
|
||||
// 添加一些测试数据
|
||||
stats.IncrementSuccessfulRequests() |
||||
stats.IncrementSuccessfulRequests() |
||||
stats.IncrementSuccessfulRequests() |
||||
stats.IncrementFailedRequests() |
||||
|
||||
snapshot := stats.GetStats() |
||||
|
||||
// 测试成功率计算
|
||||
successRate := snapshot.GetSuccessRate() |
||||
expected := 75.0 // 3 successful out of 4 total = 75%
|
||||
if successRate != expected { |
||||
t.Errorf("Expected success rate %.1f%%, got %.1f%%", expected, successRate) |
||||
} |
||||
|
||||
// 测试零请求时的成功率
|
||||
emptyStats := NewProxyStats() |
||||
emptySnapshot := emptyStats.GetStats() |
||||
if emptySnapshot.GetSuccessRate() != 0 { |
||||
t.Errorf("Expected 0%% success rate for empty stats, got %.1f%%", |
||||
emptySnapshot.GetSuccessRate()) |
||||
} |
||||
} |
||||
|
||||
func TestProxyStatsAverageConnections(t *testing.T) { |
||||
stats := NewProxyStats() |
||||
|
||||
// 由于uptime很短,我们模拟一些连接
|
||||
stats.IncrementConnections() |
||||
stats.IncrementConnections() |
||||
|
||||
snapshot := stats.GetStats() |
||||
avg := snapshot.GetAverageConnectionsPerHour() |
||||
|
||||
// 应该大于0,具体值取决于测试运行的时间
|
||||
if avg <= 0 { |
||||
t.Error("Average connections per hour should be greater than 0") |
||||
} |
||||
} |
||||
|
||||
func TestConcurrentStatsAccess(t *testing.T) { |
||||
stats := NewProxyStats() |
||||
|
||||
// 并发测试
|
||||
done := make(chan bool) |
||||
|
||||
// 启动多个goroutine来并发更新统计
|
||||
for i := 0; i < 10; i++ { |
||||
go func() { |
||||
for j := 0; j < 100; j++ { |
||||
stats.IncrementConnections() |
||||
stats.IncrementSuccessfulRequests() |
||||
stats.AddBytesTransferred(10, 20) |
||||
stats.IncrementSOCKS5Error("test_error") |
||||
} |
||||
done <- true |
||||
}() |
||||
} |
||||
|
||||
// 等待所有goroutine完成
|
||||
for i := 0; i < 10; i++ { |
||||
<-done |
||||
} |
||||
|
||||
snapshot := stats.GetStats() |
||||
|
||||
// 验证最终计数
|
||||
expectedConnections := int64(1000) // 10 goroutines * 100 increments
|
||||
if snapshot.TotalConnections != expectedConnections { |
||||
t.Errorf("Expected %d total connections, got %d", |
||||
expectedConnections, snapshot.TotalConnections) |
||||
} |
||||
|
||||
if snapshot.SuccessfulRequests != expectedConnections { |
||||
t.Errorf("Expected %d successful requests, got %d", |
||||
expectedConnections, snapshot.SuccessfulRequests) |
||||
} |
||||
|
||||
expectedBytesSent := int64(10000) // 10 * 100 * 10
|
||||
if snapshot.BytesSent != expectedBytesSent { |
||||
t.Errorf("Expected %d bytes sent, got %d", |
||||
expectedBytesSent, snapshot.BytesSent) |
||||
} |
||||
|
||||
if snapshot.SOCKS5Errors["test_error"] != expectedConnections { |
||||
t.Errorf("Expected %d test_error occurrences, got %d", |
||||
expectedConnections, snapshot.SOCKS5Errors["test_error"]) |
||||
} |
||||
} |
||||
|
||||
// BenchmarkStatsOperations 性能测试
|
||||
func BenchmarkStatsOperations(b *testing.B) { |
||||
stats := NewProxyStats() |
||||
|
||||
b.ResetTimer() |
||||
|
||||
for i := 0; i < b.N; i++ { |
||||
stats.IncrementConnections() |
||||
stats.AddBytesTransferred(100, 200) |
||||
stats.IncrementSuccessfulRequests() |
||||
} |
||||
} |
||||
|
||||
func BenchmarkStatsSnapshot(b *testing.B) { |
||||
stats := NewProxyStats() |
||||
|
||||
// 添加一些数据
|
||||
for i := 0; i < 100; i++ { |
||||
stats.IncrementConnections() |
||||
stats.AddBytesTransferred(100, 200) |
||||
stats.IncrementSOCKS5Error("test_error") |
||||
} |
||||
|
||||
b.ResetTimer() |
||||
|
||||
for i := 0; i < b.N; i++ { |
||||
_ = stats.GetStats() |
||||
} |
||||
} |
@ -0,0 +1,263 @@ |
||||
package routing |
||||
|
||||
import ( |
||||
"net" |
||||
"regexp" |
||||
"strings" |
||||
|
||||
"github.com/azoic/wormhole-client/internal/config" |
||||
"github.com/azoic/wormhole-client/pkg/logger" |
||||
) |
||||
|
||||
// RouteMatcher 路由匹配器
|
||||
type RouteMatcher struct { |
||||
config *config.Routing |
||||
bypassDomains []*regexp.Regexp |
||||
forceDomains []*regexp.Regexp |
||||
privateNetworks []*net.IPNet |
||||
} |
||||
|
||||
// MatchResult 匹配结果
|
||||
type MatchResult int |
||||
|
||||
const ( |
||||
// MatchBypass 直连(绕过代理)
|
||||
MatchBypass MatchResult = iota |
||||
// MatchProxy 代理
|
||||
MatchProxy |
||||
// MatchAuto 自动决定
|
||||
MatchAuto |
||||
) |
||||
|
||||
// NewRouteMatcher 创建路由匹配器
|
||||
func NewRouteMatcher(config *config.Routing) (*RouteMatcher, error) { |
||||
matcher := &RouteMatcher{ |
||||
config: config, |
||||
} |
||||
|
||||
// 编译域名规则
|
||||
if err := matcher.compilePatterns(); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// 初始化私有网络列表
|
||||
matcher.initPrivateNetworks() |
||||
|
||||
logger.Debug("Route matcher initialized with %d bypass domains, %d force domains", |
||||
len(matcher.bypassDomains), len(matcher.forceDomains)) |
||||
|
||||
return matcher, nil |
||||
} |
||||
|
||||
// Match 匹配主机地址,返回路由决策
|
||||
func (rm *RouteMatcher) Match(host string) MatchResult { |
||||
// 去除端口
|
||||
if hostOnly, _, err := net.SplitHostPort(host); err == nil { |
||||
host = hostOnly |
||||
} |
||||
|
||||
logger.Debug("Matching route for host: %s", host) |
||||
|
||||
// 1. 检查强制代理域名
|
||||
if rm.matchesForceDomains(host) { |
||||
logger.Debug("Host %s matches force domains - using proxy", host) |
||||
return MatchProxy |
||||
} |
||||
|
||||
// 2. 检查绕过域名
|
||||
if rm.matchesBypassDomains(host) { |
||||
logger.Debug("Host %s matches bypass domains - using direct", host) |
||||
return MatchBypass |
||||
} |
||||
|
||||
// 3. 检查是否为IP地址
|
||||
if ip := net.ParseIP(host); ip != nil { |
||||
return rm.matchIP(ip) |
||||
} |
||||
|
||||
// 4. 检查本地域名
|
||||
if rm.config.BypassLocal && rm.isLocalDomain(host) { |
||||
logger.Debug("Host %s is local domain - using direct", host) |
||||
return MatchBypass |
||||
} |
||||
|
||||
// 5. 默认策略:自动决定或代理
|
||||
logger.Debug("Host %s no specific rule - using auto", host) |
||||
return MatchAuto |
||||
} |
||||
|
||||
// matchesForceDomains 检查是否匹配强制代理域名
|
||||
func (rm *RouteMatcher) matchesForceDomains(host string) bool { |
||||
for _, pattern := range rm.forceDomains { |
||||
if pattern.MatchString(host) { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
// matchesBypassDomains 检查是否匹配绕过域名
|
||||
func (rm *RouteMatcher) matchesBypassDomains(host string) bool { |
||||
for _, pattern := range rm.bypassDomains { |
||||
if pattern.MatchString(host) { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
// matchIP 匹配IP地址
|
||||
func (rm *RouteMatcher) matchIP(ip net.IP) MatchResult { |
||||
// 检查本地IP
|
||||
if rm.config.BypassLocal && rm.isLocalIP(ip) { |
||||
logger.Debug("IP %s is local - using direct", ip.String()) |
||||
return MatchBypass |
||||
} |
||||
|
||||
// 检查私有网络
|
||||
if rm.config.BypassPrivate && rm.isPrivateIP(ip) { |
||||
logger.Debug("IP %s is private - using direct", ip.String()) |
||||
return MatchBypass |
||||
} |
||||
|
||||
return MatchAuto |
||||
} |
||||
|
||||
// isLocalIP 检查是否为本地IP
|
||||
func (rm *RouteMatcher) isLocalIP(ip net.IP) bool { |
||||
// 环回地址
|
||||
if ip.IsLoopback() { |
||||
return true |
||||
} |
||||
|
||||
// 链路本地地址
|
||||
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
// isPrivateIP 检查是否为私有IP
|
||||
func (rm *RouteMatcher) isPrivateIP(ip net.IP) bool { |
||||
for _, network := range rm.privateNetworks { |
||||
if network.Contains(ip) { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
// isLocalDomain 检查是否为本地域名
|
||||
func (rm *RouteMatcher) isLocalDomain(host string) bool { |
||||
host = strings.ToLower(host) |
||||
|
||||
// 常见本地域名
|
||||
localSuffixes := []string{ |
||||
".local", |
||||
".localhost", |
||||
".lan", |
||||
".internal", |
||||
".intranet", |
||||
".home", |
||||
".corp", |
||||
} |
||||
|
||||
for _, suffix := range localSuffixes { |
||||
if strings.HasSuffix(host, suffix) { |
||||
return true |
||||
} |
||||
} |
||||
|
||||
// 单词域名(无点)
|
||||
if !strings.Contains(host, ".") { |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
||||
|
||||
// compilePatterns 编译域名匹配模式
|
||||
func (rm *RouteMatcher) compilePatterns() error { |
||||
// 编译绕过域名模式
|
||||
for _, domain := range rm.config.BypassDomains { |
||||
pattern, err := rm.domainToRegexp(domain) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
rm.bypassDomains = append(rm.bypassDomains, pattern) |
||||
} |
||||
|
||||
// 编译强制代理域名模式
|
||||
for _, domain := range rm.config.ForceDomains { |
||||
pattern, err := rm.domainToRegexp(domain) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
rm.forceDomains = append(rm.forceDomains, pattern) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// domainToRegexp 将域名模式转换为正则表达式
|
||||
func (rm *RouteMatcher) domainToRegexp(domain string) (*regexp.Regexp, error) { |
||||
// 转义特殊字符
|
||||
pattern := regexp.QuoteMeta(domain) |
||||
|
||||
// 替换通配符
|
||||
pattern = strings.ReplaceAll(pattern, "\\*", ".*") |
||||
|
||||
// 添加行开始和结束标记
|
||||
pattern = "^" + pattern + "$" |
||||
|
||||
// 编译正则表达式(不区分大小写)
|
||||
return regexp.Compile("(?i)" + pattern) |
||||
} |
||||
|
||||
// initPrivateNetworks 初始化私有网络列表
|
||||
func (rm *RouteMatcher) initPrivateNetworks() { |
||||
privateNetworks := []string{ |
||||
"10.0.0.0/8", // Class A private
|
||||
"172.16.0.0/12", // Class B private
|
||||
"192.168.0.0/16", // Class C private
|
||||
"169.254.0.0/16", // Link-local
|
||||
"127.0.0.0/8", // Loopback
|
||||
"224.0.0.0/4", // Multicast
|
||||
"240.0.0.0/4", // Reserved
|
||||
"::1/128", // IPv6 loopback
|
||||
"fe80::/10", // IPv6 link-local
|
||||
"fc00::/7", // IPv6 unique local
|
||||
} |
||||
|
||||
for _, network := range privateNetworks { |
||||
if _, ipNet, err := net.ParseCIDR(network); err == nil { |
||||
rm.privateNetworks = append(rm.privateNetworks, ipNet) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// GetStats 获取路由统计信息
|
||||
func (rm *RouteMatcher) GetStats() map[string]interface{} { |
||||
return map[string]interface{}{ |
||||
"bypass_domains_count": len(rm.bypassDomains), |
||||
"force_domains_count": len(rm.forceDomains), |
||||
"private_networks_count": len(rm.privateNetworks), |
||||
"bypass_local": rm.config.BypassLocal, |
||||
"bypass_private": rm.config.BypassPrivate, |
||||
} |
||||
} |
||||
|
||||
// ReloadConfig 重新加载配置
|
||||
func (rm *RouteMatcher) ReloadConfig(config *config.Routing) error { |
||||
rm.config = config |
||||
rm.bypassDomains = nil |
||||
rm.forceDomains = nil |
||||
|
||||
if err := rm.compilePatterns(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
logger.Info("Route matcher configuration reloaded") |
||||
return nil |
||||
} |
@ -0,0 +1,329 @@ |
||||
package routing |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/azoic/wormhole-client/internal/config" |
||||
) |
||||
|
||||
func TestNewRouteMatcher(t *testing.T) { |
||||
cfg := &config.Routing{ |
||||
BypassLocal: true, |
||||
BypassPrivate: true, |
||||
BypassDomains: []string{"*.local", "*.lan"}, |
||||
ForceDomains: []string{"*.google.com", "*.github.com"}, |
||||
} |
||||
|
||||
matcher, err := NewRouteMatcher(cfg) |
||||
if err != nil { |
||||
t.Fatalf("Failed to create route matcher: %v", err) |
||||
} |
||||
|
||||
if matcher == nil { |
||||
t.Fatal("Route matcher should not be nil") |
||||
} |
||||
|
||||
if len(matcher.bypassDomains) != 2 { |
||||
t.Errorf("Expected 2 bypass domains, got %d", len(matcher.bypassDomains)) |
||||
} |
||||
|
||||
if len(matcher.forceDomains) != 2 { |
||||
t.Errorf("Expected 2 force domains, got %d", len(matcher.forceDomains)) |
||||
} |
||||
} |
||||
|
||||
func TestRouteMatcher_Match(t *testing.T) { |
||||
cfg := &config.Routing{ |
||||
BypassLocal: true, |
||||
BypassPrivate: true, |
||||
BypassDomains: []string{"*.local", "*.baidu.com"}, |
||||
ForceDomains: []string{"*.google.com", "*.github.com"}, |
||||
} |
||||
|
||||
matcher, err := NewRouteMatcher(cfg) |
||||
if err != nil { |
||||
t.Fatalf("Failed to create route matcher: %v", err) |
||||
} |
||||
|
||||
tests := []struct { |
||||
host string |
||||
expected MatchResult |
||||
}{ |
||||
// 强制代理域名
|
||||
{"www.google.com", MatchProxy}, |
||||
{"api.github.com", MatchProxy}, |
||||
{"www.google.com:443", MatchProxy}, |
||||
|
||||
// 绕过域名
|
||||
{"test.local", MatchBypass}, |
||||
{"www.baidu.com", MatchBypass}, |
||||
{"search.baidu.com:80", MatchBypass}, |
||||
|
||||
// 本地域名
|
||||
{"localhost", MatchBypass}, |
||||
{"router.lan", MatchBypass}, |
||||
{"printer.internal", MatchBypass}, |
||||
|
||||
// IP地址 - 本地
|
||||
{"127.0.0.1", MatchBypass}, |
||||
{"::1", MatchBypass}, |
||||
|
||||
// IP地址 - 私有网络
|
||||
{"192.168.1.1", MatchBypass}, |
||||
{"10.0.0.1", MatchBypass}, |
||||
{"172.16.1.1", MatchBypass}, |
||||
|
||||
// 其他域名 - 自动决定
|
||||
{"example.com", MatchAuto}, |
||||
{"stackoverflow.com", MatchAuto}, |
||||
} |
||||
|
||||
for _, test := range tests { |
||||
t.Run(test.host, func(t *testing.T) { |
||||
result := matcher.Match(test.host) |
||||
if result != test.expected { |
||||
t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestRouteMatcher_MatchesForceDomains(t *testing.T) { |
||||
cfg := &config.Routing{ |
||||
ForceDomains: []string{"*.google.com", "github.com", "*.example.*"}, |
||||
} |
||||
|
||||
matcher, err := NewRouteMatcher(cfg) |
||||
if err != nil { |
||||
t.Fatalf("Failed to create route matcher: %v", err) |
||||
} |
||||
|
||||
tests := []struct { |
||||
host string |
||||
expected bool |
||||
}{ |
||||
{"www.google.com", true}, |
||||
{"api.google.com", true}, |
||||
{"google.com", false}, // 不匹配 *.google.com
|
||||
{"github.com", true}, |
||||
{"api.github.com", false}, // 不匹配 github.com
|
||||
{"test.example.org", true}, |
||||
{"sub.example.net", true}, |
||||
{"example.com", false}, // 不匹配 *.example.*
|
||||
{"other.com", false}, |
||||
} |
||||
|
||||
for _, test := range tests { |
||||
t.Run(test.host, func(t *testing.T) { |
||||
result := matcher.matchesForceDomains(test.host) |
||||
if result != test.expected { |
||||
t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestRouteMatcher_MatchesBypassDomains(t *testing.T) { |
||||
cfg := &config.Routing{ |
||||
BypassDomains: []string{"*.local", "localhost", "*.cn"}, |
||||
} |
||||
|
||||
matcher, err := NewRouteMatcher(cfg) |
||||
if err != nil { |
||||
t.Fatalf("Failed to create route matcher: %v", err) |
||||
} |
||||
|
||||
tests := []struct { |
||||
host string |
||||
expected bool |
||||
}{ |
||||
{"test.local", true}, |
||||
{"printer.local", true}, |
||||
{"local", false}, // 不匹配 *.local
|
||||
{"localhost", true}, |
||||
{"www.baidu.cn", true}, |
||||
{"qq.cn", true}, |
||||
{"china.com", false}, // 不匹配 *.cn
|
||||
{"example.com", false}, |
||||
} |
||||
|
||||
for _, test := range tests { |
||||
t.Run(test.host, func(t *testing.T) { |
||||
result := matcher.matchesBypassDomains(test.host) |
||||
if result != test.expected { |
||||
t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestRouteMatcher_IsLocalDomain(t *testing.T) { |
||||
cfg := &config.Routing{} |
||||
matcher, _ := NewRouteMatcher(cfg) |
||||
|
||||
tests := []struct { |
||||
host string |
||||
expected bool |
||||
}{ |
||||
{"localhost", true}, |
||||
{"test.local", true}, |
||||
{"printer.lan", true}, |
||||
{"server.internal", true}, |
||||
{"router.home", true}, |
||||
{"pc.corp", true}, |
||||
{"singleword", true}, // 单词域名
|
||||
{"example.com", false}, |
||||
{"www.google.com", false}, |
||||
{"192.168.1.1", false}, // IP地址不是域名
|
||||
} |
||||
|
||||
for _, test := range tests { |
||||
t.Run(test.host, func(t *testing.T) { |
||||
result := matcher.isLocalDomain(test.host) |
||||
if result != test.expected { |
||||
t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestRouteMatcher_IsPrivateIP(t *testing.T) { |
||||
cfg := &config.Routing{} |
||||
matcher, _ := NewRouteMatcher(cfg) |
||||
|
||||
tests := []struct { |
||||
ip string |
||||
expected bool |
||||
}{ |
||||
// 私有IPv4地址
|
||||
{"192.168.1.1", true}, |
||||
{"10.0.0.1", true}, |
||||
{"172.16.1.1", true}, |
||||
{"127.0.0.1", true}, // 环回地址
|
||||
|
||||
// 公网IPv4地址
|
||||
{"8.8.8.8", false}, |
||||
{"1.1.1.1", false}, |
||||
{"114.114.114.114", false}, |
||||
|
||||
// IPv6地址
|
||||
{"::1", true}, // 环回
|
||||
{"fe80::1", true}, // 链路本地
|
||||
{"fc00::1", true}, // 唯一本地
|
||||
{"2001:db8::1", false}, // 公网(测试用)
|
||||
} |
||||
|
||||
for _, test := range tests { |
||||
t.Run(test.ip, func(t *testing.T) { |
||||
// 解析IP
|
||||
ip := parseIPForTest(test.ip) |
||||
if ip == nil { |
||||
t.Fatalf("Failed to parse IP: %s", test.ip) |
||||
} |
||||
|
||||
result := matcher.isPrivateIP(ip) |
||||
if result != test.expected { |
||||
t.Errorf("For IP %s, expected %v, got %v", test.ip, test.expected, result) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestRouteMatcher_GetStats(t *testing.T) { |
||||
cfg := &config.Routing{ |
||||
BypassLocal: true, |
||||
BypassPrivate: false, |
||||
BypassDomains: []string{"*.local", "*.lan", "*.cn"}, |
||||
ForceDomains: []string{"*.google.com", "*.github.com"}, |
||||
} |
||||
|
||||
matcher, err := NewRouteMatcher(cfg) |
||||
if err != nil { |
||||
t.Fatalf("Failed to create route matcher: %v", err) |
||||
} |
||||
|
||||
stats := matcher.GetStats() |
||||
|
||||
if stats["bypass_domains_count"] != 3 { |
||||
t.Errorf("Expected 3 bypass domains, got %v", stats["bypass_domains_count"]) |
||||
} |
||||
|
||||
if stats["force_domains_count"] != 2 { |
||||
t.Errorf("Expected 2 force domains, got %v", stats["force_domains_count"]) |
||||
} |
||||
|
||||
if stats["bypass_local"] != true { |
||||
t.Errorf("Expected bypass_local to be true, got %v", stats["bypass_local"]) |
||||
} |
||||
|
||||
if stats["bypass_private"] != false { |
||||
t.Errorf("Expected bypass_private to be false, got %v", stats["bypass_private"]) |
||||
} |
||||
} |
||||
|
||||
func TestDomainToRegexp(t *testing.T) { |
||||
cfg := &config.Routing{} |
||||
matcher, _ := NewRouteMatcher(cfg) |
||||
|
||||
tests := []struct { |
||||
domain string |
||||
host string |
||||
matches bool |
||||
}{ |
||||
{"*.google.com", "www.google.com", true}, |
||||
{"*.google.com", "api.google.com", true}, |
||||
{"*.google.com", "google.com", false}, |
||||
{"github.com", "github.com", true}, |
||||
{"github.com", "api.github.com", false}, |
||||
{"*.example.*", "test.example.org", true}, |
||||
{"*.example.*", "sub.example.net", true}, |
||||
{"*.example.*", "example.com", false}, |
||||
} |
||||
|
||||
for _, test := range tests { |
||||
t.Run(test.domain+"->"+test.host, func(t *testing.T) { |
||||
pattern, err := matcher.domainToRegexp(test.domain) |
||||
if err != nil { |
||||
t.Fatalf("Failed to compile pattern %s: %v", test.domain, err) |
||||
} |
||||
|
||||
matches := pattern.MatchString(test.host) |
||||
if matches != test.matches { |
||||
t.Errorf("Pattern %s against host %s: expected %v, got %v", |
||||
test.domain, test.host, test.matches, matches) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
// 辅助函数
|
||||
func parseIPForTest(s string) []byte { |
||||
// 简单的IP解析用于测试
|
||||
switch s { |
||||
case "192.168.1.1": |
||||
return []byte{192, 168, 1, 1} |
||||
case "10.0.0.1": |
||||
return []byte{10, 0, 0, 1} |
||||
case "172.16.1.1": |
||||
return []byte{172, 16, 1, 1} |
||||
case "127.0.0.1": |
||||
return []byte{127, 0, 0, 1} |
||||
case "8.8.8.8": |
||||
return []byte{8, 8, 8, 8} |
||||
case "1.1.1.1": |
||||
return []byte{1, 1, 1, 1} |
||||
case "114.114.114.114": |
||||
return []byte{114, 114, 114, 114} |
||||
// IPv6地址处理会更复杂,这里简化
|
||||
case "::1": |
||||
return []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} |
||||
case "fe80::1": |
||||
return []byte{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} |
||||
case "fc00::1": |
||||
return []byte{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} |
||||
case "2001:db8::1": |
||||
return []byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} |
||||
default: |
||||
return nil |
||||
} |
||||
} |
Loading…
Reference in new issue