You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
433 lines
9.8 KiB
433 lines
9.8 KiB
2 weeks ago
|
package socks5
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"time"
|
||
|
|
||
|
"github.com/sirupsen/logrus"
|
||
|
)
|
||
|
|
||
|
// SOCKS5 协议常量
|
||
|
const (
|
||
|
// SOCKS版本
|
||
|
Version5 = 0x05
|
||
|
|
||
|
// 认证方法
|
||
|
AuthNone = 0x00
|
||
|
AuthPassword = 0x02
|
||
|
AuthNoSupported = 0xFF
|
||
|
|
||
|
// 命令类型
|
||
|
CmdConnect = 0x01
|
||
|
CmdBind = 0x02
|
||
|
CmdUDP = 0x03
|
||
|
|
||
|
// 地址类型
|
||
|
AddrIPv4 = 0x01
|
||
|
AddrDomain = 0x03
|
||
|
AddrIPv6 = 0x04
|
||
|
|
||
|
// 响应状态
|
||
|
StatusSuccess = 0x00
|
||
|
StatusServerFailure = 0x01
|
||
|
StatusConnectionNotAllowed = 0x02
|
||
|
StatusNetworkUnreachable = 0x03
|
||
|
StatusHostUnreachable = 0x04
|
||
|
StatusConnectionRefused = 0x05
|
||
|
StatusTTLExpired = 0x06
|
||
|
StatusCommandNotSupported = 0x07
|
||
|
StatusAddressNotSupported = 0x08
|
||
|
)
|
||
|
|
||
|
type Server struct {
|
||
|
logger *logrus.Logger
|
||
|
auth AuthHandler
|
||
|
dialer Dialer
|
||
|
resolver Resolver
|
||
|
rules []Rule
|
||
|
}
|
||
|
|
||
|
type Config struct {
|
||
|
Auth AuthConfig
|
||
|
Timeout time.Duration
|
||
|
Rules []RuleConfig
|
||
|
}
|
||
|
|
||
|
type AuthConfig struct {
|
||
|
Methods []string
|
||
|
Username string
|
||
|
Password string
|
||
|
}
|
||
|
|
||
|
type RuleConfig struct {
|
||
|
Action string // allow, deny
|
||
|
IPs []string
|
||
|
Ports []int
|
||
|
}
|
||
|
|
||
|
type AuthHandler interface {
|
||
|
Authenticate(username, password string) bool
|
||
|
Methods() []byte
|
||
|
}
|
||
|
|
||
|
type Dialer interface {
|
||
|
Dial(network, address string) (net.Conn, error)
|
||
|
}
|
||
|
|
||
|
type Resolver interface {
|
||
|
Resolve(domain string) (net.IP, error)
|
||
|
}
|
||
|
|
||
|
type Rule interface {
|
||
|
Allow(addr net.Addr) bool
|
||
|
}
|
||
|
|
||
|
// NewServer 创建新的SOCKS5服务器
|
||
|
func NewServer(config Config, logger *logrus.Logger) *Server {
|
||
|
server := &Server{
|
||
|
logger: logger,
|
||
|
dialer: &DirectDialer{timeout: config.Timeout},
|
||
|
resolver: &DirectResolver{},
|
||
|
}
|
||
|
|
||
|
// 设置认证处理器
|
||
|
server.auth = NewAuthHandler(config.Auth)
|
||
|
|
||
|
// 设置访问规则
|
||
|
for _, ruleConfig := range config.Rules {
|
||
|
rule := NewRule(ruleConfig)
|
||
|
server.rules = append(server.rules, rule)
|
||
|
}
|
||
|
|
||
|
return server
|
||
|
}
|
||
|
|
||
|
// HandleConnection 处理SOCKS5连接
|
||
|
func (s *Server) HandleConnection(conn net.Conn) error {
|
||
|
defer conn.Close()
|
||
|
|
||
|
s.logger.WithField("remote_addr", conn.RemoteAddr()).Debug("New SOCKS5 connection")
|
||
|
|
||
|
// 设置连接超时
|
||
|
conn.SetDeadline(time.Now().Add(30 * time.Second))
|
||
|
|
||
|
// 1. 协议版本协商
|
||
|
if err := s.handleVersionNegotiation(conn); err != nil {
|
||
|
s.logger.WithError(err).Error("Version negotiation failed")
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// 2. 认证
|
||
|
if err := s.handleAuthentication(conn); err != nil {
|
||
|
s.logger.WithError(err).Error("Authentication failed")
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// 3. 处理请求
|
||
|
return s.handleRequest(conn)
|
||
|
}
|
||
|
|
||
|
// handleVersionNegotiation 处理版本协商
|
||
|
func (s *Server) handleVersionNegotiation(conn net.Conn) error {
|
||
|
// 读取客户端版本协商请求
|
||
|
buf := make([]byte, 2)
|
||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||
|
return fmt.Errorf("failed to read version: %w", err)
|
||
|
}
|
||
|
|
||
|
version := buf[0]
|
||
|
nmethods := buf[1]
|
||
|
|
||
|
if version != Version5 {
|
||
|
return fmt.Errorf("unsupported SOCKS version: %d", version)
|
||
|
}
|
||
|
|
||
|
// 读取支持的认证方法
|
||
|
methods := make([]byte, nmethods)
|
||
|
if _, err := io.ReadFull(conn, methods); err != nil {
|
||
|
return fmt.Errorf("failed to read methods: %w", err)
|
||
|
}
|
||
|
|
||
|
// 选择认证方法
|
||
|
authMethods := s.auth.Methods()
|
||
|
selectedMethod := byte(AuthNoSupported)
|
||
|
|
||
|
for _, method := range methods {
|
||
|
for _, supported := range authMethods {
|
||
|
if method == supported {
|
||
|
selectedMethod = method
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
if selectedMethod != AuthNoSupported {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// 发送选择的认证方法
|
||
|
response := []byte{Version5, selectedMethod}
|
||
|
if _, err := conn.Write(response); err != nil {
|
||
|
return fmt.Errorf("failed to write method selection: %w", err)
|
||
|
}
|
||
|
|
||
|
if selectedMethod == AuthNoSupported {
|
||
|
return fmt.Errorf("no supported authentication method")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// handleAuthentication 处理认证
|
||
|
func (s *Server) handleAuthentication(conn net.Conn) error {
|
||
|
// 读取认证请求
|
||
|
buf := make([]byte, 2)
|
||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||
|
return fmt.Errorf("failed to read auth version: %w", err)
|
||
|
}
|
||
|
|
||
|
version := buf[0]
|
||
|
usernameLen := buf[1]
|
||
|
|
||
|
if version != 0x01 {
|
||
|
return fmt.Errorf("unsupported auth version: %d", version)
|
||
|
}
|
||
|
|
||
|
// 读取用户名
|
||
|
username := make([]byte, usernameLen)
|
||
|
if _, err := io.ReadFull(conn, username); err != nil {
|
||
|
return fmt.Errorf("failed to read username: %w", err)
|
||
|
}
|
||
|
|
||
|
// 读取密码长度
|
||
|
if _, err := io.ReadFull(conn, buf[:1]); err != nil {
|
||
|
return fmt.Errorf("failed to read password length: %w", err)
|
||
|
}
|
||
|
passwordLen := buf[0]
|
||
|
|
||
|
// 读取密码
|
||
|
password := make([]byte, passwordLen)
|
||
|
if _, err := io.ReadFull(conn, password); err != nil {
|
||
|
return fmt.Errorf("failed to read password: %w", err)
|
||
|
}
|
||
|
|
||
|
// 验证认证
|
||
|
success := s.auth.Authenticate(string(username), string(password))
|
||
|
|
||
|
// 发送认证结果
|
||
|
status := byte(0x01) // 失败
|
||
|
if success {
|
||
|
status = 0x00 // 成功
|
||
|
}
|
||
|
|
||
|
response := []byte{0x01, status}
|
||
|
if _, err := conn.Write(response); err != nil {
|
||
|
return fmt.Errorf("failed to write auth response: %w", err)
|
||
|
}
|
||
|
|
||
|
if !success {
|
||
|
return fmt.Errorf("authentication failed")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// handleRequest 处理SOCKS5请求
|
||
|
func (s *Server) handleRequest(conn net.Conn) error {
|
||
|
// 读取请求头
|
||
|
buf := make([]byte, 4)
|
||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||
|
return fmt.Errorf("failed to read request header: %w", err)
|
||
|
}
|
||
|
|
||
|
version := buf[0]
|
||
|
cmd := buf[1]
|
||
|
// rsv := buf[2] // 保留字段
|
||
|
addrType := buf[3]
|
||
|
|
||
|
if version != Version5 {
|
||
|
return fmt.Errorf("invalid SOCKS version: %d", version)
|
||
|
}
|
||
|
|
||
|
// 读取目标地址
|
||
|
addr, err := s.readAddress(conn, addrType)
|
||
|
if err != nil {
|
||
|
s.sendResponse(conn, StatusServerFailure, "0.0.0.0", 0)
|
||
|
return fmt.Errorf("failed to read address: %w", err)
|
||
|
}
|
||
|
|
||
|
// 检查访问规则
|
||
|
if !s.checkRules(addr) {
|
||
|
s.sendResponse(conn, StatusConnectionNotAllowed, "0.0.0.0", 0)
|
||
|
return fmt.Errorf("connection not allowed: %s", addr)
|
||
|
}
|
||
|
|
||
|
// 处理不同的命令
|
||
|
switch cmd {
|
||
|
case CmdConnect:
|
||
|
return s.handleConnect(conn, addr)
|
||
|
case CmdBind:
|
||
|
return s.handleBind(conn, addr)
|
||
|
case CmdUDP:
|
||
|
return s.handleUDP(conn, addr)
|
||
|
default:
|
||
|
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0)
|
||
|
return fmt.Errorf("unsupported command: %d", cmd)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// readAddress 读取地址信息
|
||
|
func (s *Server) readAddress(conn net.Conn, addrType byte) (string, error) {
|
||
|
switch addrType {
|
||
|
case AddrIPv4:
|
||
|
buf := make([]byte, 6) // 4字节IP + 2字节端口
|
||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
ip := net.IP(buf[:4])
|
||
|
port := binary.BigEndian.Uint16(buf[4:6])
|
||
|
return fmt.Sprintf("%s:%d", ip.String(), port), nil
|
||
|
|
||
|
case AddrDomain:
|
||
|
buf := make([]byte, 1)
|
||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
domainLen := buf[0]
|
||
|
|
||
|
domain := make([]byte, domainLen+2) // 域名 + 2字节端口
|
||
|
if _, err := io.ReadFull(conn, domain); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
port := binary.BigEndian.Uint16(domain[domainLen:])
|
||
|
return fmt.Sprintf("%s:%d", string(domain[:domainLen]), port), nil
|
||
|
|
||
|
case AddrIPv6:
|
||
|
buf := make([]byte, 18) // 16字节IP + 2字节端口
|
||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
ip := net.IP(buf[:16])
|
||
|
port := binary.BigEndian.Uint16(buf[16:18])
|
||
|
return fmt.Sprintf("[%s]:%d", ip.String(), port), nil
|
||
|
|
||
|
default:
|
||
|
return "", fmt.Errorf("unsupported address type: %d", addrType)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// handleConnect 处理CONNECT命令
|
||
|
func (s *Server) handleConnect(conn net.Conn, addr string) error {
|
||
|
s.logger.WithField("target", addr).Debug("Handling CONNECT request")
|
||
|
|
||
|
// 连接到目标服务器
|
||
|
target, err := s.dialer.Dial("tcp", addr)
|
||
|
if err != nil {
|
||
|
s.logger.WithError(err).WithField("target", addr).Error("Failed to connect to target")
|
||
|
s.sendResponse(conn, StatusConnectionRefused, "0.0.0.0", 0)
|
||
|
return err
|
||
|
}
|
||
|
defer target.Close()
|
||
|
|
||
|
// 发送成功响应
|
||
|
localAddr := target.LocalAddr().(*net.TCPAddr)
|
||
|
s.sendResponse(conn, StatusSuccess, localAddr.IP.String(), uint16(localAddr.Port))
|
||
|
|
||
|
// 开始数据转发
|
||
|
s.logger.WithField("target", addr).Info("Starting data relay")
|
||
|
return s.relay(conn, target)
|
||
|
}
|
||
|
|
||
|
// handleBind 处理BIND命令
|
||
|
func (s *Server) handleBind(conn net.Conn, addr string) error {
|
||
|
// BIND命令实现(简化版本)
|
||
|
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0)
|
||
|
return fmt.Errorf("BIND command not implemented")
|
||
|
}
|
||
|
|
||
|
// handleUDP 处理UDP命令
|
||
|
func (s *Server) handleUDP(conn net.Conn, addr string) error {
|
||
|
// UDP关联实现(简化版本)
|
||
|
s.sendResponse(conn, StatusCommandNotSupported, "0.0.0.0", 0)
|
||
|
return fmt.Errorf("UDP command not implemented")
|
||
|
}
|
||
|
|
||
|
// sendResponse 发送SOCKS5响应
|
||
|
func (s *Server) sendResponse(conn net.Conn, status byte, ip string, port uint16) error {
|
||
|
response := []byte{
|
||
|
Version5,
|
||
|
status,
|
||
|
0x00, // 保留字段
|
||
|
AddrIPv4,
|
||
|
}
|
||
|
|
||
|
// 添加IP地址
|
||
|
ipAddr := net.ParseIP(ip)
|
||
|
if ipAddr == nil {
|
||
|
ipAddr = net.ParseIP("0.0.0.0")
|
||
|
}
|
||
|
if ipv4 := ipAddr.To4(); ipv4 != nil {
|
||
|
response = append(response, ipv4...)
|
||
|
} else {
|
||
|
response[3] = AddrIPv6
|
||
|
response = append(response, ipAddr.To16()...)
|
||
|
}
|
||
|
|
||
|
// 添加端口
|
||
|
portBytes := make([]byte, 2)
|
||
|
binary.BigEndian.PutUint16(portBytes, port)
|
||
|
response = append(response, portBytes...)
|
||
|
|
||
|
_, err := conn.Write(response)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// checkRules 检查访问规则
|
||
|
func (s *Server) checkRules(addr string) bool {
|
||
|
if len(s.rules) == 0 {
|
||
|
return true // 没有规则时允许所有连接
|
||
|
}
|
||
|
|
||
|
host, _, err := net.SplitHostPort(addr)
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
targetAddr, err := net.ResolveIPAddr("ip", host)
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
for _, rule := range s.rules {
|
||
|
if rule.Allow(targetAddr) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// relay 数据转发
|
||
|
func (s *Server) relay(client, target net.Conn) error {
|
||
|
// 创建双向数据转发
|
||
|
errChan := make(chan error, 2)
|
||
|
|
||
|
// 客户端到目标服务器
|
||
|
go func() {
|
||
|
_, err := io.Copy(target, client)
|
||
|
errChan <- err
|
||
|
}()
|
||
|
|
||
|
// 目标服务器到客户端
|
||
|
go func() {
|
||
|
_, err := io.Copy(client, target)
|
||
|
errChan <- err
|
||
|
}()
|
||
|
|
||
|
// 等待任一方向的连接关闭
|
||
|
err := <-errChan
|
||
|
return err
|
||
|
}
|