You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
289 lines
6.5 KiB
289 lines
6.5 KiB
package dns
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type DNSProxy struct {
|
|
logger *logrus.Logger
|
|
listenPort int
|
|
upstreamDNS []string
|
|
blockedDomains map[string]bool
|
|
cache map[string]*DNSCacheEntry
|
|
cacheMutex sync.RWMutex
|
|
server *net.UDPConn
|
|
}
|
|
|
|
type DNSCacheEntry struct {
|
|
Response []byte
|
|
Expiry time.Time
|
|
}
|
|
|
|
type Config struct {
|
|
ListenPort int
|
|
UpstreamDNS []string
|
|
BlockedDomains []string
|
|
CacheTTL time.Duration
|
|
}
|
|
|
|
func NewDNSProxy(config Config, logger *logrus.Logger) *DNSProxy {
|
|
if len(config.UpstreamDNS) == 0 {
|
|
config.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"}
|
|
}
|
|
|
|
if config.CacheTTL == 0 {
|
|
config.CacheTTL = 5 * time.Minute
|
|
}
|
|
|
|
blocked := make(map[string]bool)
|
|
for _, domain := range config.BlockedDomains {
|
|
blocked[strings.ToLower(domain)] = true
|
|
}
|
|
|
|
return &DNSProxy{
|
|
logger: logger,
|
|
listenPort: config.ListenPort,
|
|
upstreamDNS: config.UpstreamDNS,
|
|
blockedDomains: blocked,
|
|
cache: make(map[string]*DNSCacheEntry),
|
|
}
|
|
}
|
|
|
|
func (dp *DNSProxy) Start(ctx context.Context) error {
|
|
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", dp.listenPort))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to resolve UDP address: %w", err)
|
|
}
|
|
|
|
dp.server, err = net.ListenUDP("udp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen on UDP: %w", err)
|
|
}
|
|
|
|
dp.logger.WithField("port", dp.listenPort).Info("DNS proxy started")
|
|
|
|
// 启动缓存清理goroutine
|
|
go dp.cacheCleanup(ctx)
|
|
|
|
// 处理DNS请求
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
buffer := make([]byte, 512)
|
|
n, clientAddr, err := dp.server.ReadFromUDP(buffer)
|
|
if err != nil {
|
|
dp.logger.WithError(err).Error("Failed to read UDP packet")
|
|
continue
|
|
}
|
|
|
|
go dp.handleDNSRequest(buffer[:n], clientAddr)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (dp *DNSProxy) Stop() error {
|
|
if dp.server != nil {
|
|
dp.logger.Info("Stopping DNS proxy")
|
|
return dp.server.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (dp *DNSProxy) handleDNSRequest(query []byte, clientAddr *net.UDPAddr) {
|
|
// 简单的DNS解析(这里为了演示简化处理)
|
|
domain := dp.extractDomain(query)
|
|
|
|
dp.logger.WithFields(logrus.Fields{
|
|
"domain": domain,
|
|
"client": clientAddr.String(),
|
|
}).Debug("DNS request received")
|
|
|
|
// 检查是否为被阻止的域名
|
|
if dp.isBlocked(domain) {
|
|
dp.logger.WithField("domain", domain).Info("Blocked domain request")
|
|
dp.sendBlockedResponse(query, clientAddr)
|
|
return
|
|
}
|
|
|
|
// 检查缓存
|
|
if response := dp.getFromCache(domain); response != nil {
|
|
dp.logger.WithField("domain", domain).Debug("DNS cache hit")
|
|
dp.sendResponse(response, clientAddr)
|
|
return
|
|
}
|
|
|
|
// 转发到上游DNS
|
|
response := dp.forwardToUpstream(query)
|
|
if response != nil {
|
|
// 缓存响应
|
|
dp.putToCache(domain, response)
|
|
dp.sendResponse(response, clientAddr)
|
|
} else {
|
|
dp.logger.WithField("domain", domain).Error("Failed to resolve domain")
|
|
}
|
|
}
|
|
|
|
func (dp *DNSProxy) extractDomain(query []byte) string {
|
|
// 这里简化处理,实际需要解析DNS包格式
|
|
// 假设域名在查询包中的位置
|
|
if len(query) < 20 {
|
|
return "unknown"
|
|
}
|
|
|
|
// 简单的域名提取(实际需要完整的DNS解析)
|
|
return "example.com"
|
|
}
|
|
|
|
func (dp *DNSProxy) isBlocked(domain string) bool {
|
|
domain = strings.ToLower(domain)
|
|
|
|
// 检查完全匹配
|
|
if dp.blockedDomains[domain] {
|
|
return true
|
|
}
|
|
|
|
// 检查子域名
|
|
parts := strings.Split(domain, ".")
|
|
for i := 1; i < len(parts); i++ {
|
|
parent := strings.Join(parts[i:], ".")
|
|
if dp.blockedDomains[parent] {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (dp *DNSProxy) getFromCache(domain string) []byte {
|
|
dp.cacheMutex.RLock()
|
|
defer dp.cacheMutex.RUnlock()
|
|
|
|
entry, exists := dp.cache[domain]
|
|
if !exists || time.Now().After(entry.Expiry) {
|
|
return nil
|
|
}
|
|
|
|
return entry.Response
|
|
}
|
|
|
|
func (dp *DNSProxy) putToCache(domain string, response []byte) {
|
|
dp.cacheMutex.Lock()
|
|
defer dp.cacheMutex.Unlock()
|
|
|
|
dp.cache[domain] = &DNSCacheEntry{
|
|
Response: response,
|
|
Expiry: time.Now().Add(5 * time.Minute),
|
|
}
|
|
}
|
|
|
|
func (dp *DNSProxy) forwardToUpstream(query []byte) []byte {
|
|
for _, upstream := range dp.upstreamDNS {
|
|
conn, err := net.DialTimeout("udp", upstream, 3*time.Second)
|
|
if err != nil {
|
|
dp.logger.WithError(err).WithField("upstream", upstream).Warn("Failed to connect to upstream DNS")
|
|
continue
|
|
}
|
|
defer conn.Close()
|
|
|
|
// 设置读写超时
|
|
conn.SetDeadline(time.Now().Add(3 * time.Second))
|
|
|
|
// 发送查询
|
|
if _, err := conn.Write(query); err != nil {
|
|
dp.logger.WithError(err).WithField("upstream", upstream).Warn("Failed to write to upstream DNS")
|
|
continue
|
|
}
|
|
|
|
// 读取响应
|
|
response := make([]byte, 512)
|
|
n, err := conn.Read(response)
|
|
if err != nil {
|
|
dp.logger.WithError(err).WithField("upstream", upstream).Warn("Failed to read from upstream DNS")
|
|
continue
|
|
}
|
|
|
|
dp.logger.WithField("upstream", upstream).Debug("DNS query forwarded successfully")
|
|
return response[:n]
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (dp *DNSProxy) sendResponse(response []byte, clientAddr *net.UDPAddr) {
|
|
if _, err := dp.server.WriteToUDP(response, clientAddr); err != nil {
|
|
dp.logger.WithError(err).Error("Failed to send DNS response")
|
|
}
|
|
}
|
|
|
|
func (dp *DNSProxy) sendBlockedResponse(query []byte, clientAddr *net.UDPAddr) {
|
|
// 创建一个NXDOMAIN响应(简化处理)
|
|
if len(query) < 12 {
|
|
return
|
|
}
|
|
|
|
response := make([]byte, len(query))
|
|
copy(response, query)
|
|
|
|
// 设置响应标志(简化处理)
|
|
response[2] = 0x81 // QR=1, RCODE=3 (NXDOMAIN)
|
|
response[3] = 0x83
|
|
|
|
dp.sendResponse(response, clientAddr)
|
|
}
|
|
|
|
func (dp *DNSProxy) cacheCleanup(ctx context.Context) {
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
dp.cleanExpiredCache()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (dp *DNSProxy) cleanExpiredCache() {
|
|
dp.cacheMutex.Lock()
|
|
defer dp.cacheMutex.Unlock()
|
|
|
|
now := time.Now()
|
|
expired := make([]string, 0)
|
|
|
|
for domain, entry := range dp.cache {
|
|
if now.After(entry.Expiry) {
|
|
expired = append(expired, domain)
|
|
}
|
|
}
|
|
|
|
for _, domain := range expired {
|
|
delete(dp.cache, domain)
|
|
}
|
|
|
|
if len(expired) > 0 {
|
|
dp.logger.WithField("count", len(expired)).Debug("Cleaned expired DNS cache entries")
|
|
}
|
|
}
|
|
|
|
func (dp *DNSProxy) GetStats() map[string]interface{} {
|
|
dp.cacheMutex.RLock()
|
|
defer dp.cacheMutex.RUnlock()
|
|
|
|
return map[string]interface{}{
|
|
"cache_size": len(dp.cache),
|
|
"upstream_dns": dp.upstreamDNS,
|
|
"blocked_domains": len(dp.blockedDomains),
|
|
"listen_port": dp.listenPort,
|
|
}
|
|
}
|
|
|