You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
330 lines
7.9 KiB
330 lines
7.9 KiB
2 weeks ago
|
package routing
|
||
|
|
||
|
import (
|
||
|
"testing"
|
||
|
|
||
|
"github.com/azoic/wormhole-client/internal/config"
|
||
|
)
|
||
|
|
||
|
func TestNewRouteMatcher(t *testing.T) {
|
||
|
cfg := &config.Routing{
|
||
|
BypassLocal: true,
|
||
|
BypassPrivate: true,
|
||
|
BypassDomains: []string{"*.local", "*.lan"},
|
||
|
ForceDomains: []string{"*.google.com", "*.github.com"},
|
||
|
}
|
||
|
|
||
|
matcher, err := NewRouteMatcher(cfg)
|
||
|
if err != nil {
|
||
|
t.Fatalf("Failed to create route matcher: %v", err)
|
||
|
}
|
||
|
|
||
|
if matcher == nil {
|
||
|
t.Fatal("Route matcher should not be nil")
|
||
|
}
|
||
|
|
||
|
if len(matcher.bypassDomains) != 2 {
|
||
|
t.Errorf("Expected 2 bypass domains, got %d", len(matcher.bypassDomains))
|
||
|
}
|
||
|
|
||
|
if len(matcher.forceDomains) != 2 {
|
||
|
t.Errorf("Expected 2 force domains, got %d", len(matcher.forceDomains))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestRouteMatcher_Match(t *testing.T) {
|
||
|
cfg := &config.Routing{
|
||
|
BypassLocal: true,
|
||
|
BypassPrivate: true,
|
||
|
BypassDomains: []string{"*.local", "*.baidu.com"},
|
||
|
ForceDomains: []string{"*.google.com", "*.github.com"},
|
||
|
}
|
||
|
|
||
|
matcher, err := NewRouteMatcher(cfg)
|
||
|
if err != nil {
|
||
|
t.Fatalf("Failed to create route matcher: %v", err)
|
||
|
}
|
||
|
|
||
|
tests := []struct {
|
||
|
host string
|
||
|
expected MatchResult
|
||
|
}{
|
||
|
// 强制代理域名
|
||
|
{"www.google.com", MatchProxy},
|
||
|
{"api.github.com", MatchProxy},
|
||
|
{"www.google.com:443", MatchProxy},
|
||
|
|
||
|
// 绕过域名
|
||
|
{"test.local", MatchBypass},
|
||
|
{"www.baidu.com", MatchBypass},
|
||
|
{"search.baidu.com:80", MatchBypass},
|
||
|
|
||
|
// 本地域名
|
||
|
{"localhost", MatchBypass},
|
||
|
{"router.lan", MatchBypass},
|
||
|
{"printer.internal", MatchBypass},
|
||
|
|
||
|
// IP地址 - 本地
|
||
|
{"127.0.0.1", MatchBypass},
|
||
|
{"::1", MatchBypass},
|
||
|
|
||
|
// IP地址 - 私有网络
|
||
|
{"192.168.1.1", MatchBypass},
|
||
|
{"10.0.0.1", MatchBypass},
|
||
|
{"172.16.1.1", MatchBypass},
|
||
|
|
||
|
// 其他域名 - 自动决定
|
||
|
{"example.com", MatchAuto},
|
||
|
{"stackoverflow.com", MatchAuto},
|
||
|
}
|
||
|
|
||
|
for _, test := range tests {
|
||
|
t.Run(test.host, func(t *testing.T) {
|
||
|
result := matcher.Match(test.host)
|
||
|
if result != test.expected {
|
||
|
t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestRouteMatcher_MatchesForceDomains(t *testing.T) {
|
||
|
cfg := &config.Routing{
|
||
|
ForceDomains: []string{"*.google.com", "github.com", "*.example.*"},
|
||
|
}
|
||
|
|
||
|
matcher, err := NewRouteMatcher(cfg)
|
||
|
if err != nil {
|
||
|
t.Fatalf("Failed to create route matcher: %v", err)
|
||
|
}
|
||
|
|
||
|
tests := []struct {
|
||
|
host string
|
||
|
expected bool
|
||
|
}{
|
||
|
{"www.google.com", true},
|
||
|
{"api.google.com", true},
|
||
|
{"google.com", false}, // 不匹配 *.google.com
|
||
|
{"github.com", true},
|
||
|
{"api.github.com", false}, // 不匹配 github.com
|
||
|
{"test.example.org", true},
|
||
|
{"sub.example.net", true},
|
||
|
{"example.com", false}, // 不匹配 *.example.*
|
||
|
{"other.com", false},
|
||
|
}
|
||
|
|
||
|
for _, test := range tests {
|
||
|
t.Run(test.host, func(t *testing.T) {
|
||
|
result := matcher.matchesForceDomains(test.host)
|
||
|
if result != test.expected {
|
||
|
t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestRouteMatcher_MatchesBypassDomains(t *testing.T) {
|
||
|
cfg := &config.Routing{
|
||
|
BypassDomains: []string{"*.local", "localhost", "*.cn"},
|
||
|
}
|
||
|
|
||
|
matcher, err := NewRouteMatcher(cfg)
|
||
|
if err != nil {
|
||
|
t.Fatalf("Failed to create route matcher: %v", err)
|
||
|
}
|
||
|
|
||
|
tests := []struct {
|
||
|
host string
|
||
|
expected bool
|
||
|
}{
|
||
|
{"test.local", true},
|
||
|
{"printer.local", true},
|
||
|
{"local", false}, // 不匹配 *.local
|
||
|
{"localhost", true},
|
||
|
{"www.baidu.cn", true},
|
||
|
{"qq.cn", true},
|
||
|
{"china.com", false}, // 不匹配 *.cn
|
||
|
{"example.com", false},
|
||
|
}
|
||
|
|
||
|
for _, test := range tests {
|
||
|
t.Run(test.host, func(t *testing.T) {
|
||
|
result := matcher.matchesBypassDomains(test.host)
|
||
|
if result != test.expected {
|
||
|
t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestRouteMatcher_IsLocalDomain(t *testing.T) {
|
||
|
cfg := &config.Routing{}
|
||
|
matcher, _ := NewRouteMatcher(cfg)
|
||
|
|
||
|
tests := []struct {
|
||
|
host string
|
||
|
expected bool
|
||
|
}{
|
||
|
{"localhost", true},
|
||
|
{"test.local", true},
|
||
|
{"printer.lan", true},
|
||
|
{"server.internal", true},
|
||
|
{"router.home", true},
|
||
|
{"pc.corp", true},
|
||
|
{"singleword", true}, // 单词域名
|
||
|
{"example.com", false},
|
||
|
{"www.google.com", false},
|
||
|
{"192.168.1.1", false}, // IP地址不是域名
|
||
|
}
|
||
|
|
||
|
for _, test := range tests {
|
||
|
t.Run(test.host, func(t *testing.T) {
|
||
|
result := matcher.isLocalDomain(test.host)
|
||
|
if result != test.expected {
|
||
|
t.Errorf("For host %s, expected %v, got %v", test.host, test.expected, result)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestRouteMatcher_IsPrivateIP(t *testing.T) {
|
||
|
cfg := &config.Routing{}
|
||
|
matcher, _ := NewRouteMatcher(cfg)
|
||
|
|
||
|
tests := []struct {
|
||
|
ip string
|
||
|
expected bool
|
||
|
}{
|
||
|
// 私有IPv4地址
|
||
|
{"192.168.1.1", true},
|
||
|
{"10.0.0.1", true},
|
||
|
{"172.16.1.1", true},
|
||
|
{"127.0.0.1", true}, // 环回地址
|
||
|
|
||
|
// 公网IPv4地址
|
||
|
{"8.8.8.8", false},
|
||
|
{"1.1.1.1", false},
|
||
|
{"114.114.114.114", false},
|
||
|
|
||
|
// IPv6地址
|
||
|
{"::1", true}, // 环回
|
||
|
{"fe80::1", true}, // 链路本地
|
||
|
{"fc00::1", true}, // 唯一本地
|
||
|
{"2001:db8::1", false}, // 公网(测试用)
|
||
|
}
|
||
|
|
||
|
for _, test := range tests {
|
||
|
t.Run(test.ip, func(t *testing.T) {
|
||
|
// 解析IP
|
||
|
ip := parseIPForTest(test.ip)
|
||
|
if ip == nil {
|
||
|
t.Fatalf("Failed to parse IP: %s", test.ip)
|
||
|
}
|
||
|
|
||
|
result := matcher.isPrivateIP(ip)
|
||
|
if result != test.expected {
|
||
|
t.Errorf("For IP %s, expected %v, got %v", test.ip, test.expected, result)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestRouteMatcher_GetStats(t *testing.T) {
|
||
|
cfg := &config.Routing{
|
||
|
BypassLocal: true,
|
||
|
BypassPrivate: false,
|
||
|
BypassDomains: []string{"*.local", "*.lan", "*.cn"},
|
||
|
ForceDomains: []string{"*.google.com", "*.github.com"},
|
||
|
}
|
||
|
|
||
|
matcher, err := NewRouteMatcher(cfg)
|
||
|
if err != nil {
|
||
|
t.Fatalf("Failed to create route matcher: %v", err)
|
||
|
}
|
||
|
|
||
|
stats := matcher.GetStats()
|
||
|
|
||
|
if stats["bypass_domains_count"] != 3 {
|
||
|
t.Errorf("Expected 3 bypass domains, got %v", stats["bypass_domains_count"])
|
||
|
}
|
||
|
|
||
|
if stats["force_domains_count"] != 2 {
|
||
|
t.Errorf("Expected 2 force domains, got %v", stats["force_domains_count"])
|
||
|
}
|
||
|
|
||
|
if stats["bypass_local"] != true {
|
||
|
t.Errorf("Expected bypass_local to be true, got %v", stats["bypass_local"])
|
||
|
}
|
||
|
|
||
|
if stats["bypass_private"] != false {
|
||
|
t.Errorf("Expected bypass_private to be false, got %v", stats["bypass_private"])
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestDomainToRegexp(t *testing.T) {
|
||
|
cfg := &config.Routing{}
|
||
|
matcher, _ := NewRouteMatcher(cfg)
|
||
|
|
||
|
tests := []struct {
|
||
|
domain string
|
||
|
host string
|
||
|
matches bool
|
||
|
}{
|
||
|
{"*.google.com", "www.google.com", true},
|
||
|
{"*.google.com", "api.google.com", true},
|
||
|
{"*.google.com", "google.com", false},
|
||
|
{"github.com", "github.com", true},
|
||
|
{"github.com", "api.github.com", false},
|
||
|
{"*.example.*", "test.example.org", true},
|
||
|
{"*.example.*", "sub.example.net", true},
|
||
|
{"*.example.*", "example.com", false},
|
||
|
}
|
||
|
|
||
|
for _, test := range tests {
|
||
|
t.Run(test.domain+"->"+test.host, func(t *testing.T) {
|
||
|
pattern, err := matcher.domainToRegexp(test.domain)
|
||
|
if err != nil {
|
||
|
t.Fatalf("Failed to compile pattern %s: %v", test.domain, err)
|
||
|
}
|
||
|
|
||
|
matches := pattern.MatchString(test.host)
|
||
|
if matches != test.matches {
|
||
|
t.Errorf("Pattern %s against host %s: expected %v, got %v",
|
||
|
test.domain, test.host, test.matches, matches)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// 辅助函数
|
||
|
func parseIPForTest(s string) []byte {
|
||
|
// 简单的IP解析用于测试
|
||
|
switch s {
|
||
|
case "192.168.1.1":
|
||
|
return []byte{192, 168, 1, 1}
|
||
|
case "10.0.0.1":
|
||
|
return []byte{10, 0, 0, 1}
|
||
|
case "172.16.1.1":
|
||
|
return []byte{172, 16, 1, 1}
|
||
|
case "127.0.0.1":
|
||
|
return []byte{127, 0, 0, 1}
|
||
|
case "8.8.8.8":
|
||
|
return []byte{8, 8, 8, 8}
|
||
|
case "1.1.1.1":
|
||
|
return []byte{1, 1, 1, 1}
|
||
|
case "114.114.114.114":
|
||
|
return []byte{114, 114, 114, 114}
|
||
|
// IPv6地址处理会更复杂,这里简化
|
||
|
case "::1":
|
||
|
return []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||
|
case "fe80::1":
|
||
|
return []byte{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||
|
case "fc00::1":
|
||
|
return []byte{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||
|
case "2001:db8::1":
|
||
|
return []byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
||
|
default:
|
||
|
return nil
|
||
|
}
|
||
|
}
|