You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
552 lines
15 KiB
552 lines
15 KiB
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/azoic/wormhole-client/pkg/logger"
|
|
)
|
|
|
|
// SOCKS5Proxy SOCKS5代理客户端
|
|
type SOCKS5Proxy struct {
|
|
serverAddr string
|
|
username string
|
|
password string
|
|
timeout time.Duration
|
|
connPool *connectionPool
|
|
stats *ProxyStats
|
|
}
|
|
|
|
// connectionPool 连接池
|
|
type connectionPool struct {
|
|
connections chan net.Conn
|
|
maxSize int
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
// NewSOCKS5Proxy 创建SOCKS5代理客户端
|
|
func NewSOCKS5Proxy(serverAddr, username, password string, timeout time.Duration) *SOCKS5Proxy {
|
|
return &SOCKS5Proxy{
|
|
serverAddr: serverAddr,
|
|
username: username,
|
|
password: password,
|
|
timeout: timeout,
|
|
connPool: &connectionPool{
|
|
connections: make(chan net.Conn, 10),
|
|
maxSize: 10,
|
|
},
|
|
stats: NewProxyStats(),
|
|
}
|
|
}
|
|
|
|
// CreateHTTPProxy 创建HTTP代理服务器
|
|
func (p *SOCKS5Proxy) CreateHTTPProxy(localPort int) *http.Server {
|
|
proxyHandler := &httpProxyHandler{
|
|
socks5Proxy: p,
|
|
}
|
|
|
|
// 创建ServeMux来处理不同的路径
|
|
mux := http.NewServeMux()
|
|
mux.Handle("/", proxyHandler)
|
|
mux.HandleFunc("/stats", p.handleStats)
|
|
mux.HandleFunc("/health", p.handleHealth)
|
|
|
|
server := &http.Server{
|
|
Addr: fmt.Sprintf(":%d", localPort),
|
|
Handler: mux,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
IdleTimeout: 120 * time.Second,
|
|
MaxHeaderBytes: 1 << 20, // 1MB
|
|
}
|
|
|
|
return server
|
|
}
|
|
|
|
// httpProxyHandler HTTP代理处理器
|
|
type httpProxyHandler struct {
|
|
socks5Proxy *SOCKS5Proxy
|
|
}
|
|
|
|
func (h *httpProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
// 统计连接
|
|
h.socks5Proxy.stats.IncrementConnections()
|
|
defer h.socks5Proxy.stats.DecrementActiveConnections()
|
|
|
|
logger.Debug("Processing request: %s %s from %s", r.Method, r.URL.String(), r.RemoteAddr)
|
|
|
|
if r.Method == http.MethodConnect {
|
|
h.handleHTTPSProxy(w, r)
|
|
} else {
|
|
h.handleHTTPProxy(w, r)
|
|
}
|
|
}
|
|
|
|
// handleHTTPSProxy 处理HTTPS代理请求 (CONNECT方法)
|
|
func (h *httpProxyHandler) handleHTTPSProxy(w http.ResponseWriter, r *http.Request) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), h.socks5Proxy.timeout)
|
|
defer cancel()
|
|
|
|
destConn, err := h.socks5Proxy.DialTCPWithContext(ctx, r.Host)
|
|
if err != nil {
|
|
h.socks5Proxy.stats.IncrementFailedRequests()
|
|
h.socks5Proxy.stats.IncrementSOCKS5Error("connection_failed")
|
|
logger.Error("Failed to connect via SOCKS5 to %s: %v", r.Host, err)
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer destConn.Close()
|
|
|
|
// 发送200 Connection established响应
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
h.socks5Proxy.stats.IncrementFailedRequests()
|
|
logger.Error("Hijacking not supported")
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
clientConn, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
h.socks5Proxy.stats.IncrementFailedRequests()
|
|
logger.Error("Failed to hijack connection: %v", err)
|
|
return
|
|
}
|
|
defer clientConn.Close()
|
|
|
|
h.socks5Proxy.stats.IncrementSuccessfulRequests()
|
|
logger.Debug("Established HTTPS tunnel to %s", r.Host)
|
|
|
|
// 双向数据转发
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
written := h.copyData(clientConn, destConn, "client->server")
|
|
h.socks5Proxy.stats.AddBytesTransferred(written, 0)
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
written := h.copyData(destConn, clientConn, "server->client")
|
|
h.socks5Proxy.stats.AddBytesTransferred(0, written)
|
|
}()
|
|
|
|
wg.Wait()
|
|
logger.Debug("HTTPS tunnel to %s closed", r.Host)
|
|
}
|
|
|
|
// handleHTTPProxy 处理HTTP代理请求
|
|
func (h *httpProxyHandler) handleHTTPProxy(w http.ResponseWriter, r *http.Request) {
|
|
ctx, cancel := context.WithTimeout(r.Context(), h.socks5Proxy.timeout)
|
|
defer cancel()
|
|
|
|
// 确保URL包含Host
|
|
if r.URL.Host == "" {
|
|
r.URL.Host = r.Host
|
|
}
|
|
if r.URL.Scheme == "" {
|
|
r.URL.Scheme = "http"
|
|
}
|
|
|
|
// 通过SOCKS5连接到目标服务器
|
|
destConn, err := h.socks5Proxy.DialTCPWithContext(ctx, r.Host)
|
|
if err != nil {
|
|
h.socks5Proxy.stats.IncrementFailedRequests()
|
|
h.socks5Proxy.stats.IncrementSOCKS5Error("connection_failed")
|
|
logger.Error("Failed to connect via SOCKS5 to %s: %v", r.Host, err)
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer destConn.Close()
|
|
|
|
// 发送HTTP请求
|
|
if err := r.Write(destConn); err != nil {
|
|
h.socks5Proxy.stats.IncrementFailedRequests()
|
|
logger.Error("Failed to write request to %s: %v", r.Host, err)
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// 设置响应头
|
|
w.Header().Set("Via", "1.1 wormhole-proxy")
|
|
|
|
// 使用自定义ResponseWriter来统计字节数
|
|
statsWriter := &statsResponseWriter{
|
|
ResponseWriter: w,
|
|
stats: h.socks5Proxy.stats,
|
|
}
|
|
|
|
// 读取响应并返回给客户端
|
|
written, err := io.Copy(statsWriter, destConn)
|
|
if err != nil {
|
|
h.socks5Proxy.stats.IncrementFailedRequests()
|
|
logger.Error("Failed to copy response from %s: %v", r.Host, err)
|
|
return
|
|
}
|
|
|
|
h.socks5Proxy.stats.IncrementSuccessfulRequests()
|
|
h.socks5Proxy.stats.AddBytesTransferred(0, written)
|
|
logger.Debug("HTTP request to %s completed, %d bytes", r.Host, written)
|
|
}
|
|
|
|
// statsResponseWriter 带统计功能的ResponseWriter
|
|
type statsResponseWriter struct {
|
|
http.ResponseWriter
|
|
stats *ProxyStats
|
|
}
|
|
|
|
func (w *statsResponseWriter) Write(data []byte) (int, error) {
|
|
n, err := w.ResponseWriter.Write(data)
|
|
if n > 0 {
|
|
w.stats.AddBytesTransferred(int64(n), 0)
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// copyData 数据复制,带方向标识和字节统计
|
|
func (h *httpProxyHandler) copyData(dst, src net.Conn, direction string) int64 {
|
|
defer dst.Close()
|
|
defer src.Close()
|
|
|
|
written, err := io.Copy(dst, src)
|
|
if err != nil {
|
|
logger.Debug("Copy %s finished with error: %v, bytes: %d", direction, err, written)
|
|
} else {
|
|
logger.Debug("Copy %s finished successfully, bytes: %d", direction, written)
|
|
}
|
|
|
|
return written
|
|
}
|
|
|
|
// handleStats 处理统计信息请求
|
|
func (p *SOCKS5Proxy) handleStats(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
stats := p.stats.GetStats()
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
|
|
if err := json.NewEncoder(w).Encode(stats); err != nil {
|
|
logger.Error("Failed to encode stats: %v", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// handleHealth 处理健康检查请求
|
|
func (p *SOCKS5Proxy) handleHealth(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
stats := p.stats.GetStats()
|
|
|
|
health := map[string]interface{}{
|
|
"status": "healthy",
|
|
"uptime": stats.Uptime.String(),
|
|
"active_connections": stats.ActiveConnections,
|
|
"success_rate": stats.GetSuccessRate(),
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
if err := json.NewEncoder(w).Encode(health); err != nil {
|
|
logger.Error("Failed to encode health: %v", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// GetStats 获取代理统计信息
|
|
func (p *SOCKS5Proxy) GetStats() ProxyStatsSnapshot {
|
|
return p.stats.GetStats()
|
|
}
|
|
|
|
// IncrementConnections 增加连接计数
|
|
func (p *SOCKS5Proxy) IncrementConnections() {
|
|
p.stats.IncrementConnections()
|
|
}
|
|
|
|
// DecrementActiveConnections 减少活跃连接计数
|
|
func (p *SOCKS5Proxy) DecrementActiveConnections() {
|
|
p.stats.DecrementActiveConnections()
|
|
}
|
|
|
|
// IncrementSuccessfulRequests 增加成功请求计数
|
|
func (p *SOCKS5Proxy) IncrementSuccessfulRequests() {
|
|
p.stats.IncrementSuccessfulRequests()
|
|
}
|
|
|
|
// IncrementFailedRequests 增加失败请求计数
|
|
func (p *SOCKS5Proxy) IncrementFailedRequests() {
|
|
p.stats.IncrementFailedRequests()
|
|
}
|
|
|
|
// AddBytesTransferred 添加传输字节数
|
|
func (p *SOCKS5Proxy) AddBytesTransferred(sent, received int64) {
|
|
p.stats.AddBytesTransferred(sent, received)
|
|
}
|
|
|
|
// DialTCP 通过SOCKS5连接到目标地址
|
|
func (p *SOCKS5Proxy) DialTCP(address string) (net.Conn, error) {
|
|
return p.DialTCPWithContext(context.Background(), address)
|
|
}
|
|
|
|
// DialTCPWithContext 通过SOCKS5连接到目标地址(带上下文)
|
|
func (p *SOCKS5Proxy) DialTCPWithContext(ctx context.Context, address string) (net.Conn, error) {
|
|
// 连接到SOCKS5代理服务器
|
|
dialer := &net.Dialer{
|
|
Timeout: p.timeout,
|
|
}
|
|
|
|
conn, err := dialer.DialContext(ctx, "tcp", p.serverAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to connect to SOCKS5 server: %v", err)
|
|
}
|
|
|
|
// 设置连接超时
|
|
deadline, ok := ctx.Deadline()
|
|
if ok {
|
|
if err := conn.SetDeadline(deadline); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("failed to set deadline: %v", err)
|
|
}
|
|
}
|
|
|
|
// 执行SOCKS5握手
|
|
if err := p.performSOCKS5Handshake(conn, address); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("SOCKS5 handshake failed: %v", err)
|
|
}
|
|
|
|
// 清除deadline,让连接正常使用
|
|
if err := conn.SetDeadline(time.Time{}); err != nil {
|
|
logger.Debug("Failed to clear deadline: %v", err)
|
|
}
|
|
|
|
logger.Debug("Successfully connected to %s via SOCKS5 proxy", address)
|
|
return conn, nil
|
|
}
|
|
|
|
// performSOCKS5Handshake 执行SOCKS5握手协议
|
|
func (p *SOCKS5Proxy) performSOCKS5Handshake(conn net.Conn, targetAddr string) error {
|
|
// 设置握手超时
|
|
deadline := time.Now().Add(p.timeout)
|
|
if err := conn.SetDeadline(deadline); err != nil {
|
|
return fmt.Errorf("failed to set handshake deadline: %v", err)
|
|
}
|
|
|
|
// 第一步:发送认证方法选择
|
|
authMethods := []byte{0x05, 0x02, 0x00, 0x02} // 版本5,2个方法,无认证+用户名密码认证
|
|
if _, err := conn.Write(authMethods); err != nil {
|
|
return fmt.Errorf("failed to send auth methods: %v", err)
|
|
}
|
|
|
|
// 读取服务器响应
|
|
response := make([]byte, 2)
|
|
if _, err := io.ReadFull(conn, response); err != nil {
|
|
return fmt.Errorf("failed to read auth response: %v", err)
|
|
}
|
|
|
|
if response[0] != 0x05 {
|
|
return fmt.Errorf("invalid SOCKS version: %d", response[0])
|
|
}
|
|
|
|
// 第二步:处理认证
|
|
switch response[1] {
|
|
case 0x00: // 无认证
|
|
logger.Debug("SOCKS5 server requires no authentication")
|
|
case 0x02: // 用户名密码认证
|
|
if err := p.performUserPassAuth(conn); err != nil {
|
|
return fmt.Errorf("user/pass authentication failed: %v", err)
|
|
}
|
|
case 0xFF: // 无可接受的认证方法
|
|
return fmt.Errorf("no acceptable authentication methods")
|
|
default:
|
|
return fmt.Errorf("unsupported authentication method: %d", response[1])
|
|
}
|
|
|
|
// 第三步:发送连接请求
|
|
connectReq, err := p.buildConnectRequest(targetAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to build connect request: %v", err)
|
|
}
|
|
|
|
if _, err := conn.Write(connectReq); err != nil {
|
|
return fmt.Errorf("failed to send connect request: %v", err)
|
|
}
|
|
|
|
// 第四步:读取连接响应
|
|
return p.readConnectResponse(conn)
|
|
}
|
|
|
|
// buildConnectRequest 构建连接请求
|
|
func (p *SOCKS5Proxy) buildConnectRequest(targetAddr string) ([]byte, error) {
|
|
host, portStr, err := net.SplitHostPort(targetAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid target address: %v", err)
|
|
}
|
|
|
|
// 解析端口号
|
|
portNum, err := parsePort(portStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid port: %v", err)
|
|
}
|
|
|
|
var connectReq []byte
|
|
|
|
// 检测地址类型并构建请求
|
|
if ip := net.ParseIP(host); ip != nil {
|
|
if ip4 := ip.To4(); ip4 != nil {
|
|
// IPv4地址
|
|
connectReq = []byte{0x05, 0x01, 0x00, 0x01}
|
|
connectReq = append(connectReq, ip4...)
|
|
} else if ip6 := ip.To16(); ip6 != nil {
|
|
// IPv6地址
|
|
connectReq = []byte{0x05, 0x01, 0x00, 0x04}
|
|
connectReq = append(connectReq, ip6...)
|
|
}
|
|
} else {
|
|
// 域名
|
|
if len(host) > 255 {
|
|
return nil, fmt.Errorf("domain name too long: %d", len(host))
|
|
}
|
|
connectReq = []byte{0x05, 0x01, 0x00, 0x03}
|
|
connectReq = append(connectReq, byte(len(host)))
|
|
connectReq = append(connectReq, []byte(host)...)
|
|
}
|
|
|
|
// 添加端口
|
|
connectReq = append(connectReq, byte(portNum>>8), byte(portNum&0xFF))
|
|
|
|
return connectReq, nil
|
|
}
|
|
|
|
// readConnectResponse 读取连接响应
|
|
func (p *SOCKS5Proxy) readConnectResponse(conn net.Conn) error {
|
|
// 读取响应头部
|
|
header := make([]byte, 4)
|
|
if _, err := io.ReadFull(conn, header); err != nil {
|
|
return fmt.Errorf("failed to read connect response header: %v", err)
|
|
}
|
|
|
|
if header[0] != 0x05 {
|
|
return fmt.Errorf("invalid SOCKS version in response: %d", header[0])
|
|
}
|
|
|
|
if header[1] != 0x00 {
|
|
return fmt.Errorf("connection failed, status: %d (%s)", header[1], getSOCKS5ErrorMessage(header[1]))
|
|
}
|
|
|
|
// 读取绑定地址和端口
|
|
addrType := header[3]
|
|
switch addrType {
|
|
case 0x01: // IPv4
|
|
skipBytes := make([]byte, 6) // 4字节IP + 2字节端口
|
|
_, err := io.ReadFull(conn, skipBytes)
|
|
return err
|
|
case 0x03: // 域名
|
|
lenByte := make([]byte, 1)
|
|
if _, err := io.ReadFull(conn, lenByte); err != nil {
|
|
return err
|
|
}
|
|
skipBytes := make([]byte, int(lenByte[0])+2) // 域名长度 + 2字节端口
|
|
_, err := io.ReadFull(conn, skipBytes)
|
|
return err
|
|
case 0x04: // IPv6
|
|
skipBytes := make([]byte, 18) // 16字节IP + 2字节端口
|
|
_, err := io.ReadFull(conn, skipBytes)
|
|
return err
|
|
default:
|
|
return fmt.Errorf("unsupported address type: %d", addrType)
|
|
}
|
|
}
|
|
|
|
// performUserPassAuth 执行用户名密码认证
|
|
func (p *SOCKS5Proxy) performUserPassAuth(conn net.Conn) error {
|
|
// 发送用户名密码
|
|
authData := []byte{0x01} // 子协议版本
|
|
authData = append(authData, byte(len(p.username)))
|
|
authData = append(authData, []byte(p.username)...)
|
|
authData = append(authData, byte(len(p.password)))
|
|
authData = append(authData, []byte(p.password)...)
|
|
|
|
if _, err := conn.Write(authData); err != nil {
|
|
return fmt.Errorf("failed to send credentials: %v", err)
|
|
}
|
|
|
|
// 读取认证结果
|
|
authResult := make([]byte, 2)
|
|
if _, err := io.ReadFull(conn, authResult); err != nil {
|
|
return fmt.Errorf("failed to read auth result: %v", err)
|
|
}
|
|
|
|
if authResult[0] != 0x01 {
|
|
return fmt.Errorf("invalid auth response version: %d", authResult[0])
|
|
}
|
|
|
|
if authResult[1] != 0x00 {
|
|
return fmt.Errorf("authentication failed")
|
|
}
|
|
|
|
logger.Debug("SOCKS5 authentication successful")
|
|
return nil
|
|
}
|
|
|
|
// getSOCKS5ErrorMessage 获取SOCKS5错误消息
|
|
func getSOCKS5ErrorMessage(code byte) string {
|
|
switch code {
|
|
case 0x01:
|
|
return "general SOCKS server failure"
|
|
case 0x02:
|
|
return "connection not allowed by ruleset"
|
|
case 0x03:
|
|
return "network unreachable"
|
|
case 0x04:
|
|
return "host unreachable"
|
|
case 0x05:
|
|
return "connection refused"
|
|
case 0x06:
|
|
return "TTL expired"
|
|
case 0x07:
|
|
return "command not supported"
|
|
case 0x08:
|
|
return "address type not supported"
|
|
default:
|
|
return "unknown error"
|
|
}
|
|
}
|
|
|
|
// parsePort 解析端口号
|
|
func parsePort(portStr string) (int, error) {
|
|
if portStr == "" {
|
|
return 80, nil // 默认HTTP端口
|
|
}
|
|
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("invalid port format: %s", portStr)
|
|
}
|
|
|
|
if port < 1 || port > 65535 {
|
|
return 0, fmt.Errorf("port out of range: %d", port)
|
|
}
|
|
|
|
return port, nil
|
|
}
|
|
|