You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

264 lines
6.1 KiB

package routing
import (
"net"
"regexp"
"strings"
"github.com/azoic/wormhole-client/internal/config"
"github.com/azoic/wormhole-client/pkg/logger"
)
// RouteMatcher 路由匹配器
type RouteMatcher struct {
config *config.Routing
bypassDomains []*regexp.Regexp
forceDomains []*regexp.Regexp
privateNetworks []*net.IPNet
}
// MatchResult 匹配结果
type MatchResult int
const (
// MatchBypass 直连(绕过代理)
MatchBypass MatchResult = iota
// MatchProxy 代理
MatchProxy
// MatchAuto 自动决定
MatchAuto
)
// NewRouteMatcher 创建路由匹配器
func NewRouteMatcher(config *config.Routing) (*RouteMatcher, error) {
matcher := &RouteMatcher{
config: config,
}
// 编译域名规则
if err := matcher.compilePatterns(); err != nil {
return nil, err
}
// 初始化私有网络列表
matcher.initPrivateNetworks()
logger.Debug("Route matcher initialized with %d bypass domains, %d force domains",
len(matcher.bypassDomains), len(matcher.forceDomains))
return matcher, nil
}
// Match 匹配主机地址,返回路由决策
func (rm *RouteMatcher) Match(host string) MatchResult {
// 去除端口
if hostOnly, _, err := net.SplitHostPort(host); err == nil {
host = hostOnly
}
logger.Debug("Matching route for host: %s", host)
// 1. 检查强制代理域名
if rm.matchesForceDomains(host) {
logger.Debug("Host %s matches force domains - using proxy", host)
return MatchProxy
}
// 2. 检查绕过域名
if rm.matchesBypassDomains(host) {
logger.Debug("Host %s matches bypass domains - using direct", host)
return MatchBypass
}
// 3. 检查是否为IP地址
if ip := net.ParseIP(host); ip != nil {
return rm.matchIP(ip)
}
// 4. 检查本地域名
if rm.config.BypassLocal && rm.isLocalDomain(host) {
logger.Debug("Host %s is local domain - using direct", host)
return MatchBypass
}
// 5. 默认策略:自动决定或代理
logger.Debug("Host %s no specific rule - using auto", host)
return MatchAuto
}
// matchesForceDomains 检查是否匹配强制代理域名
func (rm *RouteMatcher) matchesForceDomains(host string) bool {
for _, pattern := range rm.forceDomains {
if pattern.MatchString(host) {
return true
}
}
return false
}
// matchesBypassDomains 检查是否匹配绕过域名
func (rm *RouteMatcher) matchesBypassDomains(host string) bool {
for _, pattern := range rm.bypassDomains {
if pattern.MatchString(host) {
return true
}
}
return false
}
// matchIP 匹配IP地址
func (rm *RouteMatcher) matchIP(ip net.IP) MatchResult {
// 检查本地IP
if rm.config.BypassLocal && rm.isLocalIP(ip) {
logger.Debug("IP %s is local - using direct", ip.String())
return MatchBypass
}
// 检查私有网络
if rm.config.BypassPrivate && rm.isPrivateIP(ip) {
logger.Debug("IP %s is private - using direct", ip.String())
return MatchBypass
}
return MatchAuto
}
// isLocalIP 检查是否为本地IP
func (rm *RouteMatcher) isLocalIP(ip net.IP) bool {
// 环回地址
if ip.IsLoopback() {
return true
}
// 链路本地地址
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
return false
}
// isPrivateIP 检查是否为私有IP
func (rm *RouteMatcher) isPrivateIP(ip net.IP) bool {
for _, network := range rm.privateNetworks {
if network.Contains(ip) {
return true
}
}
return false
}
// isLocalDomain 检查是否为本地域名
func (rm *RouteMatcher) isLocalDomain(host string) bool {
host = strings.ToLower(host)
// 常见本地域名
localSuffixes := []string{
".local",
".localhost",
".lan",
".internal",
".intranet",
".home",
".corp",
}
for _, suffix := range localSuffixes {
if strings.HasSuffix(host, suffix) {
return true
}
}
// 单词域名(无点)
if !strings.Contains(host, ".") {
return true
}
return false
}
// compilePatterns 编译域名匹配模式
func (rm *RouteMatcher) compilePatterns() error {
// 编译绕过域名模式
for _, domain := range rm.config.BypassDomains {
pattern, err := rm.domainToRegexp(domain)
if err != nil {
return err
}
rm.bypassDomains = append(rm.bypassDomains, pattern)
}
// 编译强制代理域名模式
for _, domain := range rm.config.ForceDomains {
pattern, err := rm.domainToRegexp(domain)
if err != nil {
return err
}
rm.forceDomains = append(rm.forceDomains, pattern)
}
return nil
}
// domainToRegexp 将域名模式转换为正则表达式
func (rm *RouteMatcher) domainToRegexp(domain string) (*regexp.Regexp, error) {
// 转义特殊字符
pattern := regexp.QuoteMeta(domain)
// 替换通配符
pattern = strings.ReplaceAll(pattern, "\\*", ".*")
// 添加行开始和结束标记
pattern = "^" + pattern + "$"
// 编译正则表达式(不区分大小写)
return regexp.Compile("(?i)" + pattern)
}
// initPrivateNetworks 初始化私有网络列表
func (rm *RouteMatcher) initPrivateNetworks() {
privateNetworks := []string{
"10.0.0.0/8", // Class A private
"172.16.0.0/12", // Class B private
"192.168.0.0/16", // Class C private
"169.254.0.0/16", // Link-local
"127.0.0.0/8", // Loopback
"224.0.0.0/4", // Multicast
"240.0.0.0/4", // Reserved
"::1/128", // IPv6 loopback
"fe80::/10", // IPv6 link-local
"fc00::/7", // IPv6 unique local
}
for _, network := range privateNetworks {
if _, ipNet, err := net.ParseCIDR(network); err == nil {
rm.privateNetworks = append(rm.privateNetworks, ipNet)
}
}
}
// GetStats 获取路由统计信息
func (rm *RouteMatcher) GetStats() map[string]interface{} {
return map[string]interface{}{
"bypass_domains_count": len(rm.bypassDomains),
"force_domains_count": len(rm.forceDomains),
"private_networks_count": len(rm.privateNetworks),
"bypass_local": rm.config.BypassLocal,
"bypass_private": rm.config.BypassPrivate,
}
}
// ReloadConfig 重新加载配置
func (rm *RouteMatcher) ReloadConfig(config *config.Routing) error {
rm.config = config
rm.bypassDomains = nil
rm.forceDomains = nil
if err := rm.compilePatterns(); err != nil {
return err
}
logger.Info("Route matcher configuration reloaded")
return nil
}