Python 性能优化实战——cProfile、line_profiler、Cython、numba 加速
2026/4/30大约 7 分钟
Python 性能优化实战——cProfile、line_profiler、Cython、numba 加速
适读人群:Python 代码跑得慢、想系统优化性能的工程师 | 阅读时长:约 16 分钟 | 核心价值:建立完整的 Python 性能优化方法论,从诊断到加速一条龙
"先测量,再优化"——我吃过的亏
刚开始做 Python 性能优化的时候,我犯了一个经典错误:凭感觉猜瓶颈。
有个数据处理脚本跑了 40 秒,我以为是 JSON 解析慢,花了两个小时把手动解析改成了 orjson。结果脚本还是 40 秒——因为瓶颈根本不在 JSON 解析上,而在一个嵌套循环里的字符串拼接。
后来我用 cProfile 一跑,一眼就看出来了——字符串拼接那个函数占了总时间的 87%。
这个教训让我牢记:性能优化的第一步永远是测量,而不是优化。
一、性能分析工具
1.1 cProfile:函数级分析
import cProfile
import pstats
import io
def profile_code(func, *args, **kwargs):
"""对函数进行性能分析"""
pr = cProfile.Profile()
pr.enable()
result = func(*args, **kwargs)
pr.disable()
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats("cumulative")
ps.print_stats(20) # 打印前 20 个最耗时的函数
print(s.getvalue())
return result
# 或者命令行使用
# python -m cProfile -s cumtime my_script.py1.2 line_profiler:逐行分析
pip install line_profilerfrom line_profiler import LineProfiler
def slow_function(data: list) -> list:
result = []
for item in data:
processed = item ** 2
result.append(processed) # 这行每次都有内存分配
return result
# 使用 @profile 装饰器(配合 kernprof -l 命令)
@profile # noqa
def slow_function_with_profile(data: list) -> list:
result = []
for item in data:
processed = item ** 2
result.append(processed)
return result
# 命令行:kernprof -l -v script.py1.3 memory_profiler:内存分析
pip install memory_profilerfrom memory_profiler import profile
@profile
def memory_heavy_function():
# @profile 会打印每行的内存变化
big_list = [i for i in range(1_000_000)] # +8MB
result = {str(i): i for i in big_list} # +100MB
del big_list # -8MB
return result二、Python 层优化:零成本收益
在上 Cython/numba 之前,先把 Python 层的优化做到位:
2.1 选对数据结构
import time
# 列表查找:O(n)
def search_list(items: list, target: int) -> bool:
return target in items
# 集合查找:O(1)
def search_set(items: set, target: int) -> bool:
return target in items
data = list(range(1_000_000))
data_set = set(data)
start = time.perf_counter()
for _ in range(1000):
search_list(data, 999_999)
print(f"list 查找: {(time.perf_counter()-start)*1000:.1f}ms")
start = time.perf_counter()
for _ in range(1000):
search_set(data_set, 999_999)
print(f"set 查找: {(time.perf_counter()-start)*1000:.1f}ms")
# set 快几百倍!2.2 避免重复属性查找
import time
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
points = [Point(i, i) for i in range(100_000)]
# 慢:每次循环都查找 math.sqrt
import math
start = time.perf_counter()
result = [math.sqrt(p.x**2 + p.y**2) for p in points]
print(f"每次查找 math.sqrt: {time.perf_counter()-start:.4f}s")
# 快:缓存局部变量
start = time.perf_counter()
sqrt = math.sqrt # 缓存到局部变量
result = [sqrt(p.x**2 + p.y**2) for p in points]
print(f"局部变量 sqrt: {time.perf_counter()-start:.4f}s")2.3 字符串拼接
# 慢:+= 拼接(每次都创建新字符串)
parts = []
result = ""
for i in range(10000):
result += str(i) # O(n²)
# 快:join(一次分配)
result = "".join(str(i) for i in range(10000)) # O(n)
# 快:f-string
name = "老张"
greeting = f"你好,{name}!" # 比 % 格式化和 .format() 都快三、NumPy 向量化:数值计算的银弹
import numpy as np
import time
# Python 循环(慢)
def python_loop(arr):
result = []
for x in arr:
result.append(x ** 2 + 2 * x + 1)
return result
# NumPy 向量化(快 100 倍)
def numpy_vectorized(arr):
return arr ** 2 + 2 * arr + 1
data_list = list(range(1_000_000))
data_np = np.array(data_list, dtype=np.float64)
start = time.perf_counter()
r1 = python_loop(data_list)
print(f"Python 循环: {time.perf_counter()-start:.3f}s")
start = time.perf_counter()
r2 = numpy_vectorized(data_np)
print(f"NumPy 向量化: {time.perf_counter()-start:.3f}s")四、numba:即时编译加速
pip install numbaimport numba
import numpy as np
import time
# 纯 Python 版本
def mandelbrot_python(c: complex, max_iter: int = 100) -> int:
z = 0
for n in range(max_iter):
if abs(z) > 2:
return n
z = z * z + c
return max_iter
# numba JIT 版本
@numba.jit(nopython=True, cache=True)
def mandelbrot_numba(c_real: float, c_imag: float, max_iter: int = 100) -> int:
z_real = 0.0
z_imag = 0.0
for n in range(max_iter):
if z_real * z_real + z_imag * z_imag > 4.0:
return n
new_real = z_real * z_real - z_imag * z_imag + c_real
z_imag = 2 * z_real * z_imag + c_imag
z_real = new_real
return max_iter
# numba parallel 版本
@numba.jit(nopython=True, parallel=True)
def compute_mandelbrot_parallel(real_arr, imag_arr, max_iter=100):
result = np.empty(len(real_arr), dtype=np.int32)
for i in numba.prange(len(real_arr)):
result[i] = mandelbrot_numba(real_arr[i], imag_arr[i], max_iter)
return result五、完整可运行示例:性能优化全流程
#!/usr/bin/env python3
"""
Python 性能优化完整演示:从慢到快的优化过程
"""
import cProfile
import io
import pstats
import time
from functools import lru_cache
import numpy as np
# ===== 被优化的函数(各个版本)=====
# 版本1:朴素 Python(最慢)
def compute_stats_v1(data: list[float]) -> dict:
n = len(data)
mean = sum(data) / n
variance = sum((x - mean) ** 2 for x in data) / n
std = variance ** 0.5
return {"n": n, "mean": mean, "std": std, "min": min(data), "max": max(data)}
# 版本2:优化 Python(稍快)
def compute_stats_v2(data: list[float]) -> dict:
n = len(data)
mean = sum(data) / n
# 避免重复计算 mean,用局部变量
diffs_sq = [(x - mean) ** 2 for x in data]
variance = sum(diffs_sq) / n
return {
"n": n,
"mean": mean,
"std": variance ** 0.5,
"min": min(data),
"max": max(data),
}
# 版本3:NumPy 向量化(最快,CPU 密集时比 v1 快 50-100x)
def compute_stats_v3(data: list[float]) -> dict:
arr = np.array(data)
return {
"n": len(arr),
"mean": float(arr.mean()),
"std": float(arr.std()),
"min": float(arr.min()),
"max": float(arr.max()),
}
# 版本4:如果 data 已经是 numpy 数组(避免转换开销)
def compute_stats_v4(arr: np.ndarray) -> dict:
return {
"n": len(arr),
"mean": float(arr.mean()),
"std": float(arr.std()),
"min": float(arr.min()),
"max": float(arr.max()),
}
# ===== 递归 + LRU 缓存 =====
@lru_cache(maxsize=1000)
def fib_cached(n: int) -> int:
if n < 2:
return n
return fib_cached(n - 1) + fib_cached(n - 2)
def fib_uncached(n: int) -> int:
if n < 2:
return n
return fib_uncached(n - 1) + fib_uncached(n - 2)
# ===== 性能分析工具 =====
def benchmark(funcs: list, *args, repeat: int = 3, label: str = "") -> None:
print(f"\n{'='*50}")
print(f"基准测试: {label}")
print(f"{'='*50}")
times = {}
for func in funcs:
elapsed_list = []
for _ in range(repeat):
start = time.perf_counter()
result = func(*args)
elapsed_list.append((time.perf_counter() - start) * 1000)
avg = sum(elapsed_list) / len(elapsed_list)
times[func.__name__] = avg
print(f" {func.__name__:<30} {avg:>8.3f}ms")
fastest = min(times, key=times.get)
slowest = max(times, key=times.get)
if times[slowest] > 0:
speedup = times[slowest] / times[fastest]
print(f"\n 最快: {fastest},最慢: {slowest},加速比: {speedup:.1f}x")
def profile_function(func, *args):
"""使用 cProfile 分析函数性能"""
pr = cProfile.Profile()
pr.enable()
func(*args)
pr.disable()
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats("tottime")
ps.print_stats(10)
print(s.getvalue())
def main():
print("=== Python 性能优化演示 ===\n")
# 生成测试数据
import random
data = [random.gauss(100, 15) for _ in range(100_000)]
data_np = np.array(data)
# 基准测试:统计计算
benchmark(
[compute_stats_v1, compute_stats_v2, compute_stats_v3],
data,
repeat=5,
label="统计计算(列表输入)",
)
benchmark(
[compute_stats_v4],
data_np,
repeat=5,
label="统计计算(numpy 数组输入)",
)
# 缓存效果
print(f"\n{'='*50}")
print("LRU Cache 效果")
print(f"{'='*50}")
n = 35
start = time.perf_counter()
r1 = fib_uncached(n)
t1 = (time.perf_counter() - start) * 1000
print(f" 无缓存 fib({n}): {t1:.2f}ms")
fib_cached.cache_clear()
start = time.perf_counter()
r2 = fib_cached(n)
t2 = (time.perf_counter() - start) * 1000
print(f" 有缓存 fib({n}) 首次: {t2:.2f}ms")
start = time.perf_counter()
r3 = fib_cached(n)
t3 = (time.perf_counter() - start) * 1000
print(f" 有缓存 fib({n}) 再次: {t3:.4f}ms (从缓存返回)")
print(f" 缓存信息: {fib_cached.cache_info()}")
# 数据结构选择
print(f"\n{'='*50}")
print("数据结构:list vs set 查找")
print(f"{'='*50}")
big_list = list(range(1_000_000))
big_set = set(big_list)
target = 999_999
start = time.perf_counter()
for _ in range(100):
_ = target in big_list
t_list = (time.perf_counter() - start) * 1000
print(f" list 查找 100次: {t_list:.2f}ms")
start = time.perf_counter()
for _ in range(100):
_ = target in big_set
t_set = (time.perf_counter() - start) * 1000
print(f" set 查找 100次: {t_set:.4f}ms")
if t_set > 0:
print(f" set 比 list 快: {t_list/t_set:.0f}x")
if __name__ == "__main__":
main()六、踩坑实录 1:过早优化
# 错误:凭感觉猜瓶颈,花大量时间优化非关键路径
# 比如:把 json.loads 换成 orjson,只节省了 5ms,但整个脚本要跑 30s
# 正确:先用 cProfile 找出真正的热点
import cProfile
cProfile.run("my_slow_function()", sort="cumulative")
# 看 cumulative time 最高的函数,那才是真正的瓶颈七、踩坑实录 2:numba 首次调用慢(JIT 预热)
# 问题:numba @jit 的第一次调用会触发编译,可能慢几秒
@numba.jit(nopython=True)
def my_func(x):
return x ** 2
start = time.time()
my_func(1.0) # 第一次:编译 + 运行,可能 2-5 秒
print(f"首次: {time.time()-start:.2f}s")
start = time.time()
my_func(2.0) # 后续:直接运行,极快
print(f"再次: {time.time()-start:.6f}s")
# 解决:使用 cache=True,将编译结果缓存到磁盘
@numba.jit(nopython=True, cache=True) # 重启后不需要重新编译
def my_func_cached(x):
return x ** 2八、踩坑实录 3:NumPy 广播的内存陷阱
import numpy as np
# 问题:广播创建了巨大的中间数组
a = np.ones((10_000, 1))
b = np.ones((1, 10_000))
c = a + b # 创建了 10000x10000 的矩阵 = 800MB!
# 解决:如果不需要完整矩阵,用 einsum 或分块计算
result = np.einsum("i,j->ij", a.ravel(), b.ravel()) # 等价但一样消耗内存
# 更好的解决方案:只计算需要的部分
row_sums = a.ravel() + b.ravel() # 如果只需要行/列求和总结
Python 性能优化的方法论:
- 先测量:cProfile 找函数热点,line_profiler 找行热点
- Python 层优化:正确的数据结构、避免重复查找、列表推导 vs join
- NumPy 向量化:数值计算首选,比纯 Python 快 50-100 倍
- lru_cache:有重复计算的递归函数,加缓存立竿见影
- numba:科学计算热点函数,@jit 接近 C 语言速度
- 多进程:CPU 密集 + 多核场景,突破 GIL 限制
