package routing import ( "testing" "github.com/azoic/wormhole-client/internal/config" ) func TestNewRouteMatcher(t *testing.T) { cfg := &config.Routing{ BypassLocal: true, BypassPrivate: true, BypassDomains: []string{"*.local", "*.lan"}, ForceDomains: []string{"*.google.com", "*.github.com"}, } matcher, err := NewRouteMatcher(cfg) if err != nil { t.Fatalf("Failed to create route matcher: %v", err) } if matcher == nil { t.Fatal("Route matcher should not be nil") } if len(matcher.bypassDomains) != 2 { t.Errorf("Expected 2 bypass domains, got %d", len(matcher.bypassDomains)) } if len(matcher.forceDomains) != 2 { t.Errorf("Expected 2 force domains, got %d", len(matcher.forceDomains)) } } func TestRouteMatcher_Match(t *testing.T) { cfg := &config.Routing{ BypassLocal: true, BypassPrivate: true, BypassDomains: []string{"*.local", "*.baidu.com"}, ForceDomains: []string{"*.google.com", "*.github.com"}, } matcher, err := NewRouteMatcher(cfg) if err != nil { t.Fatalf("Failed to create route matcher: %v", err) } tests := []struct { host string expected MatchResult }{ // 强制代理域名 {"www.google.com", MatchProxy}, {"api.github.com", MatchProxy}, {"www.google.com:443", MatchProxy}, // 绕过域名 {"test.local", MatchBypass}, {"www.baidu.com", MatchBypass}, {"search.baidu.com:80", MatchBypass}, // 本地域名 {"localhost", MatchBypass}, {"router.lan", MatchBypass}, {"printer.internal", MatchBypass}, // IP地址 - 本地 {"127.0.0.1", MatchBypass}, {"::1", MatchBypass}, // IP地址 - 私有网络 {"192.168.1.1", MatchBypass}, {"10.0.0.1", MatchBypass}, {"172.16.1.1", MatchBypass}, // 其他域名 - 自动决定 {"example.com", MatchAuto}, {"stackoverflow.com", MatchAuto}, } for _, test := range tests { t.Run(test.host, func(t *testing.T) { result := matcher.Match(test.host) if result != test.expected { t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result) } }) } } func TestRouteMatcher_MatchesForceDomains(t *testing.T) { cfg := &config.Routing{ ForceDomains: []string{"*.google.com", "github.com", "*.example.*"}, } matcher, err := NewRouteMatcher(cfg) if err != nil { t.Fatalf("Failed to create route matcher: %v", err) } tests := []struct { host string expected bool }{ {"www.google.com", true}, {"api.google.com", true}, {"google.com", false}, // 不匹配 *.google.com {"github.com", true}, {"api.github.com", false}, // 不匹配 github.com {"test.example.org", true}, {"sub.example.net", true}, {"example.com", false}, // 不匹配 *.example.* {"other.com", false}, } for _, test := range tests { t.Run(test.host, func(t *testing.T) { result := matcher.matchesForceDomains(test.host) if result != test.expected { t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result) } }) } } func TestRouteMatcher_MatchesBypassDomains(t *testing.T) { cfg := &config.Routing{ BypassDomains: []string{"*.local", "localhost", "*.cn"}, } matcher, err := NewRouteMatcher(cfg) if err != nil { t.Fatalf("Failed to create route matcher: %v", err) } tests := []struct { host string expected bool }{ {"test.local", true}, {"printer.local", true}, {"local", false}, // 不匹配 *.local {"localhost", true}, {"www.baidu.cn", true}, {"qq.cn", true}, {"china.com", false}, // 不匹配 *.cn {"example.com", false}, } for _, test := range tests { t.Run(test.host, func(t *testing.T) { result := matcher.matchesBypassDomains(test.host) if result != test.expected { t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result) } }) } } func TestRouteMatcher_IsLocalDomain(t *testing.T) { cfg := &config.Routing{} matcher, _ := NewRouteMatcher(cfg) tests := []struct { host string expected bool }{ {"localhost", true}, {"test.local", true}, {"printer.lan", true}, {"server.internal", true}, {"router.home", true}, {"pc.corp", true}, {"singleword", true}, // 单词域名 {"example.com", false}, {"www.google.com", false}, {"192.168.1.1", false}, // IP地址不是域名 } for _, test := range tests { t.Run(test.host, func(t *testing.T) { result := matcher.isLocalDomain(test.host) if result != test.expected { t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result) } }) } } func TestRouteMatcher_IsPrivateIP(t *testing.T) { cfg := &config.Routing{} matcher, _ := NewRouteMatcher(cfg) tests := []struct { ip string expected bool }{ // 私有IPv4地址 {"192.168.1.1", true}, {"10.0.0.1", true}, {"172.16.1.1", true}, {"127.0.0.1", true}, // 环回地址 // 公网IPv4地址 {"8.8.8.8", false}, {"1.1.1.1", false}, {"114.114.114.114", false}, // IPv6地址 {"::1", true}, // 环回 {"fe80::1", true}, // 链路本地 {"fc00::1", true}, // 唯一本地 {"2001:db8::1", false}, // 公网(测试用) } for _, test := range tests { t.Run(test.ip, func(t *testing.T) { // 解析IP ip := parseIPForTest(test.ip) if ip == nil { t.Fatalf("Failed to parse IP: %s", test.ip) } result := matcher.isPrivateIP(ip) if result != test.expected { t.Errorf("For IP %s, expected %v, got %v", test.ip, test.expected, result) } }) } } func TestRouteMatcher_GetStats(t *testing.T) { cfg := &config.Routing{ BypassLocal: true, BypassPrivate: false, BypassDomains: []string{"*.local", "*.lan", "*.cn"}, ForceDomains: []string{"*.google.com", "*.github.com"}, } matcher, err := NewRouteMatcher(cfg) if err != nil { t.Fatalf("Failed to create route matcher: %v", err) } stats := matcher.GetStats() if stats["bypass_domains_count"] != 3 { t.Errorf("Expected 3 bypass domains, got %v", stats["bypass_domains_count"]) } if stats["force_domains_count"] != 2 { t.Errorf("Expected 2 force domains, got %v", stats["force_domains_count"]) } if stats["bypass_local"] != true { t.Errorf("Expected bypass_local to be true, got %v", stats["bypass_local"]) } if stats["bypass_private"] != false { t.Errorf("Expected bypass_private to be false, got %v", stats["bypass_private"]) } } func TestDomainToRegexp(t *testing.T) { cfg := &config.Routing{} matcher, _ := NewRouteMatcher(cfg) tests := []struct { domain string host string matches bool }{ {"*.google.com", "www.google.com", true}, {"*.google.com", "api.google.com", true}, {"*.google.com", "google.com", false}, {"github.com", "github.com", true}, {"github.com", "api.github.com", false}, {"*.example.*", "test.example.org", true}, {"*.example.*", "sub.example.net", true}, {"*.example.*", "example.com", false}, } for _, test := range tests { t.Run(test.domain+"->"+test.host, func(t *testing.T) { pattern, err := matcher.domainToRegexp(test.domain) if err != nil { t.Fatalf("Failed to compile pattern %s: %v", test.domain, err) } matches := pattern.MatchString(test.host) if matches != test.matches { t.Errorf("Pattern %s against host %s: expected %v, got %v", test.domain, test.host, test.matches, matches) } }) } } // 辅助函数 func parseIPForTest(s string) []byte { // 简单的IP解析用于测试 switch s { case "192.168.1.1": return []byte{192, 168, 1, 1} case "10.0.0.1": return []byte{10, 0, 0, 1} case "172.16.1.1": return []byte{172, 16, 1, 1} case "127.0.0.1": return []byte{127, 0, 0, 1} case "8.8.8.8": return []byte{8, 8, 8, 8} case "1.1.1.1": return []byte{1, 1, 1, 1} case "114.114.114.114": return []byte{114, 114, 114, 114} // IPv6地址处理会更复杂,这里简化 case "::1": return []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} case "fe80::1": return []byte{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} case "fc00::1": return []byte{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} case "2001:db8::1": return []byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} default: return nil } }