Go 开发压测工具——类 wrk 的 HTTP 压测工具从零实现
Go 开发压测工具——类 wrk 的 HTTP 压测工具从零实现
适读人群:Go 开发者、需要定制化 HTTP 压测工具的工程师 | 阅读时长:约 17 分钟 | 核心价值:理解 wrk 的核心机制,用 Go 实现支持自定义场景的 HTTP 压测工具
去年帮一个做电商秒杀的团队做上线前的压测,他们用 wrk 做基准测试,但 wrk 有个限制:所有请求都是一样的,没法模拟真实业务场景——真实的秒杀流量是:先登录获取 token,再拿着 token 去抢购,两个接口有依赖关系,用户身份也不同。
LuaJIT 脚本能在 wrk 里做一些自定义,但学习成本高,调试困难。我建议他们用 Go 实现一个定制化压测工具,逻辑就是 Go 代码,调试方便,而且可以直接复用他们业务代码里的数据结构。
这篇文章把那个工具的核心实现整理出来。
设计目标
- 并发 goroutine 模拟并发用户
- 支持连接复用(HTTP/1.1 Keep-Alive)
- 实时统计:QPS、延迟分布、错误率
- 支持多阶段场景(先 warm-up,再加压)
- 结果输出:百分位延迟(P50/P90/P99/P99.9)
完整实现
package bench
import (
"context"
"fmt"
"math"
"net/http"
"sort"
"sync"
"sync/atomic"
"time"
)
// Config 压测配置
type Config struct {
Workers int // 并发 worker 数
Duration time.Duration // 压测持续时间
RateLimit int // 每秒总请求数限制(0 = 不限)
WarmupTime time.Duration // 预热时间(预热阶段数据不计入结果)
}
// RequestFunc 请求函数类型,由使用方实现具体的 HTTP 请求逻辑
type RequestFunc func(client *http.Client) (statusCode int, latency time.Duration, err error)
// Stats 统计数据
type Stats struct {
TotalRequests int64
SuccessCount int64
ErrorCount int64
TotalLatencyMs int64
Latencies []int64 // 所有请求的延迟(毫秒)
StatusCodes map[int]int64
StartTime time.Time
EndTime time.Time
}
func (s *Stats) Duration() time.Duration {
return s.EndTime.Sub(s.StartTime)
}
func (s *Stats) QPS() float64 {
d := s.Duration().Seconds()
if d == 0 {
return 0
}
return float64(s.TotalRequests) / d
}
func (s *Stats) AvgLatencyMs() float64 {
if s.TotalRequests == 0 {
return 0
}
return float64(s.TotalLatencyMs) / float64(s.TotalRequests)
}
func (s *Stats) Percentile(p float64) int64 {
if len(s.Latencies) == 0 {
return 0
}
sorted := make([]int64, len(s.Latencies))
copy(sorted, s.Latencies)
sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] })
idx := int(math.Ceil(p/100.0*float64(len(sorted)))) - 1
if idx < 0 {
idx = 0
}
if idx >= len(sorted) {
idx = len(sorted) - 1
}
return sorted[idx]
}
func (s *Stats) ErrorRate() float64 {
if s.TotalRequests == 0 {
return 0
}
return float64(s.ErrorCount) / float64(s.TotalRequests) * 100
}
// Benchmarker HTTP 压测器
type Benchmarker struct {
config Config
httpClient *http.Client
// 原子计数器(用于实时统计)
reqCount atomic.Int64
errCount atomic.Int64
// 延迟数据(需要加锁)
mu sync.Mutex
latencies []int64
statusCodes map[int]int64
}
func NewBenchmarker(config Config) *Benchmarker {
transport := &http.Transport{
MaxIdleConnsPerHost: config.Workers,
IdleConnTimeout: 90 * time.Second,
DisableCompression: true, // 压测时关闭压缩,减少 CPU 开销
}
return &Benchmarker{
config: config,
httpClient: &http.Client{
Transport: transport,
Timeout: 30 * time.Second,
},
statusCodes: make(map[int]int64),
}
}
// Run 执行压测
func (b *Benchmarker) Run(ctx context.Context, reqFunc RequestFunc) *Stats {
// 预热阶段
if b.config.WarmupTime > 0 {
fmt.Printf("Warming up for %v...\n", b.config.WarmupTime)
warmupCtx, cancel := context.WithTimeout(ctx, b.config.WarmupTime)
b.runWorkers(warmupCtx, reqFunc, true)
cancel()
// 重置计数器
b.reqCount.Store(0)
b.errCount.Store(0)
b.mu.Lock()
b.latencies = nil
b.statusCodes = make(map[int]int64)
b.mu.Unlock()
fmt.Println("Warmup done, starting benchmark...")
}
// 正式压测
pressCtx, cancel := context.WithTimeout(ctx, b.config.Duration)
defer cancel()
start := time.Now()
b.runWorkers(pressCtx, reqFunc, false)
end := time.Now()
// 收集结果
b.mu.Lock()
defer b.mu.Unlock()
var totalLatency int64
for _, l := range b.latencies {
totalLatency += l
}
statusCodes := make(map[int]int64)
for k, v := range b.statusCodes {
statusCodes[k] = v
}
return &Stats{
TotalRequests: b.reqCount.Load(),
SuccessCount: b.reqCount.Load() - b.errCount.Load(),
ErrorCount: b.errCount.Load(),
TotalLatencyMs: totalLatency,
Latencies: b.latencies,
StatusCodes: statusCodes,
StartTime: start,
EndTime: end,
}
}
func (b *Benchmarker) runWorkers(ctx context.Context, reqFunc RequestFunc, isWarmup bool) {
var wg sync.WaitGroup
// 速率限制
var limiter <-chan time.Time
if b.config.RateLimit > 0 {
ticker := time.NewTicker(time.Second / time.Duration(b.config.RateLimit))
defer ticker.Stop()
limiter = ticker.C
}
for i := 0; i < b.config.Workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
}
// 如果有速率限制,等待令牌
if limiter != nil {
select {
case <-limiter:
case <-ctx.Done():
return
}
}
statusCode, latencyMs, err := b.doRequest(reqFunc)
if !isWarmup {
b.reqCount.Add(1)
if err != nil {
b.errCount.Add(1)
}
b.mu.Lock()
b.latencies = append(b.latencies, latencyMs.Milliseconds())
b.statusCodes[statusCode]++
b.mu.Unlock()
}
}
}()
}
wg.Wait()
}
func (b *Benchmarker) doRequest(reqFunc RequestFunc) (int, time.Duration, error) {
start := time.Now()
statusCode, latency, err := reqFunc(b.httpClient)
if latency == 0 {
latency = time.Since(start)
}
return statusCode, latency, err
}
// PrintReport 打印压测报告
func PrintReport(stats *Stats) {
fmt.Println("\n========== 压测报告 ==========")
fmt.Printf("总请求数: %d\n", stats.TotalRequests)
fmt.Printf("成功请求: %d\n", stats.SuccessCount)
fmt.Printf("失败请求: %d (%.2f%%)\n", stats.ErrorCount, stats.ErrorRate())
fmt.Printf("压测时长: %v\n", stats.Duration().Round(time.Millisecond))
fmt.Printf("QPS: %.2f req/s\n", stats.QPS())
fmt.Println()
fmt.Println("--- 延迟分布 ---")
fmt.Printf("平均延迟: %.2f ms\n", stats.AvgLatencyMs())
fmt.Printf("P50: %d ms\n", stats.Percentile(50))
fmt.Printf("P90: %d ms\n", stats.Percentile(90))
fmt.Printf("P99: %d ms\n", stats.Percentile(99))
fmt.Printf("P99.9: %d ms\n", stats.Percentile(99.9))
fmt.Printf("最大延迟: %d ms\n", stats.Percentile(100))
fmt.Println()
fmt.Println("--- HTTP 状态码分布 ---")
for code, count := range stats.StatusCodes {
fmt.Printf(" %d: %d\n", code, count)
}
fmt.Println("================================")
}使用示例:秒杀场景压测
func main() {
// 准备测试账号(实际压测时用真实账号)
tokens := prepareTokens(100) // 100 个账号的 token
tokenIdx := atomic.Int64{}
// 获取 token(轮询)
getToken := func() string {
idx := int(tokenIdx.Add(1)) % len(tokens)
return tokens[idx]
}
bench := bench.NewBenchmarker(bench.Config{
Workers: 50,
Duration: 30 * time.Second,
RateLimit: 500, // 500 QPS
WarmupTime: 5 * time.Second,
})
stats := bench.Run(context.Background(), func(client *http.Client) (int, time.Duration, error) {
req, err := http.NewRequest("POST", "https://api.example.com/seckill/buy", strings.NewReader(`{"item_id":12345,"qty":1}`))
if err != nil {
return 0, 0, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+getToken())
start := time.Now()
resp, err := client.Do(req)
latency := time.Since(start)
if err != nil {
return 0, latency, err
}
defer resp.Body.Close()
io.Discard.Write(resp.Body) // 必须读取 body,否则连接不会复用
return resp.StatusCode, latency, nil
})
bench.PrintReport(stats)
}踩坑实录
踩坑 1:不读 Response Body 导致连接无法复用
现象:QPS 远低于预期,连接数持续增长。
原因:HTTP Keep-Alive 复用连接的前提是必须完整读取并关闭 response body。如果只判断 status code 就 return,body 没读完,连接会被关闭而不是放回连接池。
解法:永远要 io.Copy(io.Discard, resp.Body) + resp.Body.Close()。
踩坑 2:P99 延迟数据失真
现象:压测 30 秒,P99 延迟显示 5000ms,但实际上只有极少数请求这么慢,中位数才 20ms。
原因:把所有延迟数据存到一个切片里,最后排序。数据量大(几十万条)时,sort 消耗很多时间,而且切片内存占用也很大。
解法:用直方图(histogram)代替存储所有延迟。把延迟范围分成 bucket,只记录每个 bucket 的计数:
// 简化:用 HDR Histogram 库
import "github.com/HdrHistogram/hdrhistogram-go"
h := hdrhistogram.New(1, 60000, 3) // 1ms to 60s, 3 sig figs
h.RecordValue(latencyMs)
fmt.Printf("P99: %d ms\n", h.ValueAtQuantile(99))踩坑 3:并发量太大时 goroutine 调度开销大
现象:Workers 设到 500 后,QPS 不再增长,CPU 飙升,但 goroutine 利用率低。
原因:500 个 goroutine 频繁等待网络响应,Go scheduler 的调度开销变大。
解法:Workers 不是越多越好,最优值通常是 CPU 核数 * 2 到 CPU 核数 * 10 之间。对于 IO 密集型压测,100-200 个 goroutine 通常已经能打满网络带宽了。
