package proxy import ( "context" "encoding/json" "fmt" "io" "net" "net/http" "strconv" "sync" "time" "github.com/azoic/wormhole-client/pkg/logger" ) // SOCKS5Proxy SOCKS5代理客户端 type SOCKS5Proxy struct { serverAddr string username string password string timeout time.Duration connPool *connectionPool stats *ProxyStats } // connectionPool 连接池 type connectionPool struct { connections chan net.Conn maxSize int mutex sync.Mutex } // NewSOCKS5Proxy 创建SOCKS5代理客户端 func NewSOCKS5Proxy(serverAddr, username, password string, timeout time.Duration) *SOCKS5Proxy { return &SOCKS5Proxy{ serverAddr: serverAddr, username: username, password: password, timeout: timeout, connPool: &connectionPool{ connections: make(chan net.Conn, 10), maxSize: 10, }, stats: NewProxyStats(), } } // CreateHTTPProxy 创建HTTP代理服务器 func (p *SOCKS5Proxy) CreateHTTPProxy(localPort int) *http.Server { proxyHandler := &httpProxyHandler{ socks5Proxy: p, } // 创建ServeMux来处理不同的路径 mux := http.NewServeMux() mux.Handle("/", proxyHandler) mux.HandleFunc("/stats", p.handleStats) mux.HandleFunc("/health", p.handleHealth) server := &http.Server{ Addr: fmt.Sprintf(":%d", localPort), Handler: mux, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 120 * time.Second, MaxHeaderBytes: 1 << 20, // 1MB } return server } // httpProxyHandler HTTP代理处理器 type httpProxyHandler struct { socks5Proxy *SOCKS5Proxy } func (h *httpProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 统计连接 h.socks5Proxy.stats.IncrementConnections() defer h.socks5Proxy.stats.DecrementActiveConnections() logger.Debug("Processing request: %s %s from %s", r.Method, r.URL.String(), r.RemoteAddr) if r.Method == http.MethodConnect { h.handleHTTPSProxy(w, r) } else { h.handleHTTPProxy(w, r) } } // handleHTTPSProxy 处理HTTPS代理请求 (CONNECT方法) func (h *httpProxyHandler) handleHTTPSProxy(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), h.socks5Proxy.timeout) defer cancel() destConn, err := h.socks5Proxy.DialTCPWithContext(ctx, r.Host) if err != nil { h.socks5Proxy.stats.IncrementFailedRequests() h.socks5Proxy.stats.IncrementSOCKS5Error("connection_failed") logger.Error("Failed to connect via SOCKS5 to %s: %v", r.Host, err) http.Error(w, "Bad Gateway", http.StatusBadGateway) return } defer destConn.Close() // 发送200 Connection established响应 w.WriteHeader(http.StatusOK) hijacker, ok := w.(http.Hijacker) if !ok { h.socks5Proxy.stats.IncrementFailedRequests() logger.Error("Hijacking not supported") http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } clientConn, _, err := hijacker.Hijack() if err != nil { h.socks5Proxy.stats.IncrementFailedRequests() logger.Error("Failed to hijack connection: %v", err) return } defer clientConn.Close() h.socks5Proxy.stats.IncrementSuccessfulRequests() logger.Debug("Established HTTPS tunnel to %s", r.Host) // 双向数据转发 var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() written := h.copyData(clientConn, destConn, "client->server") h.socks5Proxy.stats.AddBytesTransferred(written, 0) }() go func() { defer wg.Done() written := h.copyData(destConn, clientConn, "server->client") h.socks5Proxy.stats.AddBytesTransferred(0, written) }() wg.Wait() logger.Debug("HTTPS tunnel to %s closed", r.Host) } // handleHTTPProxy 处理HTTP代理请求 func (h *httpProxyHandler) handleHTTPProxy(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), h.socks5Proxy.timeout) defer cancel() // 确保URL包含Host if r.URL.Host == "" { r.URL.Host = r.Host } if r.URL.Scheme == "" { r.URL.Scheme = "http" } // 通过SOCKS5连接到目标服务器 destConn, err := h.socks5Proxy.DialTCPWithContext(ctx, r.Host) if err != nil { h.socks5Proxy.stats.IncrementFailedRequests() h.socks5Proxy.stats.IncrementSOCKS5Error("connection_failed") logger.Error("Failed to connect via SOCKS5 to %s: %v", r.Host, err) http.Error(w, "Bad Gateway", http.StatusBadGateway) return } defer destConn.Close() // 发送HTTP请求 if err := r.Write(destConn); err != nil { h.socks5Proxy.stats.IncrementFailedRequests() logger.Error("Failed to write request to %s: %v", r.Host, err) http.Error(w, "Bad Gateway", http.StatusBadGateway) return } // 设置响应头 w.Header().Set("Via", "1.1 wormhole-proxy") // 使用自定义ResponseWriter来统计字节数 statsWriter := &statsResponseWriter{ ResponseWriter: w, stats: h.socks5Proxy.stats, } // 读取响应并返回给客户端 written, err := io.Copy(statsWriter, destConn) if err != nil { h.socks5Proxy.stats.IncrementFailedRequests() logger.Error("Failed to copy response from %s: %v", r.Host, err) return } h.socks5Proxy.stats.IncrementSuccessfulRequests() h.socks5Proxy.stats.AddBytesTransferred(0, written) logger.Debug("HTTP request to %s completed, %d bytes", r.Host, written) } // statsResponseWriter 带统计功能的ResponseWriter type statsResponseWriter struct { http.ResponseWriter stats *ProxyStats } func (w *statsResponseWriter) Write(data []byte) (int, error) { n, err := w.ResponseWriter.Write(data) if n > 0 { w.stats.AddBytesTransferred(int64(n), 0) } return n, err } // copyData 数据复制,带方向标识和字节统计 func (h *httpProxyHandler) copyData(dst, src net.Conn, direction string) int64 { defer dst.Close() defer src.Close() written, err := io.Copy(dst, src) if err != nil { logger.Debug("Copy %s finished with error: %v, bytes: %d", direction, err, written) } else { logger.Debug("Copy %s finished successfully, bytes: %d", direction, written) } return written } // handleStats 处理统计信息请求 func (p *SOCKS5Proxy) handleStats(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } stats := p.stats.GetStats() w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") if err := json.NewEncoder(w).Encode(stats); err != nil { logger.Error("Failed to encode stats: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } } // handleHealth 处理健康检查请求 func (p *SOCKS5Proxy) handleHealth(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } stats := p.stats.GetStats() health := map[string]interface{}{ "status": "healthy", "uptime": stats.Uptime.String(), "active_connections": stats.ActiveConnections, "success_rate": stats.GetSuccessRate(), } w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(health); err != nil { logger.Error("Failed to encode health: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } } // GetStats 获取代理统计信息 func (p *SOCKS5Proxy) GetStats() ProxyStatsSnapshot { return p.stats.GetStats() } // IncrementConnections 增加连接计数 func (p *SOCKS5Proxy) IncrementConnections() { p.stats.IncrementConnections() } // DecrementActiveConnections 减少活跃连接计数 func (p *SOCKS5Proxy) DecrementActiveConnections() { p.stats.DecrementActiveConnections() } // IncrementSuccessfulRequests 增加成功请求计数 func (p *SOCKS5Proxy) IncrementSuccessfulRequests() { p.stats.IncrementSuccessfulRequests() } // IncrementFailedRequests 增加失败请求计数 func (p *SOCKS5Proxy) IncrementFailedRequests() { p.stats.IncrementFailedRequests() } // AddBytesTransferred 添加传输字节数 func (p *SOCKS5Proxy) AddBytesTransferred(sent, received int64) { p.stats.AddBytesTransferred(sent, received) } // DialTCP 通过SOCKS5连接到目标地址 func (p *SOCKS5Proxy) DialTCP(address string) (net.Conn, error) { return p.DialTCPWithContext(context.Background(), address) } // DialTCPWithContext 通过SOCKS5连接到目标地址(带上下文) func (p *SOCKS5Proxy) DialTCPWithContext(ctx context.Context, address string) (net.Conn, error) { // 连接到SOCKS5代理服务器 dialer := &net.Dialer{ Timeout: p.timeout, } conn, err := dialer.DialContext(ctx, "tcp", p.serverAddr) if err != nil { return nil, fmt.Errorf("failed to connect to SOCKS5 server: %v", err) } // 设置连接超时 deadline, ok := ctx.Deadline() if ok { if err := conn.SetDeadline(deadline); err != nil { conn.Close() return nil, fmt.Errorf("failed to set deadline: %v", err) } } // 执行SOCKS5握手 if err := p.performSOCKS5Handshake(conn, address); err != nil { conn.Close() return nil, fmt.Errorf("SOCKS5 handshake failed: %v", err) } // 清除deadline,让连接正常使用 if err := conn.SetDeadline(time.Time{}); err != nil { logger.Debug("Failed to clear deadline: %v", err) } logger.Debug("Successfully connected to %s via SOCKS5 proxy", address) return conn, nil } // performSOCKS5Handshake 执行SOCKS5握手协议 func (p *SOCKS5Proxy) performSOCKS5Handshake(conn net.Conn, targetAddr string) error { // 设置握手超时 deadline := time.Now().Add(p.timeout) if err := conn.SetDeadline(deadline); err != nil { return fmt.Errorf("failed to set handshake deadline: %v", err) } // 第一步:发送认证方法选择 authMethods := []byte{0x05, 0x02, 0x00, 0x02} // 版本5,2个方法,无认证+用户名密码认证 if _, err := conn.Write(authMethods); err != nil { return fmt.Errorf("failed to send auth methods: %v", err) } // 读取服务器响应 response := make([]byte, 2) if _, err := io.ReadFull(conn, response); err != nil { return fmt.Errorf("failed to read auth response: %v", err) } if response[0] != 0x05 { return fmt.Errorf("invalid SOCKS version: %d", response[0]) } // 第二步:处理认证 switch response[1] { case 0x00: // 无认证 logger.Debug("SOCKS5 server requires no authentication") case 0x02: // 用户名密码认证 if err := p.performUserPassAuth(conn); err != nil { return fmt.Errorf("user/pass authentication failed: %v", err) } case 0xFF: // 无可接受的认证方法 return fmt.Errorf("no acceptable authentication methods") default: return fmt.Errorf("unsupported authentication method: %d", response[1]) } // 第三步:发送连接请求 connectReq, err := p.buildConnectRequest(targetAddr) if err != nil { return fmt.Errorf("failed to build connect request: %v", err) } if _, err := conn.Write(connectReq); err != nil { return fmt.Errorf("failed to send connect request: %v", err) } // 第四步:读取连接响应 return p.readConnectResponse(conn) } // buildConnectRequest 构建连接请求 func (p *SOCKS5Proxy) buildConnectRequest(targetAddr string) ([]byte, error) { host, portStr, err := net.SplitHostPort(targetAddr) if err != nil { return nil, fmt.Errorf("invalid target address: %v", err) } // 解析端口号 portNum, err := parsePort(portStr) if err != nil { return nil, fmt.Errorf("invalid port: %v", err) } var connectReq []byte // 检测地址类型并构建请求 if ip := net.ParseIP(host); ip != nil { if ip4 := ip.To4(); ip4 != nil { // IPv4地址 connectReq = []byte{0x05, 0x01, 0x00, 0x01} connectReq = append(connectReq, ip4...) } else if ip6 := ip.To16(); ip6 != nil { // IPv6地址 connectReq = []byte{0x05, 0x01, 0x00, 0x04} connectReq = append(connectReq, ip6...) } } else { // 域名 if len(host) > 255 { return nil, fmt.Errorf("domain name too long: %d", len(host)) } connectReq = []byte{0x05, 0x01, 0x00, 0x03} connectReq = append(connectReq, byte(len(host))) connectReq = append(connectReq, []byte(host)...) } // 添加端口 connectReq = append(connectReq, byte(portNum>>8), byte(portNum&0xFF)) return connectReq, nil } // readConnectResponse 读取连接响应 func (p *SOCKS5Proxy) readConnectResponse(conn net.Conn) error { // 读取响应头部 header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { return fmt.Errorf("failed to read connect response header: %v", err) } if header[0] != 0x05 { return fmt.Errorf("invalid SOCKS version in response: %d", header[0]) } if header[1] != 0x00 { return fmt.Errorf("connection failed, status: %d (%s)", header[1], getSOCKS5ErrorMessage(header[1])) } // 读取绑定地址和端口 addrType := header[3] switch addrType { case 0x01: // IPv4 skipBytes := make([]byte, 6) // 4字节IP + 2字节端口 _, err := io.ReadFull(conn, skipBytes) return err case 0x03: // 域名 lenByte := make([]byte, 1) if _, err := io.ReadFull(conn, lenByte); err != nil { return err } skipBytes := make([]byte, int(lenByte[0])+2) // 域名长度 + 2字节端口 _, err := io.ReadFull(conn, skipBytes) return err case 0x04: // IPv6 skipBytes := make([]byte, 18) // 16字节IP + 2字节端口 _, err := io.ReadFull(conn, skipBytes) return err default: return fmt.Errorf("unsupported address type: %d", addrType) } } // performUserPassAuth 执行用户名密码认证 func (p *SOCKS5Proxy) performUserPassAuth(conn net.Conn) error { // 发送用户名密码 authData := []byte{0x01} // 子协议版本 authData = append(authData, byte(len(p.username))) authData = append(authData, []byte(p.username)...) authData = append(authData, byte(len(p.password))) authData = append(authData, []byte(p.password)...) if _, err := conn.Write(authData); err != nil { return fmt.Errorf("failed to send credentials: %v", err) } // 读取认证结果 authResult := make([]byte, 2) if _, err := io.ReadFull(conn, authResult); err != nil { return fmt.Errorf("failed to read auth result: %v", err) } if authResult[0] != 0x01 { return fmt.Errorf("invalid auth response version: %d", authResult[0]) } if authResult[1] != 0x00 { return fmt.Errorf("authentication failed") } logger.Debug("SOCKS5 authentication successful") return nil } // getSOCKS5ErrorMessage 获取SOCKS5错误消息 func getSOCKS5ErrorMessage(code byte) string { switch code { case 0x01: return "general SOCKS server failure" case 0x02: return "connection not allowed by ruleset" case 0x03: return "network unreachable" case 0x04: return "host unreachable" case 0x05: return "connection refused" case 0x06: return "TTL expired" case 0x07: return "command not supported" case 0x08: return "address type not supported" default: return "unknown error" } } // parsePort 解析端口号 func parsePort(portStr string) (int, error) { if portStr == "" { return 80, nil // 默认HTTP端口 } port, err := strconv.Atoi(portStr) if err != nil { return 0, fmt.Errorf("invalid port format: %s", portStr) } if port < 1 || port > 65535 { return 0, fmt.Errorf("port out of range: %d", port) } return port, nil }