You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

252 lines
5.1 KiB

package dns
import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/azoic/wormhole-client/pkg/logger"
)
// DNSProxy DNS代理服务器
type DNSProxy struct {
upstreamDNS string
localPort int
server *net.UDPConn
cache *dnsCache
running bool
mutex sync.RWMutex
}
// dnsCache DNS缓存
type dnsCache struct {
entries map[string]*cacheEntry
mutex sync.RWMutex
}
type cacheEntry struct {
response []byte
expiry time.Time
}
// NewDNSProxy 创建DNS代理
func NewDNSProxy(upstreamDNS string, localPort int) *DNSProxy {
return &DNSProxy{
upstreamDNS: upstreamDNS,
localPort: localPort,
cache: &dnsCache{
entries: make(map[string]*cacheEntry),
},
}
}
// Start 启动DNS代理服务器
func (d *DNSProxy) Start() error {
d.mutex.Lock()
defer d.mutex.Unlock()
if d.running {
return fmt.Errorf("DNS proxy is already running")
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", d.localPort))
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %v", err)
}
d.server, err = net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to start DNS proxy server: %v", err)
}
d.running = true
logger.Info("DNS proxy started on port %d", d.localPort)
// 启动清理缓存的goroutine
go d.cleanupCache()
// 处理DNS请求
go d.handleRequests()
return nil
}
// Stop 停止DNS代理服务器
func (d *DNSProxy) Stop() error {
d.mutex.Lock()
defer d.mutex.Unlock()
if !d.running {
return nil
}
d.running = false
if d.server != nil {
d.server.Close()
}
logger.Info("DNS proxy stopped")
return nil
}
// handleRequests 处理DNS请求
func (d *DNSProxy) handleRequests() {
buffer := make([]byte, 512) // DNS消息最大512字节(UDP)
for d.isRunning() {
n, clientAddr, err := d.server.ReadFromUDP(buffer)
if err != nil {
if d.isRunning() {
logger.Error("Failed to read DNS request: %v", err)
}
continue
}
go d.processRequest(buffer[:n], clientAddr)
}
}
// processRequest 处理单个DNS请求
func (d *DNSProxy) processRequest(request []byte, clientAddr *net.UDPAddr) {
// 简单的DNS请求解析
domain := d.extractDomain(request)
logger.Debug("DNS request for domain: %s", domain)
// 检查缓存
if cachedResponse := d.getFromCache(domain); cachedResponse != nil {
logger.Debug("Serving %s from cache", domain)
d.server.WriteToUDP(cachedResponse, clientAddr)
return
}
// 转发到上游DNS服务器
response, err := d.forwardToUpstream(request)
if err != nil {
logger.Error("Failed to forward DNS request: %v", err)
return
}
// 缓存响应
d.addToCache(domain, response)
// 返回响应给客户端
d.server.WriteToUDP(response, clientAddr)
}
// extractDomain 从DNS请求中提取域名(简化实现)
func (d *DNSProxy) extractDomain(request []byte) string {
if len(request) < 12 {
return ""
}
// 跳过DNS头部(12字节)
offset := 12
var domain strings.Builder
for offset < len(request) {
length := int(request[offset])
if length == 0 {
break
}
offset++
if offset+length > len(request) {
break
}
if domain.Len() > 0 {
domain.WriteByte('.')
}
domain.Write(request[offset : offset+length])
offset += length
}
return domain.String()
}
// forwardToUpstream 转发DNS请求到上游服务器
func (d *DNSProxy) forwardToUpstream(request []byte) ([]byte, error) {
conn, err := net.Dial("udp", d.upstreamDNS)
if err != nil {
return nil, fmt.Errorf("failed to connect to upstream DNS: %v", err)
}
defer conn.Close()
// 设置超时
conn.SetDeadline(time.Now().Add(5 * time.Second))
// 发送请求
if _, err := conn.Write(request); err != nil {
return nil, fmt.Errorf("failed to send DNS request: %v", err)
}
// 读取响应
response := make([]byte, 512)
n, err := conn.Read(response)
if err != nil {
return nil, fmt.Errorf("failed to read DNS response: %v", err)
}
return response[:n], nil
}
// getFromCache 从缓存获取DNS响应
func (d *DNSProxy) getFromCache(domain string) []byte {
d.cache.mutex.RLock()
defer d.cache.mutex.RUnlock()
entry, exists := d.cache.entries[domain]
if !exists || time.Now().After(entry.expiry) {
return nil
}
return entry.response
}
// addToCache 添加DNS响应到缓存
func (d *DNSProxy) addToCache(domain string, response []byte) {
d.cache.mutex.Lock()
defer d.cache.mutex.Unlock()
// 设置缓存过期时间(5分钟)
expiry := time.Now().Add(5 * time.Minute)
d.cache.entries[domain] = &cacheEntry{
response: make([]byte, len(response)),
expiry: expiry,
}
copy(d.cache.entries[domain].response, response)
}
// cleanupCache 清理过期的缓存条目
func (d *DNSProxy) cleanupCache() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if !d.isRunning() {
return
}
d.cache.mutex.Lock()
now := time.Now()
for domain, entry := range d.cache.entries {
if now.After(entry.expiry) {
delete(d.cache.entries, domain)
}
}
d.cache.mutex.Unlock()
}
}
}
// isRunning 检查DNS代理是否在运行
func (d *DNSProxy) isRunning() bool {
d.mutex.RLock()
defer d.mutex.RUnlock()
return d.running
}