package socks5 import ( "encoding/binary" "fmt" "io" "net" "time" "github.com/sirupsen/logrus" ) // SOCKS5 协议常量 const ( // SOCKS版本 Version5 = 0x05 // 认证方法 AuthNone = 0x00 AuthPassword = 0x02 AuthNoSupported = 0xFF // 命令类型 CmdConnect = 0x01 CmdBind = 0x02 CmdUDP = 0x03 // 地址类型 AddrIPv4 = 0x01 AddrDomain = 0x03 AddrIPv6 = 0x04 // 响应状态 StatusSuccess = 0x00 StatusServerFailure = 0x01 StatusConnectionNotAllowed = 0x02 StatusNetworkUnreachable = 0x03 StatusHostUnreachable = 0x04 StatusConnectionRefused = 0x05 StatusTTLExpired = 0x06 StatusCommandNotSupported = 0x07 StatusAddressNotSupported = 0x08 ) type Server struct { logger *logrus.Logger auth AuthHandler dialer Dialer resolver Resolver rules []Rule } type Config struct { Auth AuthConfig Timeout time.Duration Rules []RuleConfig } type AuthConfig struct { Methods []string Username string Password string } type RuleConfig struct { Action string // allow, deny IPs []string Ports []int } type AuthHandler interface { Authenticate(username, password string) bool Methods() []byte } type Dialer interface { Dial(network, address string) (net.Conn, error) } type Resolver interface { Resolve(domain string) (net.IP, error) } type Rule interface { Allow(addr net.Addr) bool } // NewServer 创建新的SOCKS5服务器 func NewServer(config Config, logger *logrus.Logger) *Server { server := &Server{ logger: logger, dialer: &DirectDialer{timeout: config.Timeout}, resolver: &DirectResolver{}, } // 设置认证处理器 server.auth = NewAuthHandler(config.Auth) // 设置访问规则 for _, ruleConfig := range config.Rules { rule := NewRule(ruleConfig) server.rules = append(server.rules, rule) } return server } // HandleConnection 处理SOCKS5连接 func (s *Server) HandleConnection(conn net.Conn) error { defer conn.Close() s.logger.WithField("remote_addr", conn.RemoteAddr()).Debug("New SOCKS5 connection") // 设置连接超时 conn.SetDeadline(time.Now().Add(30 * time.Second)) // 1. 协议版本协商 if err := s.handleVersionNegotiation(conn); err != nil { s.logger.WithError(err).Error("Version negotiation failed") return err } // 2. 认证 if err := s.handleAuthentication(conn); err != nil { s.logger.WithError(err).Error("Authentication failed") return err } // 3. 处理请求 return s.handleRequest(conn) } // handleVersionNegotiation 处理版本协商 func (s *Server) handleVersionNegotiation(conn net.Conn) error { // 读取客户端版本协商请求 buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { return fmt.Errorf("failed to read version: %w", err) } version := buf[0] nmethods := buf[1] if version != Version5 { return fmt.Errorf("unsupported SOCKS version: %d", version) } // 读取支持的认证方法 methods := make([]byte, nmethods) if _, err := io.ReadFull(conn, methods); err != nil { return fmt.Errorf("failed to read methods: %w", err) } // 选择认证方法 authMethods := s.auth.Methods() selectedMethod := byte(AuthNoSupported) for _, method := range methods { for _, supported := range authMethods { if method == supported { selectedMethod = method break } } if selectedMethod != AuthNoSupported { break } } // 发送选择的认证方法 response := []byte{Version5, selectedMethod} if _, err := conn.Write(response); err != nil { return fmt.Errorf("failed to write method selection: %w", err) } if selectedMethod == AuthNoSupported { return fmt.Errorf("no supported authentication method") } return nil } // handleAuthentication 处理认证 func (s *Server) handleAuthentication(conn net.Conn) error { // 读取认证请求 buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { return fmt.Errorf("failed to read auth version: %w", err) } version := buf[0] usernameLen := buf[1] if version != 0x01 { return fmt.Errorf("unsupported auth version: %d", version) } // 读取用户名 username := make([]byte, usernameLen) if _, err := io.ReadFull(conn, username); err != nil { return fmt.Errorf("failed to read username: %w", err) } // 读取密码长度 if _, err := io.ReadFull(conn, buf[:1]); err != nil { return fmt.Errorf("failed to read password length: %w", err) } passwordLen := buf[0] // 读取密码 password := make([]byte, passwordLen) if _, err := io.ReadFull(conn, password); err != nil { return fmt.Errorf("failed to read password: %w", err) } // 验证认证 success := s.auth.Authenticate(string(username), string(password)) // 发送认证结果 status := byte(0x01) // 失败 if success { status = 0x00 // 成功 } response := []byte{0x01, status} if _, err := conn.Write(response); err != nil { return fmt.Errorf("failed to write auth response: %w", err) } if !success { return fmt.Errorf("authentication failed") } return nil } // handleRequest 处理SOCKS5请求 func (s *Server) handleRequest(conn net.Conn) error { // 读取请求头 buf := make([]byte, 4) if _, err := io.ReadFull(conn, buf); err != nil { return fmt.Errorf("failed to read request header: %w", err) } version := buf[0] cmd := buf[1] // rsv := buf[2] // 保留字段 addrType := buf[3] if version != Version5 { return fmt.Errorf("invalid SOCKS version: %d", version) } // 读取目标地址 addr, err := s.readAddress(conn, addrType) if err != nil { s.sendResponse(conn, StatusServerFailure, "0.0.0.0", 0) return fmt.Errorf("failed to read address: %w", err) } // 检查访问规则 if !s.checkRules(addr) { s.sendResponse(conn, StatusConnectionNotAllowed, "0.0.0.0", 0) return fmt.Errorf("connection not allowed: %s", addr) } // 处理不同的命令 switch cmd { case CmdConnect: return s.handleConnect(conn, addr) case CmdBind: return s.handleBind(conn, addr) case CmdUDP: return s.handleUDP(conn, addr) default: s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0) return fmt.Errorf("unsupported command: %d", cmd) } } // readAddress 读取地址信息 func (s *Server) readAddress(conn net.Conn, addrType byte) (string, error) { switch addrType { case AddrIPv4: buf := make([]byte, 6) // 4字节IP + 2字节端口 if _, err := io.ReadFull(conn, buf); err != nil { return "", err } ip := net.IP(buf[:4]) port := binary.BigEndian.Uint16(buf[4:6]) return fmt.Sprintf("%s:%d", ip.String(), port), nil case AddrDomain: buf := make([]byte, 1) if _, err := io.ReadFull(conn, buf); err != nil { return "", err } domainLen := buf[0] domain := make([]byte, domainLen+2) // 域名 + 2字节端口 if _, err := io.ReadFull(conn, domain); err != nil { return "", err } port := binary.BigEndian.Uint16(domain[domainLen:]) return fmt.Sprintf("%s:%d", string(domain[:domainLen]), port), nil case AddrIPv6: buf := make([]byte, 18) // 16字节IP + 2字节端口 if _, err := io.ReadFull(conn, buf); err != nil { return "", err } ip := net.IP(buf[:16]) port := binary.BigEndian.Uint16(buf[16:18]) return fmt.Sprintf("[%s]:%d", ip.String(), port), nil default: return "", fmt.Errorf("unsupported address type: %d", addrType) } } // handleConnect 处理CONNECT命令 func (s *Server) handleConnect(conn net.Conn, addr string) error { s.logger.WithField("target", addr).Debug("Handling CONNECT request") // 连接到目标服务器 target, err := s.dialer.Dial("tcp", addr) if err != nil { s.logger.WithError(err).WithField("target", addr).Error("Failed to connect to target") s.sendResponse(conn, StatusConnectionRefused, "0.0.0.0", 0) return err } defer target.Close() // 发送成功响应 localAddr := target.LocalAddr().(*net.TCPAddr) s.sendResponse(conn, StatusSuccess, localAddr.IP.String(), uint16(localAddr.Port)) // 开始数据转发 s.logger.WithField("target", addr).Info("Starting data relay") return s.relay(conn, target) } // handleBind 处理BIND命令 func (s *Server) handleBind(conn net.Conn, addr string) error { // BIND命令实现(简化版本) s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0) return fmt.Errorf("BIND command not implemented") } // handleUDP 处理UDP命令 func (s *Server) handleUDP(conn net.Conn, addr string) error { // UDP关联实现(简化版本) s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0) return fmt.Errorf("UDP command not implemented") } // sendResponse 发送SOCKS5响应 func (s *Server) sendResponse(conn net.Conn, status byte, ip string, port uint16) error { response := []byte{ Version5, status, 0x00, // 保留字段 AddrIPv4, } // 添加IP地址 ipAddr := net.ParseIP(ip) if ipAddr == nil { ipAddr = net.ParseIP("0.0.0.0") } if ipv4 := ipAddr.To4(); ipv4 != nil { response = append(response, ipv4...) } else { response[3] = AddrIPv6 response = append(response, ipAddr.To16()...) } // 添加端口 portBytes := make([]byte, 2) binary.BigEndian.PutUint16(portBytes, port) response = append(response, portBytes...) _, err := conn.Write(response) return err } // checkRules 检查访问规则 func (s *Server) checkRules(addr string) bool { if len(s.rules) == 0 { return true // 没有规则时允许所有连接 } host, _, err := net.SplitHostPort(addr) if err != nil { return false } targetAddr, err := net.ResolveIPAddr("ip", host) if err != nil { return false } for _, rule := range s.rules { if rule.Allow(targetAddr) { return true } } return false } // relay 数据转发 func (s *Server) relay(client, target net.Conn) error { // 创建双向数据转发 errChan := make(chan error, 2) // 客户端到目标服务器 go func() { _, err := io.Copy(target, client) errChan <- err }() // 目标服务器到客户端 go func() { _, err := io.Copy(client, target) errChan <- err }() // 等待任一方向的连接关闭 err := <-errChan return err }