package pool import ( "errors" "net" "sync" "time" "github.com/sirupsen/logrus" ) var ( ErrPoolClosed = errors.New("connection pool is closed") ErrPoolFull = errors.New("connection pool is full") ErrConnExpired = errors.New("connection has expired") ErrConnInvalid = errors.New("connection is invalid") ) // ConnectionPool 连接池 type ConnectionPool struct { logger *logrus.Logger factory Factory pool chan *PooledConnection mu sync.RWMutex closed bool maxSize int maxLifetime time.Duration maxIdle time.Duration // 统计信息 stats Stats } // PooledConnection 池化连接 type PooledConnection struct { conn net.Conn createdAt time.Time lastUsed time.Time pool *ConnectionPool } // Factory 连接工厂 type Factory interface { Create() (net.Conn, error) Validate(net.Conn) bool Close(net.Conn) error } // Config 连接池配置 type Config struct { MaxSize int `json:"maxSize"` MaxLifetime time.Duration `json:"maxLifetime"` MaxIdle time.Duration `json:"maxIdle"` InitialSize int `json:"initialSize"` } // Stats 统计信息 type Stats struct { mu sync.RWMutex Created int64 `json:"created"` Reused int64 `json:"reused"` Closed int64 `json:"closed"` Active int64 `json:"active"` Idle int64 `json:"idle"` Failures int64 `json:"failures"` } // NewConnectionPool 创建新的连接池 func NewConnectionPool(config Config, factory Factory, logger *logrus.Logger) (*ConnectionPool, error) { if config.MaxSize <= 0 { config.MaxSize = 100 } if config.MaxLifetime == 0 { config.MaxLifetime = 30 * time.Minute } if config.MaxIdle == 0 { config.MaxIdle = 5 * time.Minute } pool := &ConnectionPool{ logger: logger, factory: factory, pool: make(chan *PooledConnection, config.MaxSize), maxSize: config.MaxSize, maxLifetime: config.MaxLifetime, maxIdle: config.MaxIdle, } // 预创建连接 for i := 0; i < config.InitialSize && i < config.MaxSize; i++ { conn, err := pool.factory.Create() if err != nil { pool.logger.WithError(err).Warn("Failed to create initial connection") continue } pooledConn := &PooledConnection{ conn: conn, createdAt: time.Now(), lastUsed: time.Now(), pool: pool, } select { case pool.pool <- pooledConn: pool.stats.incCreated() pool.stats.incIdle() default: conn.Close() } } // 启动清理goroutine go pool.cleaner() return pool, nil } // Get 获取连接 func (p *ConnectionPool) Get() (*PooledConnection, error) { p.mu.RLock() if p.closed { p.mu.RUnlock() return nil, ErrPoolClosed } p.mu.RUnlock() // 尝试从池中获取连接 for { select { case conn := <-p.pool: p.stats.decIdle() // 检查连接是否有效 if p.isConnValid(conn) { conn.lastUsed = time.Now() p.stats.incReused() p.stats.incActive() return conn, nil } // 连接无效,关闭并继续尝试 p.closeConn(conn) continue default: // 池中没有可用连接,创建新连接 return p.createConnection() } } } // Put 归还连接到池 func (p *ConnectionPool) Put(conn *PooledConnection) error { if conn == nil { return nil } p.mu.RLock() if p.closed { p.mu.RUnlock() p.closeConn(conn) return ErrPoolClosed } p.mu.RUnlock() p.stats.decActive() // 检查连接是否有效 if !p.isConnValid(conn) { p.closeConn(conn) return ErrConnInvalid } // 尝试归还到池 select { case p.pool <- conn: p.stats.incIdle() return nil default: // 池已满,关闭连接 p.closeConn(conn) return ErrPoolFull } } // Close 关闭连接池 func (p *ConnectionPool) Close() error { p.mu.Lock() if p.closed { p.mu.Unlock() return nil } p.closed = true p.mu.Unlock() // 关闭所有池中的连接 close(p.pool) for conn := range p.pool { p.closeConn(conn) } p.logger.Info("Connection pool closed") return nil } // GetStats 获取统计信息 func (p *ConnectionPool) GetStats() Stats { p.stats.mu.RLock() defer p.stats.mu.RUnlock() stats := p.stats stats.Idle = int64(len(p.pool)) return stats } // createConnection 创建新连接 func (p *ConnectionPool) createConnection() (*PooledConnection, error) { conn, err := p.factory.Create() if err != nil { p.stats.incFailures() return nil, err } pooledConn := &PooledConnection{ conn: conn, createdAt: time.Now(), lastUsed: time.Now(), pool: p, } p.stats.incCreated() p.stats.incActive() return pooledConn, nil } // isConnValid 检查连接是否有效 func (p *ConnectionPool) isConnValid(conn *PooledConnection) bool { now := time.Now() // 检查连接生命周期 if now.Sub(conn.createdAt) > p.maxLifetime { return false } // 检查空闲时间 if now.Sub(conn.lastUsed) > p.maxIdle { return false } // 使用工厂验证连接 return p.factory.Validate(conn.conn) } // closeConn 关闭连接 func (p *ConnectionPool) closeConn(conn *PooledConnection) { if conn != nil && conn.conn != nil { p.factory.Close(conn.conn) p.stats.incClosed() } } // cleaner 清理过期连接 func (p *ConnectionPool) cleaner() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: p.cleanExpiredConnections() } } } // cleanExpiredConnections 清理过期连接 func (p *ConnectionPool) cleanExpiredConnections() { p.mu.RLock() if p.closed { p.mu.RUnlock() return } p.mu.RUnlock() // 检查池中的连接 poolSize := len(p.pool) cleaned := 0 for i := 0; i < poolSize; i++ { select { case conn := <-p.pool: if p.isConnValid(conn) { // 连接有效,放回池中 select { case p.pool <- conn: default: // 池满了,关闭连接 p.closeConn(conn) cleaned++ } } else { // 连接无效,关闭 p.closeConn(conn) p.stats.decIdle() cleaned++ } default: break } } if cleaned > 0 { p.logger.WithField("count", cleaned).Debug("Cleaned expired connections") } } // PooledConnection 方法 // Read 读取数据 func (pc *PooledConnection) Read(b []byte) (n int, err error) { return pc.conn.Read(b) } // Write 写入数据 func (pc *PooledConnection) Write(b []byte) (n int, err error) { return pc.conn.Write(b) } // Close 关闭连接(归还到池) func (pc *PooledConnection) Close() error { return pc.pool.Put(pc) } // ForceClose 强制关闭连接 func (pc *PooledConnection) ForceClose() error { pc.pool.closeConn(pc) return nil } // LocalAddr 获取本地地址 func (pc *PooledConnection) LocalAddr() net.Addr { return pc.conn.LocalAddr() } // RemoteAddr 获取远程地址 func (pc *PooledConnection) RemoteAddr() net.Addr { return pc.conn.RemoteAddr() } // SetDeadline 设置截止时间 func (pc *PooledConnection) SetDeadline(t time.Time) error { return pc.conn.SetDeadline(t) } // SetReadDeadline 设置读截止时间 func (pc *PooledConnection) SetReadDeadline(t time.Time) error { return pc.conn.SetReadDeadline(t) } // SetWriteDeadline 设置写截止时间 func (pc *PooledConnection) SetWriteDeadline(t time.Time) error { return pc.conn.SetWriteDeadline(t) } // Stats 方法 func (s *Stats) incCreated() { s.mu.Lock() s.Created++ s.mu.Unlock() } func (s *Stats) incReused() { s.mu.Lock() s.Reused++ s.mu.Unlock() } func (s *Stats) incClosed() { s.mu.Lock() s.Closed++ s.mu.Unlock() } func (s *Stats) incActive() { s.mu.Lock() s.Active++ s.mu.Unlock() } func (s *Stats) decActive() { s.mu.Lock() s.Active-- s.mu.Unlock() } func (s *Stats) incIdle() { s.mu.Lock() s.Idle++ s.mu.Unlock() } func (s *Stats) decIdle() { s.mu.Lock() s.Idle-- s.mu.Unlock() } func (s *Stats) incFailures() { s.mu.Lock() s.Failures++ s.mu.Unlock() }