package dns import ( "fmt" "net" "strings" "sync" "time" "github.com/azoic/wormhole-client/pkg/logger" ) // DNSProxy DNS代理服务器 type DNSProxy struct { upstreamDNS string localPort int server *net.UDPConn cache *dnsCache running bool mutex sync.RWMutex } // dnsCache DNS缓存 type dnsCache struct { entries map[string]*cacheEntry mutex sync.RWMutex } type cacheEntry struct { response []byte expiry time.Time } // NewDNSProxy 创建DNS代理 func NewDNSProxy(upstreamDNS string, localPort int) *DNSProxy { return &DNSProxy{ upstreamDNS: upstreamDNS, localPort: localPort, cache: &dnsCache{ entries: make(map[string]*cacheEntry), }, } } // Start 启动DNS代理服务器 func (d *DNSProxy) Start() error { d.mutex.Lock() defer d.mutex.Unlock() if d.running { return fmt.Errorf("DNS proxy is already running") } addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", d.localPort)) if err != nil { return fmt.Errorf("failed to resolve UDP address: %v", err) } d.server, err = net.ListenUDP("udp", addr) if err != nil { return fmt.Errorf("failed to start DNS proxy server: %v", err) } d.running = true logger.Info("DNS proxy started on port %d", d.localPort) // 启动清理缓存的goroutine go d.cleanupCache() // 处理DNS请求 go d.handleRequests() return nil } // Stop 停止DNS代理服务器 func (d *DNSProxy) Stop() error { d.mutex.Lock() defer d.mutex.Unlock() if !d.running { return nil } d.running = false if d.server != nil { d.server.Close() } logger.Info("DNS proxy stopped") return nil } // handleRequests 处理DNS请求 func (d *DNSProxy) handleRequests() { buffer := make([]byte, 512) // DNS消息最大512字节(UDP) for d.isRunning() { n, clientAddr, err := d.server.ReadFromUDP(buffer) if err != nil { if d.isRunning() { logger.Error("Failed to read DNS request: %v", err) } continue } go d.processRequest(buffer[:n], clientAddr) } } // processRequest 处理单个DNS请求 func (d *DNSProxy) processRequest(request []byte, clientAddr *net.UDPAddr) { // 简单的DNS请求解析 domain := d.extractDomain(request) logger.Debug("DNS request for domain: %s", domain) // 检查缓存 if cachedResponse := d.getFromCache(domain); cachedResponse != nil { logger.Debug("Serving %s from cache", domain) d.server.WriteToUDP(cachedResponse, clientAddr) return } // 转发到上游DNS服务器 response, err := d.forwardToUpstream(request) if err != nil { logger.Error("Failed to forward DNS request: %v", err) return } // 缓存响应 d.addToCache(domain, response) // 返回响应给客户端 d.server.WriteToUDP(response, clientAddr) } // extractDomain 从DNS请求中提取域名(简化实现) func (d *DNSProxy) extractDomain(request []byte) string { if len(request) < 12 { return "" } // 跳过DNS头部(12字节) offset := 12 var domain strings.Builder for offset < len(request) { length := int(request[offset]) if length == 0 { break } offset++ if offset+length > len(request) { break } if domain.Len() > 0 { domain.WriteByte('.') } domain.Write(request[offset : offset+length]) offset += length } return domain.String() } // forwardToUpstream 转发DNS请求到上游服务器 func (d *DNSProxy) forwardToUpstream(request []byte) ([]byte, error) { conn, err := net.Dial("udp", d.upstreamDNS) if err != nil { return nil, fmt.Errorf("failed to connect to upstream DNS: %v", err) } defer conn.Close() // 设置超时 conn.SetDeadline(time.Now().Add(5 * time.Second)) // 发送请求 if _, err := conn.Write(request); err != nil { return nil, fmt.Errorf("failed to send DNS request: %v", err) } // 读取响应 response := make([]byte, 512) n, err := conn.Read(response) if err != nil { return nil, fmt.Errorf("failed to read DNS response: %v", err) } return response[:n], nil } // getFromCache 从缓存获取DNS响应 func (d *DNSProxy) getFromCache(domain string) []byte { d.cache.mutex.RLock() defer d.cache.mutex.RUnlock() entry, exists := d.cache.entries[domain] if !exists || time.Now().After(entry.expiry) { return nil } return entry.response } // addToCache 添加DNS响应到缓存 func (d *DNSProxy) addToCache(domain string, response []byte) { d.cache.mutex.Lock() defer d.cache.mutex.Unlock() // 设置缓存过期时间(5分钟) expiry := time.Now().Add(5 * time.Minute) d.cache.entries[domain] = &cacheEntry{ response: make([]byte, len(response)), expiry: expiry, } copy(d.cache.entries[domain].response, response) } // cleanupCache 清理过期的缓存条目 func (d *DNSProxy) cleanupCache() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: if !d.isRunning() { return } d.cache.mutex.Lock() now := time.Now() for domain, entry := range d.cache.entries { if now.After(entry.expiry) { delete(d.cache.entries, domain) } } d.cache.mutex.Unlock() } } } // isRunning 检查DNS代理是否在运行 func (d *DNSProxy) isRunning() bool { d.mutex.RLock() defer d.mutex.RUnlock() return d.running }