Python 科学计算工程化——NumPy/SciPy 在生产服务中的正确使用姿势
Python 科学计算工程化——NumPy/SciPy 在生产服务中的正确使用姿势
适读人群:在生产服务中使用 NumPy/SciPy 做计算的 Python 工程师、有科学计算背景想做工程化的开发者 | 阅读时长:约14分钟 | 核心价值:从"能跑"到"生产可用",NumPy/SciPy 的工程化要点
有个很常见的现象:数据科学家在 Jupyter Notebook 里跑得好好的代码,交给工程师部署到生产,立刻出各种问题——内存爆炸、数值溢出、并发错误、边界条件崩溃。
不是代码本身有问题,是从"实验代码"到"生产代码"需要跨越的工程化距离被低估了。
我这几年做了不少把科学计算代码工程化的工作,今天把核心要点写出来。
先搞清楚:NumPy 在生产服务里能做什么、不能做什么
能做:
- 大批量数值计算(向量化操作,利用 BLAS/LAPACK 的多线程)
- 数组变换、切片、索引操作
- 信号处理(FFT、滤波)
- 线性代数(矩阵乘法、分解)
- 统计计算
不能替代的:
- 实时流式处理(NumPy 的批处理模型不适合)
- 超大规模数据(内存放不下的数据,用 Dask 或者分批处理)
- 高并发请求处理(NumPy 操作不是线程安全的,要注意并发写入)
工程化第一要点:数值稳定性
科学计算里最容易被忽视的问题:数值稳定性。在实验环境里数据分布比较好,问题不出现。到了生产,遇到边界输入,直接 nan 或者 inf。
常见的数值陷阱:
import numpy as np
# 陷阱1:对数的数值稳定性
def safe_log_prob(x: np.ndarray) -> np.ndarray:
"""计算对数概率,避免 log(0) = -inf"""
# 错误:np.log(x),x 为 0 时得到 -inf
# 正确:加一个极小的 epsilon
eps = np.finfo(x.dtype).eps # 根据数据类型自动选 epsilon
return np.log(np.maximum(x, eps))
# 陷阱2:Softmax 溢出
def safe_softmax(x: np.ndarray) -> np.ndarray:
"""数值稳定的 softmax"""
# 错误:np.exp(x) / np.sum(np.exp(x)),x 很大时 exp(x) 会 overflow
# 正确:先减去最大值
shifted = x - np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(shifted)
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
# 陷阱3:除法的零除
def safe_divide(a: np.ndarray, b: np.ndarray, fill_value: float = 0.0) -> np.ndarray:
"""安全除法,分母为零时返回 fill_value"""
return np.where(np.abs(b) > 1e-10, a / b, fill_value)
# 陷阱4:开方域错误
def safe_sqrt(x: np.ndarray) -> np.ndarray:
"""安全开方,避免负数开方得到 nan"""
return np.sqrt(np.maximum(x, 0.0))生产代码必须有数值检查:
def validate_numeric_output(arr: np.ndarray, name: str) -> None:
"""验证计算结果,及时发现数值问题"""
if np.any(np.isnan(arr)):
nan_count = np.sum(np.isnan(arr))
raise ValueError(f"{name}: contains {nan_count} NaN values")
if np.any(np.isinf(arr)):
inf_count = np.sum(np.isinf(arr))
raise ValueError(f"{name}: contains {inf_count} Inf values")
if arr.size > 0:
if np.all(arr == arr[0] if arr.ndim == 0 else arr.flat[0]):
import warnings
warnings.warn(f"{name}: all values are identical, possible issue")踩坑实录一:生产数据里的 NaN 级联传播
现象: 一个信号处理服务,某天开始某些用户的输出全是 NaN,但大多数用户正常。排查了很久,因为 NaN 会在计算中传播,找到根源很困难。
原因: 输入数据里有一个传感器故障,某个通道数据全部变成了 NaN。这个 NaN 进入计算之后,传播到了所有相关的输出。
解法:
import numpy as np
from typing import Optional
def preprocess_signal(
raw_data: np.ndarray,
max_nan_ratio: float = 0.1,
fill_strategy: str = "interpolate",
) -> np.ndarray:
"""
信号预处理,包含 NaN 处理
Args:
raw_data: 原始数据,shape (n_samples, n_channels)
max_nan_ratio: 超过这个比例的 NaN,认为数据不可用
fill_strategy: NaN 填充策略,"interpolate" 或 "forward_fill" 或 "zero"
"""
nan_ratio = np.sum(np.isnan(raw_data)) / raw_data.size
if nan_ratio > max_nan_ratio:
raise ValueError(
f"Too many NaN values: {nan_ratio:.1%} > {max_nan_ratio:.1%}. "
f"Data quality is unacceptable."
)
if nan_ratio == 0:
return raw_data
result = raw_data.copy()
if fill_strategy == "interpolate":
# 对每个通道独立插值
for ch in range(result.shape[1] if result.ndim > 1 else 1):
channel = result[:, ch] if result.ndim > 1 else result
nan_mask = np.isnan(channel)
if np.any(nan_mask):
x = np.where(~nan_mask)[0]
y = channel[~nan_mask]
channel[nan_mask] = np.interp(np.where(nan_mask)[0], x, y)
elif fill_strategy == "forward_fill":
for i in range(1, len(result)):
if np.any(np.isnan(result[i])):
result[i] = result[i - 1]
elif fill_strategy == "zero":
result[np.isnan(result)] = 0.0
return resultNumPy 在多线程服务里的并发问题
这是很多人不知道的问题:NumPy 操作在多线程环境里不总是安全的。
NumPy 在读操作上是线程安全的,但写操作(修改数组内容)不是。如果多个线程同时写同一个数组,结果是不确定的。
import numpy as np
import threading
# 危险:多线程同时写同一个共享数组
shared_buffer = np.zeros(1000)
def worker(thread_id: int):
# 同时写 shared_buffer,数据竞争!
shared_buffer[thread_id * 10: (thread_id + 1) * 10] = thread_id
threads = [threading.Thread(target=worker, args=(i,)) for i in range(100)]
for t in threads:
t.start()
for t in threads:
t.join()
# shared_buffer 的内容是不确定的
# 正确做法1:每个线程有自己的数组,结束后合并
def worker_safe(thread_id: int, results: list):
local_result = np.zeros(10)
local_result[:] = thread_id
results[thread_id] = local_result # 写 list 的不同位置,相对安全
results = [None] * 100
threads = [threading.Thread(target=worker_safe, args=(i, results)) for i in range(100)]
for t in threads:
t.start()
for t in threads:
t.join()
shared_buffer = np.concatenate(results)
# 正确做法2:用锁(但会降低并发性)
lock = threading.Lock()
def worker_with_lock(thread_id: int):
local_data = np.ones(10) * thread_id
with lock:
shared_buffer[thread_id * 10: (thread_id + 1) * 10] = local_data在 FastAPI 这类异步服务里,NumPy 的计算密集操作应该放到线程池里执行:
from fastapi import FastAPI
import asyncio
import numpy as np
from concurrent.futures import ThreadPoolExecutor
app = FastAPI()
executor = ThreadPoolExecutor(max_workers=4)
def cpu_intensive_numpy_work(data: list) -> dict:
"""CPU 密集型 NumPy 计算,在线程池里运行"""
arr = np.array(data, dtype=np.float64)
result = {
"mean": float(np.mean(arr)),
"std": float(np.std(arr)),
"quantile_95": float(np.percentile(arr, 95)),
}
return result
@app.post("/analyze")
async def analyze(data: list[float]):
# 把 CPU 密集型操作放到线程池,不阻塞事件循环
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(executor, cpu_intensive_numpy_work, data)
return result踩坑实录二:内存使用失控
现象: 一个特征提取服务,处理每个请求时内存涨一点,几个小时后 OOM。
原因: NumPy 数组的内存管理有时候不如你想象的那样即时释放。特别是视图(view)和引用的问题:
import numpy as np
# 问题代码:切片是视图,持有原数组引用
def extract_features_bad(large_matrix: np.ndarray) -> np.ndarray:
# 这个切片是 large_matrix 的视图,不是独立的拷贝
# large_matrix 的引用通过切片被隐式保留了
features = large_matrix[:100, :50]
return features # 返回出去,large_matrix 无法被 GC
# 正确做法:显式拷贝
def extract_features_good(large_matrix: np.ndarray) -> np.ndarray:
features = large_matrix[:100, :50].copy() # 拷贝,切断引用
# large_matrix 在函数结束后可以被 GC
return features另外,NumPy 的 astype() 等操作会创建新数组,要记得把旧数组显式删除:
def process_large_data(data: np.ndarray) -> np.ndarray:
# float64 转 float32,节省内存
data_f32 = data.astype(np.float32)
del data # 显式删除原数组,不等 GC
result = expensive_computation(data_f32)
del data_f32 # 用完就删
return resultSciPy 的工程化注意事项
from scipy import signal, stats, optimize
import numpy as np
# SciPy 的很多函数有边界条件,生产环境必须处理
def fit_distribution(data: np.ndarray) -> dict:
"""拟合数据分布,带完整的错误处理"""
if len(data) < 10:
raise ValueError(f"Insufficient data for fitting: {len(data)} < 10")
# 检查数据质量
if np.any(np.isnan(data)) or np.any(np.isinf(data)):
raise ValueError("Data contains NaN or Inf values")
try:
# 拟合正态分布
mu, sigma = stats.norm.fit(data)
# 验证拟合结果
if np.isnan(mu) or np.isnan(sigma) or sigma <= 0:
raise ValueError(f"Invalid fit result: mu={mu}, sigma={sigma}")
# 计算拟合优度
ks_stat, p_value = stats.kstest(data, 'norm', args=(mu, sigma))
return {
"distribution": "normal",
"mu": float(mu),
"sigma": float(sigma),
"ks_statistic": float(ks_stat),
"p_value": float(p_value),
"good_fit": p_value > 0.05, # p > 0.05 认为拟合较好
}
except Exception as e:
raise RuntimeError(f"Distribution fitting failed: {e}") from e
def filter_signal(
signal_data: np.ndarray,
cutoff_hz: float,
sample_rate_hz: float,
filter_order: int = 4,
) -> np.ndarray:
"""带验证的信号滤波"""
# 验证参数
nyquist = sample_rate_hz / 2
if cutoff_hz >= nyquist:
raise ValueError(
f"Cutoff frequency {cutoff_hz}Hz must be less than Nyquist frequency {nyquist}Hz"
)
normalized_cutoff = cutoff_hz / nyquist
try:
b, a = signal.butter(filter_order, normalized_cutoff, btype='low', analog=False)
# filtfilt 做零相位滤波,不会引入相位延迟
filtered = signal.filtfilt(b, a, signal_data)
validate_numeric_output(filtered, "filtered_signal")
return filtered
except Exception as e:
raise RuntimeError(f"Signal filtering failed: {e}") from e踩坑实录三:dtype 不一致导致的静默精度损失
现象: 一个计算结果,在 Python 里看是整数(100),但存到数据库后变成了 99.99999...,导致后续的整数比较全部失败。
原因: 中间有一步把 float64 转成了 float32,精度损失了,但没有报错,悄悄地变了。
import numpy as np
# 注意 dtype 的传播
a = np.array([100.0], dtype=np.float64)
b = np.array([3.0], dtype=np.float32)
# float64 和 float32 运算,结果是 float64
result = a / b
print(result.dtype) # float64
print(result[0]) # 33.333333333333336(精度正常)
# 但如果 a 也是 float32:
a32 = np.array([100.0], dtype=np.float32)
result32 = a32 / b
print(result32.dtype) # float32
print(result32[0]) # 33.333332(精度损失)解法: 在接口处明确 dtype,不要让 dtype 隐式传播:
def compute_ratio(numerator: np.ndarray, denominator: np.ndarray) -> np.ndarray:
"""计算比率,确保使用 float64 精度"""
# 强制转换为 float64,避免精度问题
num = numerator.astype(np.float64)
den = denominator.astype(np.float64)
return safe_divide(num, den)生产就绪的科学计算服务 checklist
| 检查项 | 要点 |
|---|---|
| 数值稳定性 | log/sqrt/除法都要处理边界,用 safe 版本 |
| NaN/Inf 检测 | 输入验证 + 输出验证,及时发现数值问题 |
| dtype 一致性 | 明确 dtype,不让 float64/float32 混用 |
| 内存管理 | 用 .copy() 切断视图引用,及时 del 大数组 |
| 并发安全 | 写操作加锁,或者每个请求用独立数组 |
| 异步兼容 | CPU 密集操作放线程池,不阻塞事件循环 |
| 边界条件 | 空数组、单元素数组、全相同值都要测试 |
科学计算代码的工程化,本质是把"研究人员假设输入总是好的"变成"工程代码假设输入可能是任何值"。 这个心态的转变,比学会任何具体技术都重要。
