package proxy import ( "context" "testing" "time" ) func TestNewSOCKS5Proxy(t *testing.T) { proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) if proxy == nil { t.Fatal("NewSOCKS5Proxy returned nil") } if proxy.serverAddr != "127.0.0.1:1080" { t.Errorf("Expected serverAddr '127.0.0.1:1080', got '%s'", proxy.serverAddr) } if proxy.username != "user" { t.Errorf("Expected username 'user', got '%s'", proxy.username) } if proxy.password != "pass" { t.Errorf("Expected password 'pass', got '%s'", proxy.password) } if proxy.timeout != 30*time.Second { t.Errorf("Expected timeout 30s, got %v", proxy.timeout) } // 检查连接池是否正确初始化 if proxy.connPool == nil { t.Error("Connection pool should be initialized") } if proxy.connPool.maxSize != 10 { t.Errorf("Expected connection pool size 10, got %d", proxy.connPool.maxSize) } } func TestParsePort(t *testing.T) { tests := []struct { input string expected int hasError bool }{ {"80", 80, false}, {"443", 443, false}, {"8080", 8080, false}, {"", 80, false}, // 默认端口 {"0", 0, true}, // 无效端口 {"65536", 0, true}, // 端口超出范围 {"abc", 0, true}, // 非数字 {"-1", 0, true}, // 负数 {"65535", 65535, false}, // 最大有效端口 {"1", 1, false}, // 最小有效端口 } for _, test := range tests { result, err := parsePort(test.input) if test.hasError { if err == nil { t.Errorf("Expected error for input '%s', but got none", test.input) } } else { if err != nil { t.Errorf("Unexpected error for input '%s': %v", test.input, err) } if result != test.expected { t.Errorf("For input '%s', expected %d, got %d", test.input, test.expected, result) } } } } func TestCreateHTTPProxy(t *testing.T) { proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) server := proxy.CreateHTTPProxy(8080) if server == nil { t.Fatal("CreateHTTPProxy returned nil") } if server.Addr != ":8080" { t.Errorf("Expected server address ':8080', got '%s'", server.Addr) } // 检查服务器配置 if server.ReadTimeout != 30*time.Second { t.Errorf("Expected ReadTimeout 30s, got %v", server.ReadTimeout) } if server.WriteTimeout != 30*time.Second { t.Errorf("Expected WriteTimeout 30s, got %v", server.WriteTimeout) } if server.IdleTimeout != 120*time.Second { t.Errorf("Expected IdleTimeout 120s, got %v", server.IdleTimeout) } } func TestBuildConnectRequest(t *testing.T) { proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) tests := []struct { name string targetAddr string expectError bool expectedType byte // 地址类型: 1=IPv4, 3=域名, 4=IPv6 }{ { name: "IPv4 address", targetAddr: "192.168.1.1:80", expectError: false, expectedType: 0x01, }, { name: "IPv6 address", targetAddr: "[2001:db8::1]:80", expectError: false, expectedType: 0x04, }, { name: "Domain name", targetAddr: "example.com:80", expectError: false, expectedType: 0x03, }, { name: "Domain with HTTPS port", targetAddr: "example.com:443", expectError: false, expectedType: 0x03, }, { name: "Invalid address format", targetAddr: "invalid", expectError: true, }, { name: "Invalid port", targetAddr: "example.com:invalid", expectError: true, }, { name: "Port out of range", targetAddr: "example.com:70000", expectError: true, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { req, err := proxy.buildConnectRequest(test.targetAddr) if test.expectError { if err == nil { t.Errorf("Expected error for %s, but got none", test.targetAddr) } return } if err != nil { t.Errorf("Unexpected error for %s: %v", test.targetAddr, err) return } if len(req) < 4 { t.Errorf("Request too short: %d bytes", len(req)) return } // 检查SOCKS版本 if req[0] != 0x05 { t.Errorf("Expected SOCKS version 5, got %d", req[0]) } // 检查命令类型 if req[1] != 0x01 { t.Errorf("Expected CONNECT command (1), got %d", req[1]) } // 检查保留字段 if req[2] != 0x00 { t.Errorf("Expected reserved field to be 0, got %d", req[2]) } // 检查地址类型 if req[3] != test.expectedType { t.Errorf("Expected address type %d, got %d", test.expectedType, req[3]) } }) } } func TestGetSOCKS5ErrorMessage(t *testing.T) { tests := []struct { code byte expected string }{ {0x01, "general SOCKS server failure"}, {0x02, "connection not allowed by ruleset"}, {0x03, "network unreachable"}, {0x04, "host unreachable"}, {0x05, "connection refused"}, {0x06, "TTL expired"}, {0x07, "command not supported"}, {0x08, "address type not supported"}, {0xFF, "unknown error"}, } for _, test := range tests { result := getSOCKS5ErrorMessage(test.code) if result != test.expected { t.Errorf("For code %d, expected '%s', got '%s'", test.code, test.expected, result) } } } func TestDialTCPWithContext(t *testing.T) { proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 5*time.Second) // 测试超时上下文 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) defer cancel() // 由于没有真实的SOCKS5服务器,这应该会超时或连接失败 _, err := proxy.DialTCPWithContext(ctx, "example.com:80") if err == nil { t.Error("Expected error when connecting to non-existent SOCKS5 server") } } func TestHTTPProxyHandler(t *testing.T) { proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) handler := &httpProxyHandler{socks5Proxy: proxy} if handler.socks5Proxy != proxy { t.Error("Handler should have reference to SOCKS5 proxy") } } // BenchmarkParsePort 性能测试 func BenchmarkParsePort(b *testing.B) { ports := []string{"80", "443", "8080", "3000", "9999"} for i := 0; i < b.N; i++ { port := ports[i%len(ports)] _, err := parsePort(port) if err != nil { b.Errorf("Unexpected error: %v", err) } } } // TestIPv6AddressHandling 测试IPv6地址处理 func TestIPv6AddressHandling(t *testing.T) { proxy := NewSOCKS5Proxy("127.0.0.1:1080", "user", "pass", 30*time.Second) // 测试完整的IPv6地址 req, err := proxy.buildConnectRequest("[2001:db8:85a3::8a2e:370:7334]:443") if err != nil { t.Fatalf("Failed to build connect request for IPv6: %v", err) } if req[3] != 0x04 { t.Errorf("Expected IPv6 address type (4), got %d", req[3]) } // IPv6地址应该是16字节 + 头部4字节 + 端口2字节 = 22字节 expectedLen := 4 + 16 + 2 // 头部 + IPv6地址 + 端口 if len(req) != expectedLen { t.Errorf("Expected request length %d for IPv6, got %d", expectedLen, len(req)) } }