You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

433 lines
9.8 KiB

2 weeks ago
package socks5
import (
"encoding/binary"
"fmt"
"io"
"net"
"time"
"github.com/sirupsen/logrus"
)
// SOCKS5 协议常量
const (
// SOCKS版本
Version5 = 0x05
// 认证方法
AuthNone = 0x00
AuthPassword = 0x02
AuthNoSupported = 0xFF
// 命令类型
CmdConnect = 0x01
CmdBind = 0x02
CmdUDP = 0x03
// 地址类型
AddrIPv4 = 0x01
AddrDomain = 0x03
AddrIPv6 = 0x04
// 响应状态
StatusSuccess = 0x00
StatusServerFailure = 0x01
StatusConnectionNotAllowed = 0x02
StatusNetworkUnreachable = 0x03
StatusHostUnreachable = 0x04
StatusConnectionRefused = 0x05
StatusTTLExpired = 0x06
StatusCommandNotSupported = 0x07
StatusAddressNotSupported = 0x08
)
type Server struct {
logger *logrus.Logger
auth AuthHandler
dialer Dialer
resolver Resolver
rules []Rule
}
type Config struct {
Auth AuthConfig
Timeout time.Duration
Rules []RuleConfig
}
type AuthConfig struct {
Methods []string
Username string
Password string
}
type RuleConfig struct {
Action string // allow, deny
IPs []string
Ports []int
}
type AuthHandler interface {
Authenticate(username, password string) bool
Methods() []byte
}
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}
type Resolver interface {
Resolve(domain string) (net.IP, error)
}
type Rule interface {
Allow(addr net.Addr) bool
}
// NewServer 创建新的SOCKS5服务器
func NewServer(config Config, logger *logrus.Logger) *Server {
server := &Server{
logger: logger,
dialer: &DirectDialer{timeout: config.Timeout},
resolver: &DirectResolver{},
}
// 设置认证处理器
server.auth = NewAuthHandler(config.Auth)
// 设置访问规则
for _, ruleConfig := range config.Rules {
rule := NewRule(ruleConfig)
server.rules = append(server.rules, rule)
}
return server
}
// HandleConnection 处理SOCKS5连接
func (s *Server) HandleConnection(conn net.Conn) error {
defer conn.Close()
s.logger.WithField("remote_addr", conn.RemoteAddr()).Debug("New SOCKS5 connection")
// 设置连接超时
conn.SetDeadline(time.Now().Add(30 * time.Second))
// 1. 协议版本协商
if err := s.handleVersionNegotiation(conn); err != nil {
s.logger.WithError(err).Error("Version negotiation failed")
return err
}
// 2. 认证
if err := s.handleAuthentication(conn); err != nil {
s.logger.WithError(err).Error("Authentication failed")
return err
}
// 3. 处理请求
return s.handleRequest(conn)
}
// handleVersionNegotiation 处理版本协商
func (s *Server) handleVersionNegotiation(conn net.Conn) error {
// 读取客户端版本协商请求
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return fmt.Errorf("failed to read version: %w", err)
}
version := buf[0]
nmethods := buf[1]
if version != Version5 {
return fmt.Errorf("unsupported SOCKS version: %d", version)
}
// 读取支持的认证方法
methods := make([]byte, nmethods)
if _, err := io.ReadFull(conn, methods); err != nil {
return fmt.Errorf("failed to read methods: %w", err)
}
// 选择认证方法
authMethods := s.auth.Methods()
selectedMethod := byte(AuthNoSupported)
for _, method := range methods {
for _, supported := range authMethods {
if method == supported {
selectedMethod = method
break
}
}
if selectedMethod != AuthNoSupported {
break
}
}
// 发送选择的认证方法
response := []byte{Version5, selectedMethod}
if _, err := conn.Write(response); err != nil {
return fmt.Errorf("failed to write method selection: %w", err)
}
if selectedMethod == AuthNoSupported {
return fmt.Errorf("no supported authentication method")
}
return nil
}
// handleAuthentication 处理认证
func (s *Server) handleAuthentication(conn net.Conn) error {
// 读取认证请求
buf := make([]byte, 2)
if _, err := io.ReadFull(conn, buf); err != nil {
return fmt.Errorf("failed to read auth version: %w", err)
}
version := buf[0]
usernameLen := buf[1]
if version != 0x01 {
return fmt.Errorf("unsupported auth version: %d", version)
}
// 读取用户名
username := make([]byte, usernameLen)
if _, err := io.ReadFull(conn, username); err != nil {
return fmt.Errorf("failed to read username: %w", err)
}
// 读取密码长度
if _, err := io.ReadFull(conn, buf[:1]); err != nil {
return fmt.Errorf("failed to read password length: %w", err)
}
passwordLen := buf[0]
// 读取密码
password := make([]byte, passwordLen)
if _, err := io.ReadFull(conn, password); err != nil {
return fmt.Errorf("failed to read password: %w", err)
}
// 验证认证
success := s.auth.Authenticate(string(username), string(password))
// 发送认证结果
status := byte(0x01) // 失败
if success {
status = 0x00 // 成功
}
response := []byte{0x01, status}
if _, err := conn.Write(response); err != nil {
return fmt.Errorf("failed to write auth response: %w", err)
}
if !success {
return fmt.Errorf("authentication failed")
}
return nil
}
// handleRequest 处理SOCKS5请求
func (s *Server) handleRequest(conn net.Conn) error {
// 读取请求头
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
return fmt.Errorf("failed to read request header: %w", err)
}
version := buf[0]
cmd := buf[1]
// rsv := buf[2] // 保留字段
addrType := buf[3]
if version != Version5 {
return fmt.Errorf("invalid SOCKS version: %d", version)
}
// 读取目标地址
addr, err := s.readAddress(conn, addrType)
if err != nil {
s.sendResponse(conn, StatusServerFailure, "0.0.0.0", 0)
return fmt.Errorf("failed to read address: %w", err)
}
// 检查访问规则
if !s.checkRules(addr) {
s.sendResponse(conn, StatusConnectionNotAllowed, "0.0.0.0", 0)
return fmt.Errorf("connection not allowed: %s", addr)
}
// 处理不同的命令
switch cmd {
case CmdConnect:
return s.handleConnect(conn, addr)
case CmdBind:
return s.handleBind(conn, addr)
case CmdUDP:
return s.handleUDP(conn, addr)
default:
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0)
return fmt.Errorf("unsupported command: %d", cmd)
}
}
// readAddress 读取地址信息
func (s *Server) readAddress(conn net.Conn, addrType byte) (string, error) {
switch addrType {
case AddrIPv4:
buf := make([]byte, 6) // 4字节IP + 2字节端口
if _, err := io.ReadFull(conn, buf); err != nil {
return "", err
}
ip := net.IP(buf[:4])
port := binary.BigEndian.Uint16(buf[4:6])
return fmt.Sprintf("%s:%d", ip.String(), port), nil
case AddrDomain:
buf := make([]byte, 1)
if _, err := io.ReadFull(conn, buf); err != nil {
return "", err
}
domainLen := buf[0]
domain := make([]byte, domainLen+2) // 域名 + 2字节端口
if _, err := io.ReadFull(conn, domain); err != nil {
return "", err
}
port := binary.BigEndian.Uint16(domain[domainLen:])
return fmt.Sprintf("%s:%d", string(domain[:domainLen]), port), nil
case AddrIPv6:
buf := make([]byte, 18) // 16字节IP + 2字节端口
if _, err := io.ReadFull(conn, buf); err != nil {
return "", err
}
ip := net.IP(buf[:16])
port := binary.BigEndian.Uint16(buf[16:18])
return fmt.Sprintf("[%s]:%d", ip.String(), port), nil
default:
return "", fmt.Errorf("unsupported address type: %d", addrType)
}
}
// handleConnect 处理CONNECT命令
func (s *Server) handleConnect(conn net.Conn, addr string) error {
s.logger.WithField("target", addr).Debug("Handling CONNECT request")
// 连接到目标服务器
target, err := s.dialer.Dial("tcp", addr)
if err != nil {
s.logger.WithError(err).WithField("target", addr).Error("Failed to connect to target")
s.sendResponse(conn, StatusConnectionRefused, "0.0.0.0", 0)
return err
}
defer target.Close()
// 发送成功响应
localAddr := target.LocalAddr().(*net.TCPAddr)
s.sendResponse(conn, StatusSuccess, localAddr.IP.String(), uint16(localAddr.Port))
// 开始数据转发
s.logger.WithField("target", addr).Info("Starting data relay")
return s.relay(conn, target)
}
// handleBind 处理BIND命令
func (s *Server) handleBind(conn net.Conn, addr string) error {
// BIND命令实现(简化版本)
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0)
return fmt.Errorf("BIND command not implemented")
}
// handleUDP 处理UDP命令
func (s *Server) handleUDP(conn net.Conn, addr string) error {
// UDP关联实现(简化版本)
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0)
return fmt.Errorf("UDP command not implemented")
}
// sendResponse 发送SOCKS5响应
func (s *Server) sendResponse(conn net.Conn, status byte, ip string, port uint16) error {
response := []byte{
Version5,
status,
0x00, // 保留字段
AddrIPv4,
}
// 添加IP地址
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
ipAddr = net.ParseIP("0.0.0.0")
}
if ipv4 := ipAddr.To4(); ipv4 != nil {
response = append(response, ipv4...)
} else {
response[3] = AddrIPv6
response = append(response, ipAddr.To16()...)
}
// 添加端口
portBytes := make([]byte, 2)
binary.BigEndian.PutUint16(portBytes, port)
response = append(response, portBytes...)
_, err := conn.Write(response)
return err
}
// checkRules 检查访问规则
func (s *Server) checkRules(addr string) bool {
if len(s.rules) == 0 {
return true // 没有规则时允许所有连接
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return false
}
targetAddr, err := net.ResolveIPAddr("ip", host)
if err != nil {
return false
}
for _, rule := range s.rules {
if rule.Allow(targetAddr) {
return true
}
}
return false
}
// relay 数据转发
func (s *Server) relay(client, target net.Conn) error {
// 创建双向数据转发
errChan := make(chan error, 2)
// 客户端到目标服务器
go func() {
_, err := io.Copy(target, client)
errChan <- err
}()
// 目标服务器到客户端
go func() {
_, err := io.Copy(client, target)
errChan <- err
}()
// 等待任一方向的连接关闭
err := <-errChan
return err
}