You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
253 lines
5.1 KiB
253 lines
5.1 KiB
2 weeks ago
|
package dns
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/azoic/wormhole-client/pkg/logger"
|
||
|
)
|
||
|
|
||
|
// DNSProxy DNS代理服务器
|
||
|
type DNSProxy struct {
|
||
|
upstreamDNS string
|
||
|
localPort int
|
||
|
server *net.UDPConn
|
||
|
cache *dnsCache
|
||
|
running bool
|
||
|
mutex sync.RWMutex
|
||
|
}
|
||
|
|
||
|
// dnsCache DNS缓存
|
||
|
type dnsCache struct {
|
||
|
entries map[string]*cacheEntry
|
||
|
mutex sync.RWMutex
|
||
|
}
|
||
|
|
||
|
type cacheEntry struct {
|
||
|
response []byte
|
||
|
expiry time.Time
|
||
|
}
|
||
|
|
||
|
// NewDNSProxy 创建DNS代理
|
||
|
func NewDNSProxy(upstreamDNS string, localPort int) *DNSProxy {
|
||
|
return &DNSProxy{
|
||
|
upstreamDNS: upstreamDNS,
|
||
|
localPort: localPort,
|
||
|
cache: &dnsCache{
|
||
|
entries: make(map[string]*cacheEntry),
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Start 启动DNS代理服务器
|
||
|
func (d *DNSProxy) Start() error {
|
||
|
d.mutex.Lock()
|
||
|
defer d.mutex.Unlock()
|
||
|
|
||
|
if d.running {
|
||
|
return fmt.Errorf("DNS proxy is already running")
|
||
|
}
|
||
|
|
||
|
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", d.localPort))
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to resolve UDP address: %v", err)
|
||
|
}
|
||
|
|
||
|
d.server, err = net.ListenUDP("udp", addr)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to start DNS proxy server: %v", err)
|
||
|
}
|
||
|
|
||
|
d.running = true
|
||
|
logger.Info("DNS proxy started on port %d", d.localPort)
|
||
|
|
||
|
// 启动清理缓存的goroutine
|
||
|
go d.cleanupCache()
|
||
|
|
||
|
// 处理DNS请求
|
||
|
go d.handleRequests()
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Stop 停止DNS代理服务器
|
||
|
func (d *DNSProxy) Stop() error {
|
||
|
d.mutex.Lock()
|
||
|
defer d.mutex.Unlock()
|
||
|
|
||
|
if !d.running {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
d.running = false
|
||
|
if d.server != nil {
|
||
|
d.server.Close()
|
||
|
}
|
||
|
|
||
|
logger.Info("DNS proxy stopped")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// handleRequests 处理DNS请求
|
||
|
func (d *DNSProxy) handleRequests() {
|
||
|
buffer := make([]byte, 512) // DNS消息最大512字节(UDP)
|
||
|
|
||
|
for d.isRunning() {
|
||
|
n, clientAddr, err := d.server.ReadFromUDP(buffer)
|
||
|
if err != nil {
|
||
|
if d.isRunning() {
|
||
|
logger.Error("Failed to read DNS request: %v", err)
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
go d.processRequest(buffer[:n], clientAddr)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// processRequest 处理单个DNS请求
|
||
|
func (d *DNSProxy) processRequest(request []byte, clientAddr *net.UDPAddr) {
|
||
|
// 简单的DNS请求解析
|
||
|
domain := d.extractDomain(request)
|
||
|
logger.Debug("DNS request for domain: %s", domain)
|
||
|
|
||
|
// 检查缓存
|
||
|
if cachedResponse := d.getFromCache(domain); cachedResponse != nil {
|
||
|
logger.Debug("Serving %s from cache", domain)
|
||
|
d.server.WriteToUDP(cachedResponse, clientAddr)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// 转发到上游DNS服务器
|
||
|
response, err := d.forwardToUpstream(request)
|
||
|
if err != nil {
|
||
|
logger.Error("Failed to forward DNS request: %v", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// 缓存响应
|
||
|
d.addToCache(domain, response)
|
||
|
|
||
|
// 返回响应给客户端
|
||
|
d.server.WriteToUDP(response, clientAddr)
|
||
|
}
|
||
|
|
||
|
// extractDomain 从DNS请求中提取域名(简化实现)
|
||
|
func (d *DNSProxy) extractDomain(request []byte) string {
|
||
|
if len(request) < 12 {
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
// 跳过DNS头部(12字节)
|
||
|
offset := 12
|
||
|
var domain strings.Builder
|
||
|
|
||
|
for offset < len(request) {
|
||
|
length := int(request[offset])
|
||
|
if length == 0 {
|
||
|
break
|
||
|
}
|
||
|
|
||
|
offset++
|
||
|
if offset+length > len(request) {
|
||
|
break
|
||
|
}
|
||
|
|
||
|
if domain.Len() > 0 {
|
||
|
domain.WriteByte('.')
|
||
|
}
|
||
|
domain.Write(request[offset : offset+length])
|
||
|
offset += length
|
||
|
}
|
||
|
|
||
|
return domain.String()
|
||
|
}
|
||
|
|
||
|
// forwardToUpstream 转发DNS请求到上游服务器
|
||
|
func (d *DNSProxy) forwardToUpstream(request []byte) ([]byte, error) {
|
||
|
conn, err := net.Dial("udp", d.upstreamDNS)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to connect to upstream DNS: %v", err)
|
||
|
}
|
||
|
defer conn.Close()
|
||
|
|
||
|
// 设置超时
|
||
|
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||
|
|
||
|
// 发送请求
|
||
|
if _, err := conn.Write(request); err != nil {
|
||
|
return nil, fmt.Errorf("failed to send DNS request: %v", err)
|
||
|
}
|
||
|
|
||
|
// 读取响应
|
||
|
response := make([]byte, 512)
|
||
|
n, err := conn.Read(response)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to read DNS response: %v", err)
|
||
|
}
|
||
|
|
||
|
return response[:n], nil
|
||
|
}
|
||
|
|
||
|
// getFromCache 从缓存获取DNS响应
|
||
|
func (d *DNSProxy) getFromCache(domain string) []byte {
|
||
|
d.cache.mutex.RLock()
|
||
|
defer d.cache.mutex.RUnlock()
|
||
|
|
||
|
entry, exists := d.cache.entries[domain]
|
||
|
if !exists || time.Now().After(entry.expiry) {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
return entry.response
|
||
|
}
|
||
|
|
||
|
// addToCache 添加DNS响应到缓存
|
||
|
func (d *DNSProxy) addToCache(domain string, response []byte) {
|
||
|
d.cache.mutex.Lock()
|
||
|
defer d.cache.mutex.Unlock()
|
||
|
|
||
|
// 设置缓存过期时间(5分钟)
|
||
|
expiry := time.Now().Add(5 * time.Minute)
|
||
|
|
||
|
d.cache.entries[domain] = &cacheEntry{
|
||
|
response: make([]byte, len(response)),
|
||
|
expiry: expiry,
|
||
|
}
|
||
|
copy(d.cache.entries[domain].response, response)
|
||
|
}
|
||
|
|
||
|
// cleanupCache 清理过期的缓存条目
|
||
|
func (d *DNSProxy) cleanupCache() {
|
||
|
ticker := time.NewTicker(1 * time.Minute)
|
||
|
defer ticker.Stop()
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-ticker.C:
|
||
|
if !d.isRunning() {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
d.cache.mutex.Lock()
|
||
|
now := time.Now()
|
||
|
for domain, entry := range d.cache.entries {
|
||
|
if now.After(entry.expiry) {
|
||
|
delete(d.cache.entries, domain)
|
||
|
}
|
||
|
}
|
||
|
d.cache.mutex.Unlock()
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// isRunning 检查DNS代理是否在运行
|
||
|
func (d *DNSProxy) isRunning() bool {
|
||
|
d.mutex.RLock()
|
||
|
defer d.mutex.RUnlock()
|
||
|
return d.running
|
||
|
}
|