You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

414 lines
7.6 KiB

2 weeks ago
package pool
import (
"errors"
"net"
"sync"
"time"
"github.com/sirupsen/logrus"
)
var (
ErrPoolClosed = errors.New("connection pool is closed")
ErrPoolFull = errors.New("connection pool is full")
ErrConnExpired = errors.New("connection has expired")
ErrConnInvalid = errors.New("connection is invalid")
)
// ConnectionPool 连接池
type ConnectionPool struct {
logger *logrus.Logger
factory Factory
pool chan *PooledConnection
mu sync.RWMutex
closed bool
maxSize int
maxLifetime time.Duration
maxIdle time.Duration
// 统计信息
stats Stats
}
// PooledConnection 池化连接
type PooledConnection struct {
conn net.Conn
createdAt time.Time
lastUsed time.Time
pool *ConnectionPool
}
// Factory 连接工厂
type Factory interface {
Create() (net.Conn, error)
Validate(net.Conn) bool
Close(net.Conn) error
}
// Config 连接池配置
type Config struct {
MaxSize int `json:"maxSize"`
MaxLifetime time.Duration `json:"maxLifetime"`
MaxIdle time.Duration `json:"maxIdle"`
InitialSize int `json:"initialSize"`
}
// Stats 统计信息
type Stats struct {
mu sync.RWMutex
Created int64 `json:"created"`
Reused int64 `json:"reused"`
Closed int64 `json:"closed"`
Active int64 `json:"active"`
Idle int64 `json:"idle"`
Failures int64 `json:"failures"`
}
// NewConnectionPool 创建新的连接池
func NewConnectionPool(config Config, factory Factory, logger *logrus.Logger) (*ConnectionPool, error) {
if config.MaxSize <= 0 {
config.MaxSize = 100
}
if config.MaxLifetime == 0 {
config.MaxLifetime = 30 * time.Minute
}
if config.MaxIdle == 0 {
config.MaxIdle = 5 * time.Minute
}
pool := &ConnectionPool{
logger: logger,
factory: factory,
pool: make(chan *PooledConnection, config.MaxSize),
maxSize: config.MaxSize,
maxLifetime: config.MaxLifetime,
maxIdle: config.MaxIdle,
}
// 预创建连接
for i := 0; i < config.InitialSize && i < config.MaxSize; i++ {
conn, err := pool.factory.Create()
if err != nil {
pool.logger.WithError(err).Warn("Failed to create initial connection")
continue
}
pooledConn := &PooledConnection{
conn: conn,
createdAt: time.Now(),
lastUsed: time.Now(),
pool: pool,
}
select {
case pool.pool <- pooledConn:
pool.stats.incCreated()
pool.stats.incIdle()
default:
conn.Close()
}
}
// 启动清理goroutine
go pool.cleaner()
return pool, nil
}
// Get 获取连接
func (p *ConnectionPool) Get() (*PooledConnection, error) {
p.mu.RLock()
if p.closed {
p.mu.RUnlock()
return nil, ErrPoolClosed
}
p.mu.RUnlock()
// 尝试从池中获取连接
for {
select {
case conn := <-p.pool:
p.stats.decIdle()
// 检查连接是否有效
if p.isConnValid(conn) {
conn.lastUsed = time.Now()
p.stats.incReused()
p.stats.incActive()
return conn, nil
}
// 连接无效,关闭并继续尝试
p.closeConn(conn)
continue
default:
// 池中没有可用连接,创建新连接
return p.createConnection()
}
}
}
// Put 归还连接到池
func (p *ConnectionPool) Put(conn *PooledConnection) error {
if conn == nil {
return nil
}
p.mu.RLock()
if p.closed {
p.mu.RUnlock()
p.closeConn(conn)
return ErrPoolClosed
}
p.mu.RUnlock()
p.stats.decActive()
// 检查连接是否有效
if !p.isConnValid(conn) {
p.closeConn(conn)
return ErrConnInvalid
}
// 尝试归还到池
select {
case p.pool <- conn:
p.stats.incIdle()
return nil
default:
// 池已满,关闭连接
p.closeConn(conn)
return ErrPoolFull
}
}
// Close 关闭连接池
func (p *ConnectionPool) Close() error {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return nil
}
p.closed = true
p.mu.Unlock()
// 关闭所有池中的连接
close(p.pool)
for conn := range p.pool {
p.closeConn(conn)
}
p.logger.Info("Connection pool closed")
return nil
}
// GetStats 获取统计信息
func (p *ConnectionPool) GetStats() Stats {
p.stats.mu.RLock()
defer p.stats.mu.RUnlock()
stats := p.stats
stats.Idle = int64(len(p.pool))
return stats
}
// createConnection 创建新连接
func (p *ConnectionPool) createConnection() (*PooledConnection, error) {
conn, err := p.factory.Create()
if err != nil {
p.stats.incFailures()
return nil, err
}
pooledConn := &PooledConnection{
conn: conn,
createdAt: time.Now(),
lastUsed: time.Now(),
pool: p,
}
p.stats.incCreated()
p.stats.incActive()
return pooledConn, nil
}
// isConnValid 检查连接是否有效
func (p *ConnectionPool) isConnValid(conn *PooledConnection) bool {
now := time.Now()
// 检查连接生命周期
if now.Sub(conn.createdAt) > p.maxLifetime {
return false
}
// 检查空闲时间
if now.Sub(conn.lastUsed) > p.maxIdle {
return false
}
// 使用工厂验证连接
return p.factory.Validate(conn.conn)
}
// closeConn 关闭连接
func (p *ConnectionPool) closeConn(conn *PooledConnection) {
if conn != nil && conn.conn != nil {
p.factory.Close(conn.conn)
p.stats.incClosed()
}
}
// cleaner 清理过期连接
func (p *ConnectionPool) cleaner() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
p.cleanExpiredConnections()
}
}
}
// cleanExpiredConnections 清理过期连接
func (p *ConnectionPool) cleanExpiredConnections() {
p.mu.RLock()
if p.closed {
p.mu.RUnlock()
return
}
p.mu.RUnlock()
// 检查池中的连接
poolSize := len(p.pool)
cleaned := 0
for i := 0; i < poolSize; i++ {
select {
case conn := <-p.pool:
if p.isConnValid(conn) {
// 连接有效,放回池中
select {
case p.pool <- conn:
default:
// 池满了,关闭连接
p.closeConn(conn)
cleaned++
}
} else {
// 连接无效,关闭
p.closeConn(conn)
p.stats.decIdle()
cleaned++
}
default:
break
}
}
if cleaned > 0 {
p.logger.WithField("count", cleaned).Debug("Cleaned expired connections")
}
}
// PooledConnection 方法
// Read 读取数据
func (pc *PooledConnection) Read(b []byte) (n int, err error) {
return pc.conn.Read(b)
}
// Write 写入数据
func (pc *PooledConnection) Write(b []byte) (n int, err error) {
return pc.conn.Write(b)
}
// Close 关闭连接(归还到池)
func (pc *PooledConnection) Close() error {
return pc.pool.Put(pc)
}
// ForceClose 强制关闭连接
func (pc *PooledConnection) ForceClose() error {
pc.pool.closeConn(pc)
return nil
}
// LocalAddr 获取本地地址
func (pc *PooledConnection) LocalAddr() net.Addr {
return pc.conn.LocalAddr()
}
// RemoteAddr 获取远程地址
func (pc *PooledConnection) RemoteAddr() net.Addr {
return pc.conn.RemoteAddr()
}
// SetDeadline 设置截止时间
func (pc *PooledConnection) SetDeadline(t time.Time) error {
return pc.conn.SetDeadline(t)
}
// SetReadDeadline 设置读截止时间
func (pc *PooledConnection) SetReadDeadline(t time.Time) error {
return pc.conn.SetReadDeadline(t)
}
// SetWriteDeadline 设置写截止时间
func (pc *PooledConnection) SetWriteDeadline(t time.Time) error {
return pc.conn.SetWriteDeadline(t)
}
// Stats 方法
func (s *Stats) incCreated() {
s.mu.Lock()
s.Created++
s.mu.Unlock()
}
func (s *Stats) incReused() {
s.mu.Lock()
s.Reused++
s.mu.Unlock()
}
func (s *Stats) incClosed() {
s.mu.Lock()
s.Closed++
s.mu.Unlock()
}
func (s *Stats) incActive() {
s.mu.Lock()
s.Active++
s.mu.Unlock()
}
func (s *Stats) decActive() {
s.mu.Lock()
s.Active--
s.mu.Unlock()
}
func (s *Stats) incIdle() {
s.mu.Lock()
s.Idle++
s.mu.Unlock()
}
func (s *Stats) decIdle() {
s.mu.Lock()
s.Idle--
s.mu.Unlock()
}
func (s *Stats) incFailures() {
s.mu.Lock()
s.Failures++
s.mu.Unlock()
}