You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

528 lines
14 KiB

package proxy
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strconv"
"sync"
"time"
"github.com/azoic/wormhole-client/pkg/logger"
)
// SOCKS5Proxy SOCKS5代理客户端
type SOCKS5Proxy struct {
serverAddr string
username string
password string
timeout time.Duration
connPool *connectionPool
stats *ProxyStats
}
// connectionPool 连接池
type connectionPool struct {
connections chan net.Conn
maxSize int
mutex sync.Mutex
}
// NewSOCKS5Proxy 创建SOCKS5代理客户端
func NewSOCKS5Proxy(serverAddr, username, password string, timeout time.Duration) *SOCKS5Proxy {
return &SOCKS5Proxy{
serverAddr: serverAddr,
username: username,
password: password,
timeout: timeout,
connPool: &connectionPool{
connections: make(chan net.Conn, 10),
maxSize: 10,
},
stats: NewProxyStats(),
}
}
// CreateHTTPProxy 创建HTTP代理服务器
func (p *SOCKS5Proxy) CreateHTTPProxy(localPort int) *http.Server {
proxyHandler := &httpProxyHandler{
socks5Proxy: p,
}
// 创建ServeMux来处理不同的路径
mux := http.NewServeMux()
mux.Handle("/", proxyHandler)
mux.HandleFunc("/stats", p.handleStats)
mux.HandleFunc("/health", p.handleHealth)
server := &http.Server{
Addr: fmt.Sprintf(":%d", localPort),
Handler: mux,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
MaxHeaderBytes: 1 << 20, // 1MB
}
return server
}
// httpProxyHandler HTTP代理处理器
type httpProxyHandler struct {
socks5Proxy *SOCKS5Proxy
}
func (h *httpProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 统计连接
h.socks5Proxy.stats.IncrementConnections()
defer h.socks5Proxy.stats.DecrementActiveConnections()
logger.Debug("Processing request: %s %s from %s", r.Method, r.URL.String(), r.RemoteAddr)
if r.Method == http.MethodConnect {
h.handleHTTPSProxy(w, r)
} else {
h.handleHTTPProxy(w, r)
}
}
// handleHTTPSProxy 处理HTTPS代理请求 (CONNECT方法)
func (h *httpProxyHandler) handleHTTPSProxy(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), h.socks5Proxy.timeout)
defer cancel()
destConn, err := h.socks5Proxy.DialTCPWithContext(ctx, r.Host)
if err != nil {
h.socks5Proxy.stats.IncrementFailedRequests()
h.socks5Proxy.stats.IncrementSOCKS5Error("connection_failed")
logger.Error("Failed to connect via SOCKS5 to %s: %v", r.Host, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer destConn.Close()
// 发送200 Connection established响应
w.WriteHeader(http.StatusOK)
hijacker, ok := w.(http.Hijacker)
if !ok {
h.socks5Proxy.stats.IncrementFailedRequests()
logger.Error("Hijacking not supported")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
h.socks5Proxy.stats.IncrementFailedRequests()
logger.Error("Failed to hijack connection: %v", err)
return
}
defer clientConn.Close()
h.socks5Proxy.stats.IncrementSuccessfulRequests()
logger.Debug("Established HTTPS tunnel to %s", r.Host)
// 双向数据转发
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
written := h.copyData(clientConn, destConn, "client->server")
h.socks5Proxy.stats.AddBytesTransferred(written, 0)
}()
go func() {
defer wg.Done()
written := h.copyData(destConn, clientConn, "server->client")
h.socks5Proxy.stats.AddBytesTransferred(0, written)
}()
wg.Wait()
logger.Debug("HTTPS tunnel to %s closed", r.Host)
}
// handleHTTPProxy 处理HTTP代理请求
func (h *httpProxyHandler) handleHTTPProxy(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), h.socks5Proxy.timeout)
defer cancel()
// 确保URL包含Host
if r.URL.Host == "" {
r.URL.Host = r.Host
}
if r.URL.Scheme == "" {
r.URL.Scheme = "http"
}
// 通过SOCKS5连接到目标服务器
destConn, err := h.socks5Proxy.DialTCPWithContext(ctx, r.Host)
if err != nil {
h.socks5Proxy.stats.IncrementFailedRequests()
h.socks5Proxy.stats.IncrementSOCKS5Error("connection_failed")
logger.Error("Failed to connect via SOCKS5 to %s: %v", r.Host, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer destConn.Close()
// 发送HTTP请求
if err := r.Write(destConn); err != nil {
h.socks5Proxy.stats.IncrementFailedRequests()
logger.Error("Failed to write request to %s: %v", r.Host, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
// 设置响应头
w.Header().Set("Via", "1.1 wormhole-proxy")
// 使用自定义ResponseWriter来统计字节数
statsWriter := &statsResponseWriter{
ResponseWriter: w,
stats: h.socks5Proxy.stats,
}
// 读取响应并返回给客户端
written, err := io.Copy(statsWriter, destConn)
if err != nil {
h.socks5Proxy.stats.IncrementFailedRequests()
logger.Error("Failed to copy response from %s: %v", r.Host, err)
return
}
h.socks5Proxy.stats.IncrementSuccessfulRequests()
h.socks5Proxy.stats.AddBytesTransferred(0, written)
logger.Debug("HTTP request to %s completed, %d bytes", r.Host, written)
}
// statsResponseWriter 带统计功能的ResponseWriter
type statsResponseWriter struct {
http.ResponseWriter
stats *ProxyStats
}
func (w *statsResponseWriter) Write(data []byte) (int, error) {
n, err := w.ResponseWriter.Write(data)
if n > 0 {
w.stats.AddBytesTransferred(int64(n), 0)
}
return n, err
}
// copyData 数据复制,带方向标识和字节统计
func (h *httpProxyHandler) copyData(dst, src net.Conn, direction string) int64 {
defer dst.Close()
defer src.Close()
written, err := io.Copy(dst, src)
if err != nil {
logger.Debug("Copy %s finished with error: %v, bytes: %d", direction, err, written)
} else {
logger.Debug("Copy %s finished successfully, bytes: %d", direction, written)
}
return written
}
// handleStats 处理统计信息请求
func (p *SOCKS5Proxy) handleStats(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
stats := p.stats.GetStats()
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
if err := json.NewEncoder(w).Encode(stats); err != nil {
logger.Error("Failed to encode stats: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
// handleHealth 处理健康检查请求
func (p *SOCKS5Proxy) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
stats := p.stats.GetStats()
health := map[string]interface{}{
"status": "healthy",
"uptime": stats.Uptime.String(),
"active_connections": stats.ActiveConnections,
"success_rate": stats.GetSuccessRate(),
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(health); err != nil {
logger.Error("Failed to encode health: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
// GetStats 获取代理统计信息
func (p *SOCKS5Proxy) GetStats() ProxyStatsSnapshot {
return p.stats.GetStats()
}
// DialTCP 通过SOCKS5连接到目标地址
func (p *SOCKS5Proxy) DialTCP(address string) (net.Conn, error) {
return p.DialTCPWithContext(context.Background(), address)
}
// DialTCPWithContext 通过SOCKS5连接到目标地址(带上下文)
func (p *SOCKS5Proxy) DialTCPWithContext(ctx context.Context, address string) (net.Conn, error) {
// 连接到SOCKS5代理服务器
dialer := &net.Dialer{
Timeout: p.timeout,
}
conn, err := dialer.DialContext(ctx, "tcp", p.serverAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to SOCKS5 server: %v", err)
}
// 设置连接超时
deadline, ok := ctx.Deadline()
if ok {
if err := conn.SetDeadline(deadline); err != nil {
conn.Close()
return nil, fmt.Errorf("failed to set deadline: %v", err)
}
}
// 执行SOCKS5握手
if err := p.performSOCKS5Handshake(conn, address); err != nil {
conn.Close()
return nil, fmt.Errorf("SOCKS5 handshake failed: %v", err)
}
// 清除deadline,让连接正常使用
if err := conn.SetDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear deadline: %v", err)
}
logger.Debug("Successfully connected to %s via SOCKS5 proxy", address)
return conn, nil
}
// performSOCKS5Handshake 执行SOCKS5握手协议
func (p *SOCKS5Proxy) performSOCKS5Handshake(conn net.Conn, targetAddr string) error {
// 设置握手超时
deadline := time.Now().Add(p.timeout)
if err := conn.SetDeadline(deadline); err != nil {
return fmt.Errorf("failed to set handshake deadline: %v", err)
}
// 第一步:发送认证方法选择
authMethods := []byte{0x05, 0x02, 0x00, 0x02} // 版本5,2个方法,无认证+用户名密码认证
if _, err := conn.Write(authMethods); err != nil {
return fmt.Errorf("failed to send auth methods: %v", err)
}
// 读取服务器响应
response := make([]byte, 2)
if _, err := io.ReadFull(conn, response); err != nil {
return fmt.Errorf("failed to read auth response: %v", err)
}
if response[0] != 0x05 {
return fmt.Errorf("invalid SOCKS version: %d", response[0])
}
// 第二步:处理认证
switch response[1] {
case 0x00: // 无认证
logger.Debug("SOCKS5 server requires no authentication")
case 0x02: // 用户名密码认证
if err := p.performUserPassAuth(conn); err != nil {
return fmt.Errorf("user/pass authentication failed: %v", err)
}
case 0xFF: // 无可接受的认证方法
return fmt.Errorf("no acceptable authentication methods")
default:
return fmt.Errorf("unsupported authentication method: %d", response[1])
}
// 第三步:发送连接请求
connectReq, err := p.buildConnectRequest(targetAddr)
if err != nil {
return fmt.Errorf("failed to build connect request: %v", err)
}
if _, err := conn.Write(connectReq); err != nil {
return fmt.Errorf("failed to send connect request: %v", err)
}
// 第四步:读取连接响应
return p.readConnectResponse(conn)
}
// buildConnectRequest 构建连接请求
func (p *SOCKS5Proxy) buildConnectRequest(targetAddr string) ([]byte, error) {
host, portStr, err := net.SplitHostPort(targetAddr)
if err != nil {
return nil, fmt.Errorf("invalid target address: %v", err)
}
// 解析端口号
portNum, err := parsePort(portStr)
if err != nil {
return nil, fmt.Errorf("invalid port: %v", err)
}
var connectReq []byte
// 检测地址类型并构建请求
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
// IPv4地址
connectReq = []byte{0x05, 0x01, 0x00, 0x01}
connectReq = append(connectReq, ip4...)
} else if ip6 := ip.To16(); ip6 != nil {
// IPv6地址
connectReq = []byte{0x05, 0x01, 0x00, 0x04}
connectReq = append(connectReq, ip6...)
}
} else {
// 域名
if len(host) > 255 {
return nil, fmt.Errorf("domain name too long: %d", len(host))
}
connectReq = []byte{0x05, 0x01, 0x00, 0x03}
connectReq = append(connectReq, byte(len(host)))
connectReq = append(connectReq, []byte(host)...)
}
// 添加端口
connectReq = append(connectReq, byte(portNum>>8), byte(portNum&0xFF))
return connectReq, nil
}
// readConnectResponse 读取连接响应
func (p *SOCKS5Proxy) readConnectResponse(conn net.Conn) error {
// 读取响应头部
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return fmt.Errorf("failed to read connect response header: %v", err)
}
if header[0] != 0x05 {
return fmt.Errorf("invalid SOCKS version in response: %d", header[0])
}
if header[1] != 0x00 {
return fmt.Errorf("connection failed, status: %d (%s)", header[1], getSOCKS5ErrorMessage(header[1]))
}
// 读取绑定地址和端口
addrType := header[3]
switch addrType {
case 0x01: // IPv4
skipBytes := make([]byte, 6) // 4字节IP + 2字节端口
_, err := io.ReadFull(conn, skipBytes)
return err
case 0x03: // 域名
lenByte := make([]byte, 1)
if _, err := io.ReadFull(conn, lenByte); err != nil {
return err
}
skipBytes := make([]byte, int(lenByte[0])+2) // 域名长度 + 2字节端口
_, err := io.ReadFull(conn, skipBytes)
return err
case 0x04: // IPv6
skipBytes := make([]byte, 18) // 16字节IP + 2字节端口
_, err := io.ReadFull(conn, skipBytes)
return err
default:
return fmt.Errorf("unsupported address type: %d", addrType)
}
}
// performUserPassAuth 执行用户名密码认证
func (p *SOCKS5Proxy) performUserPassAuth(conn net.Conn) error {
// 发送用户名密码
authData := []byte{0x01} // 子协议版本
authData = append(authData, byte(len(p.username)))
authData = append(authData, []byte(p.username)...)
authData = append(authData, byte(len(p.password)))
authData = append(authData, []byte(p.password)...)
if _, err := conn.Write(authData); err != nil {
return fmt.Errorf("failed to send credentials: %v", err)
}
// 读取认证结果
authResult := make([]byte, 2)
if _, err := io.ReadFull(conn, authResult); err != nil {
return fmt.Errorf("failed to read auth result: %v", err)
}
if authResult[0] != 0x01 {
return fmt.Errorf("invalid auth response version: %d", authResult[0])
}
if authResult[1] != 0x00 {
return fmt.Errorf("authentication failed")
}
logger.Debug("SOCKS5 authentication successful")
return nil
}
// getSOCKS5ErrorMessage 获取SOCKS5错误消息
func getSOCKS5ErrorMessage(code byte) string {
switch code {
case 0x01:
return "general SOCKS server failure"
case 0x02:
return "connection not allowed by ruleset"
case 0x03:
return "network unreachable"
case 0x04:
return "host unreachable"
case 0x05:
return "connection refused"
case 0x06:
return "TTL expired"
case 0x07:
return "command not supported"
case 0x08:
return "address type not supported"
default:
return "unknown error"
}
}
// parsePort 解析端口号
func parsePort(portStr string) (int, error) {
if portStr == "" {
return 80, nil // 默认HTTP端口
}
port, err := strconv.Atoi(portStr)
if err != nil {
return 0, fmt.Errorf("invalid port format: %s", portStr)
}
if port < 1 || port > 65535 {
return 0, fmt.Errorf("port out of range: %d", port)
}
return port, nil
}