You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

289 lines
6.5 KiB

package dns
import (
"context"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/sirupsen/logrus"
)
type DNSProxy struct {
logger *logrus.Logger
listenPort int
upstreamDNS []string
blockedDomains map[string]bool
cache map[string]*DNSCacheEntry
cacheMutex sync.RWMutex
server *net.UDPConn
}
type DNSCacheEntry struct {
Response []byte
Expiry time.Time
}
type Config struct {
ListenPort int
UpstreamDNS []string
BlockedDomains []string
CacheTTL time.Duration
}
func NewDNSProxy(config Config, logger *logrus.Logger) *DNSProxy {
if len(config.UpstreamDNS) == 0 {
config.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"}
}
if config.CacheTTL == 0 {
config.CacheTTL = 5 * time.Minute
}
blocked := make(map[string]bool)
for _, domain := range config.BlockedDomains {
blocked[strings.ToLower(domain)] = true
}
return &DNSProxy{
logger: logger,
listenPort: config.ListenPort,
upstreamDNS: config.UpstreamDNS,
blockedDomains: blocked,
cache: make(map[string]*DNSCacheEntry),
}
}
func (dp *DNSProxy) Start(ctx context.Context) error {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", dp.listenPort))
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %w", err)
}
dp.server, err = net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to listen on UDP: %w", err)
}
dp.logger.WithField("port", dp.listenPort).Info("DNS proxy started")
// 启动缓存清理goroutine
go dp.cacheCleanup(ctx)
// 处理DNS请求
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
buffer := make([]byte, 512)
n, clientAddr, err := dp.server.ReadFromUDP(buffer)
if err != nil {
dp.logger.WithError(err).Error("Failed to read UDP packet")
continue
}
go dp.handleDNSRequest(buffer[:n], clientAddr)
}
}
}
func (dp *DNSProxy) Stop() error {
if dp.server != nil {
dp.logger.Info("Stopping DNS proxy")
return dp.server.Close()
}
return nil
}
func (dp *DNSProxy) handleDNSRequest(query []byte, clientAddr *net.UDPAddr) {
// 简单的DNS解析(这里为了演示简化处理)
domain := dp.extractDomain(query)
dp.logger.WithFields(logrus.Fields{
"domain": domain,
"client": clientAddr.String(),
}).Debug("DNS request received")
// 检查是否为被阻止的域名
if dp.isBlocked(domain) {
dp.logger.WithField("domain", domain).Info("Blocked domain request")
dp.sendBlockedResponse(query, clientAddr)
return
}
// 检查缓存
if response := dp.getFromCache(domain); response != nil {
dp.logger.WithField("domain", domain).Debug("DNS cache hit")
dp.sendResponse(response, clientAddr)
return
}
// 转发到上游DNS
response := dp.forwardToUpstream(query)
if response != nil {
// 缓存响应
dp.putToCache(domain, response)
dp.sendResponse(response, clientAddr)
} else {
dp.logger.WithField("domain", domain).Error("Failed to resolve domain")
}
}
func (dp *DNSProxy) extractDomain(query []byte) string {
// 这里简化处理,实际需要解析DNS包格式
// 假设域名在查询包中的位置
if len(query) < 20 {
return "unknown"
}
// 简单的域名提取(实际需要完整的DNS解析)
return "example.com"
}
func (dp *DNSProxy) isBlocked(domain string) bool {
domain = strings.ToLower(domain)
// 检查完全匹配
if dp.blockedDomains[domain] {
return true
}
// 检查子域名
parts := strings.Split(domain, ".")
for i := 1; i < len(parts); i++ {
parent := strings.Join(parts[i:], ".")
if dp.blockedDomains[parent] {
return true
}
}
return false
}
func (dp *DNSProxy) getFromCache(domain string) []byte {
dp.cacheMutex.RLock()
defer dp.cacheMutex.RUnlock()
entry, exists := dp.cache[domain]
if !exists || time.Now().After(entry.Expiry) {
return nil
}
return entry.Response
}
func (dp *DNSProxy) putToCache(domain string, response []byte) {
dp.cacheMutex.Lock()
defer dp.cacheMutex.Unlock()
dp.cache[domain] = &DNSCacheEntry{
Response: response,
Expiry: time.Now().Add(5 * time.Minute),
}
}
func (dp *DNSProxy) forwardToUpstream(query []byte) []byte {
for _, upstream := range dp.upstreamDNS {
conn, err := net.DialTimeout("udp", upstream, 3*time.Second)
if err != nil {
dp.logger.WithError(err).WithField("upstream", upstream).Warn("Failed to connect to upstream DNS")
continue
}
defer conn.Close()
// 设置读写超时
conn.SetDeadline(time.Now().Add(3 * time.Second))
// 发送查询
if _, err := conn.Write(query); err != nil {
dp.logger.WithError(err).WithField("upstream", upstream).Warn("Failed to write to upstream DNS")
continue
}
// 读取响应
response := make([]byte, 512)
n, err := conn.Read(response)
if err != nil {
dp.logger.WithError(err).WithField("upstream", upstream).Warn("Failed to read from upstream DNS")
continue
}
dp.logger.WithField("upstream", upstream).Debug("DNS query forwarded successfully")
return response[:n]
}
return nil
}
func (dp *DNSProxy) sendResponse(response []byte, clientAddr *net.UDPAddr) {
if _, err := dp.server.WriteToUDP(response, clientAddr); err != nil {
dp.logger.WithError(err).Error("Failed to send DNS response")
}
}
func (dp *DNSProxy) sendBlockedResponse(query []byte, clientAddr *net.UDPAddr) {
// 创建一个NXDOMAIN响应(简化处理)
if len(query) < 12 {
return
}
response := make([]byte, len(query))
copy(response, query)
// 设置响应标志(简化处理)
response[2] = 0x81 // QR=1, RCODE=3 (NXDOMAIN)
response[3] = 0x83
dp.sendResponse(response, clientAddr)
}
func (dp *DNSProxy) cacheCleanup(ctx context.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
dp.cleanExpiredCache()
}
}
}
func (dp *DNSProxy) cleanExpiredCache() {
dp.cacheMutex.Lock()
defer dp.cacheMutex.Unlock()
now := time.Now()
expired := make([]string, 0)
for domain, entry := range dp.cache {
if now.After(entry.Expiry) {
expired = append(expired, domain)
}
}
for _, domain := range expired {
delete(dp.cache, domain)
}
if len(expired) > 0 {
dp.logger.WithField("count", len(expired)).Debug("Cleaned expired DNS cache entries")
}
}
func (dp *DNSProxy) GetStats() map[string]interface{} {
dp.cacheMutex.RLock()
defer dp.cacheMutex.RUnlock()
return map[string]interface{}{
"cache_size": len(dp.cache),
"upstream_dns": dp.upstreamDNS,
"blocked_domains": len(dp.blockedDomains),
"listen_port": dp.listenPort,
}
}