package dns import ( "context" "fmt" "net" "strings" "sync" "time" "github.com/sirupsen/logrus" ) type DNSProxy struct { logger *logrus.Logger listenPort int upstreamDNS []string blockedDomains map[string]bool cache map[string]*DNSCacheEntry cacheMutex sync.RWMutex server *net.UDPConn } type DNSCacheEntry struct { Response []byte Expiry time.Time } type Config struct { ListenPort int UpstreamDNS []string BlockedDomains []string CacheTTL time.Duration } func NewDNSProxy(config Config, logger *logrus.Logger) *DNSProxy { if len(config.UpstreamDNS) == 0 { config.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"} } if config.CacheTTL == 0 { config.CacheTTL = 5 * time.Minute } blocked := make(map[string]bool) for _, domain := range config.BlockedDomains { blocked[strings.ToLower(domain)] = true } return &DNSProxy{ logger: logger, listenPort: config.ListenPort, upstreamDNS: config.UpstreamDNS, blockedDomains: blocked, cache: make(map[string]*DNSCacheEntry), } } func (dp *DNSProxy) Start(ctx context.Context) error { addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", dp.listenPort)) if err != nil { return fmt.Errorf("failed to resolve UDP address: %w", err) } dp.server, err = net.ListenUDP("udp", addr) if err != nil { return fmt.Errorf("failed to listen on UDP: %w", err) } dp.logger.WithField("port", dp.listenPort).Info("DNS proxy started") // 启动缓存清理goroutine go dp.cacheCleanup(ctx) // 处理DNS请求 for { select { case <-ctx.Done(): return ctx.Err() default: buffer := make([]byte, 512) n, clientAddr, err := dp.server.ReadFromUDP(buffer) if err != nil { dp.logger.WithError(err).Error("Failed to read UDP packet") continue } go dp.handleDNSRequest(buffer[:n], clientAddr) } } } func (dp *DNSProxy) Stop() error { if dp.server != nil { dp.logger.Info("Stopping DNS proxy") return dp.server.Close() } return nil } func (dp *DNSProxy) handleDNSRequest(query []byte, clientAddr *net.UDPAddr) { // 简单的DNS解析(这里为了演示简化处理) domain := dp.extractDomain(query) dp.logger.WithFields(logrus.Fields{ "domain": domain, "client": clientAddr.String(), }).Debug("DNS request received") // 检查是否为被阻止的域名 if dp.isBlocked(domain) { dp.logger.WithField("domain", domain).Info("Blocked domain request") dp.sendBlockedResponse(query, clientAddr) return } // 检查缓存 if response := dp.getFromCache(domain); response != nil { dp.logger.WithField("domain", domain).Debug("DNS cache hit") dp.sendResponse(response, clientAddr) return } // 转发到上游DNS response := dp.forwardToUpstream(query) if response != nil { // 缓存响应 dp.putToCache(domain, response) dp.sendResponse(response, clientAddr) } else { dp.logger.WithField("domain", domain).Error("Failed to resolve domain") } } func (dp *DNSProxy) extractDomain(query []byte) string { // 这里简化处理,实际需要解析DNS包格式 // 假设域名在查询包中的位置 if len(query) < 20 { return "unknown" } // 简单的域名提取(实际需要完整的DNS解析) return "example.com" } func (dp *DNSProxy) isBlocked(domain string) bool { domain = strings.ToLower(domain) // 检查完全匹配 if dp.blockedDomains[domain] { return true } // 检查子域名 parts := strings.Split(domain, ".") for i := 1; i < len(parts); i++ { parent := strings.Join(parts[i:], ".") if dp.blockedDomains[parent] { return true } } return false } func (dp *DNSProxy) getFromCache(domain string) []byte { dp.cacheMutex.RLock() defer dp.cacheMutex.RUnlock() entry, exists := dp.cache[domain] if !exists || time.Now().After(entry.Expiry) { return nil } return entry.Response } func (dp *DNSProxy) putToCache(domain string, response []byte) { dp.cacheMutex.Lock() defer dp.cacheMutex.Unlock() dp.cache[domain] = &DNSCacheEntry{ Response: response, Expiry: time.Now().Add(5 * time.Minute), } } func (dp *DNSProxy) forwardToUpstream(query []byte) []byte { for _, upstream := range dp.upstreamDNS { conn, err := net.DialTimeout("udp", upstream, 3*time.Second) if err != nil { dp.logger.WithError(err).WithField("upstream", upstream).Warn("Failed to connect to upstream DNS") continue } defer conn.Close() // 设置读写超时 conn.SetDeadline(time.Now().Add(3 * time.Second)) // 发送查询 if _, err := conn.Write(query); err != nil { dp.logger.WithError(err).WithField("upstream", upstream).Warn("Failed to write to upstream DNS") continue } // 读取响应 response := make([]byte, 512) n, err := conn.Read(response) if err != nil { dp.logger.WithError(err).WithField("upstream", upstream).Warn("Failed to read from upstream DNS") continue } dp.logger.WithField("upstream", upstream).Debug("DNS query forwarded successfully") return response[:n] } return nil } func (dp *DNSProxy) sendResponse(response []byte, clientAddr *net.UDPAddr) { if _, err := dp.server.WriteToUDP(response, clientAddr); err != nil { dp.logger.WithError(err).Error("Failed to send DNS response") } } func (dp *DNSProxy) sendBlockedResponse(query []byte, clientAddr *net.UDPAddr) { // 创建一个NXDOMAIN响应(简化处理) if len(query) < 12 { return } response := make([]byte, len(query)) copy(response, query) // 设置响应标志(简化处理) response[2] = 0x81 // QR=1, RCODE=3 (NXDOMAIN) response[3] = 0x83 dp.sendResponse(response, clientAddr) } func (dp *DNSProxy) cacheCleanup(ctx context.Context) { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: dp.cleanExpiredCache() } } } func (dp *DNSProxy) cleanExpiredCache() { dp.cacheMutex.Lock() defer dp.cacheMutex.Unlock() now := time.Now() expired := make([]string, 0) for domain, entry := range dp.cache { if now.After(entry.Expiry) { expired = append(expired, domain) } } for _, domain := range expired { delete(dp.cache, domain) } if len(expired) > 0 { dp.logger.WithField("count", len(expired)).Debug("Cleaned expired DNS cache entries") } } func (dp *DNSProxy) GetStats() map[string]interface{} { dp.cacheMutex.RLock() defer dp.cacheMutex.RUnlock() return map[string]interface{}{ "cache_size": len(dp.cache), "upstream_dns": dp.upstreamDNS, "blocked_domains": len(dp.blockedDomains), "listen_port": dp.listenPort, } }