pytest 参数化测试实战——@pytest.mark.parametrize 高级用法与数据驱动
pytest 参数化测试实战——@pytest.mark.parametrize 高级用法与数据驱动
适读人群:Python 测试工程师 / 关注测试代码质量的开发者 | 阅读时长:约 14 分钟 | 核心价值:掌握参数化测试的全部技巧,用数据驱动消除测试代码重复
复制粘贴 30 次的测试代码
我第一次认真审查一个同事的测试代码时,看到了这样的场景:
def test_validate_email_empty():
assert not validate_email("")
def test_validate_email_no_at():
assert not validate_email("userexample.com")
def test_validate_email_no_domain():
assert not validate_email("user@")
def test_validate_email_double_at():
assert not validate_email("user@@example.com")
# ... 还有 26 个类似的函数30 个测试函数,30 份几乎一样的代码,差别只是输入不同。
如果验证逻辑变了,要改 30 个地方。如果要加新场景,要再复制粘贴一次。
这种代码在 pytest 里可以优化成这样:
@pytest.mark.parametrize("email", [
"",
"userexample.com",
"user@",
"user@@example.com",
# ...其余 26 个
])
def test_validate_email_invalid(email):
assert not validate_email(email)30 行代码变成了一个函数 + 一个列表。
parametrize 基础语法
单参数
@pytest.mark.parametrize("value", [1, 2, 3, 100, -1, 0])
def test_is_positive_integer(value):
result = is_positive_integer(value)
expected = value > 0
assert result == expected多参数
@pytest.mark.parametrize("input_text,expected_length", [
("hello", 5),
("", 0),
("中文字符", 4),
(" spaces ", 10), # 不去除空格
])
def test_text_length(input_text, expected_length):
assert len(input_text) == expected_length带预期结果的验证测试
@pytest.mark.parametrize("price,quantity,discount_rate,expected_total", [
(100.0, 1, 0.0, 100.0), # 无折扣
(100.0, 2, 0.0, 200.0), # 数量倍数
(100.0, 1, 0.1, 90.0), # 9折
(100.0, 3, 0.2, 240.0), # 8折+数量
(99.99, 1, 0.0, 99.99), # 小数精度
(0.01, 1000, 0.0, 10.0), # 最小单价大量购买
])
def test_order_total_calculation(price, quantity, discount_rate, expected_total):
total = calculate_order_total(price, quantity, discount_rate)
assert abs(total - expected_total) < 0.001 # 允许浮点误差高级用法:pytest.param 和标记
pytest.param 允许为单个参数组设置 ID、标记(skip、xfail 等):
@pytest.mark.parametrize("input,expected", [
pytest.param("hello world", "Hello World", id="basic_title_case"),
pytest.param("ALREADY UPPER", "Already Upper", id="from_uppercase"),
pytest.param("", "", id="empty_string"),
pytest.param(
"café",
"Café",
id="unicode_chars",
marks=pytest.mark.xfail(reason="已知 bug:非 ASCII 字符处理异常,Issue #456")
),
pytest.param(
None,
None,
id="none_input",
marks=pytest.mark.skip(reason="尚未实现 None 处理,待 v2.0")
),
])
def test_title_case(input, expected):
assert title_case(input) == expected运行结果会显示有意义的 ID:
test_title_case[basic_title_case] PASSED
test_title_case[from_uppercase] PASSED
test_title_case[empty_string] PASSED
test_title_case[unicode_chars] XFAIL
test_title_case[none_input] SKIPPED数据驱动:从外部文件加载测试数据
当测试场景很多时,把数据放在代码里不方便管理。可以从 CSV、JSON、YAML 等文件加载:
从 CSV 文件
import csv
import pytest
from pathlib import Path
def load_test_cases_from_csv(filename):
"""从 CSV 文件加载测试用例数据"""
data_file = Path(__file__).parent / "test_data" / filename
test_cases = []
with open(data_file, encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
test_cases.append(pytest.param(
row["input"],
row["expected"],
id=row.get("id", row["input"][:20]),
))
return test_cases
# test_data/validation_cases.csv:
# id,input,expected
# valid_email,user@example.com,true
# empty_email,,false
# invalid_format,not-an-email,false
@pytest.mark.parametrize("email,expected_valid",
load_test_cases_from_csv("validation_cases.csv"))
def test_email_validation_from_csv(email, expected_valid):
result = validate_email(email)
assert str(result).lower() == expected_valid.lower()从 JSON 文件
import json
def load_test_scenarios(filename):
data_file = Path(__file__).parent / "test_data" / filename
with open(data_file, encoding="utf-8") as f:
scenarios = json.load(f)
return [
pytest.param(
scenario["input"],
scenario["expected"],
id=scenario["name"]
)
for scenario in scenarios
]
# test_data/pricing_scenarios.json:
# [
# {"name": "standard_discount", "input": {"price": 100, "tier": "gold"}, "expected": 80},
# {"name": "no_discount", "input": {"price": 100, "tier": "bronze"}, "expected": 100}
# ]
@pytest.mark.parametrize("pricing_input,expected_price",
load_test_scenarios("pricing_scenarios.json"))
def test_pricing_calculation(pricing_input, expected_price):
result = calculate_price(**pricing_input)
assert result == expected_price复杂参数化场景:异常测试
@pytest.mark.parametrize("invalid_input,expected_exception,expected_message", [
(None, TypeError, "输入不能为 None"),
(-1, ValueError, "数量必须大于 0"),
(1001, ValueError, "数量超过库存上限"),
("abc", TypeError, "数量必须为整数"),
])
def test_add_to_cart_invalid_quantity(invalid_input, expected_exception, expected_message):
with pytest.raises(expected_exception) as exc_info:
add_to_cart(product_id="P001", quantity=invalid_input)
assert expected_message in str(exc_info.value)参数化 fixture(indirect 用法进阶)
@pytest.fixture
def http_client(request):
"""根据参数创建不同配置的 HTTP 客户端"""
config = request.param
client = HTTPClient(
base_url=config["base_url"],
timeout=config.get("timeout", 30),
headers=config.get("headers", {}),
)
yield client
client.close()
@pytest.mark.parametrize("http_client", [
{"base_url": "http://api-v1.example.com", "timeout": 10},
{"base_url": "http://api-v2.example.com", "timeout": 20},
], indirect=True, ids=["api-v1", "api-v2"])
def test_get_user_endpoint(http_client):
"""在 v1 和 v2 API 上都运行相同的测试"""
response = http_client.get("/users/123")
assert response.status_code == 200
assert "id" in response.json()组合参数化矩阵
多个 @pytest.mark.parametrize 叠加会生成笛卡尔积:
@pytest.mark.parametrize("browser", ["chromium", "firefox", "webkit"])
@pytest.mark.parametrize("viewport", [
{"width": 375, "height": 812}, # iPhone X
{"width": 1920, "height": 1080}, # 桌面
])
@pytest.mark.parametrize("theme", ["light", "dark"])
def test_homepage_renders(browser, viewport, theme, playwright_factory):
"""
3 浏览器 × 2 视口 × 2 主题 = 12 个测试用例
"""
page = playwright_factory.create_page(browser, viewport)
page.goto(f"/?theme={theme}")
assert page.locator("body").is_visible()
# 截图对比注意:笛卡尔积可能会让测试数量爆炸,谨慎使用。
条件参数化
根据环境动态决定测试参数:
import os
import pytest
# 只在有外部 API Key 时才运行真实 API 测试
REAL_API_PARAMS = [
pytest.param("real_api", marks=pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY"),
reason="需要 OPENAI_API_KEY 环境变量"
)),
pytest.param("mock_api"),
]
@pytest.mark.parametrize("api_mode", REAL_API_PARAMS)
def test_text_generation(api_mode, make_client):
client = make_client(mode=api_mode)
result = client.generate("写一句话")
assert len(result) > 0踩坑实录
坑一:参数化 ID 包含特殊字符导致报告乱码
现象: 参数是中文或包含特殊字符时,测试报告中的 ID 显示为 URL 编码或乱码。
解法: 显式指定 ids 参数:
@pytest.mark.parametrize("message,expected", [
("你好世界", 4),
("Hello World", 11),
], ids=["chinese_4chars", "english_11chars"]) # 用 ASCII ID
def test_message_length(message, expected):
assert len(message) == expected坑二:参数化和 fixture 的 scope 不匹配
现象: 参数化测试中使用了 module scope 的 fixture,但参数化导致每次都重建 fixture。
原因: 参数化会为每个参数组生成独立的测试节点,module scope 的 fixture 在同一模块内共享,但如果参数化配合 indirect=True,每次都会调用 fixture。
解法: 仔细区分:哪些数据是"测试输入"(放在参数化里),哪些是"测试基础设施"(放在 fixture 里)。
坑三:参数化中的可变对象被意外共享
现象: 参数化列表中的 dict 对象在多个测试之间被共享修改,导致后续测试数据不对。
解法: 参数化的参数是不可变类型(字符串、数字、元组)最安全;如果是 dict,在测试中不要修改它:
# 危险:如果测试修改了 config dict,后续测试看到的是修改后的值
@pytest.mark.parametrize("config", [
{"timeout": 30, "retries": 3},
])
def test_something(config):
config["timeout"] = 60 # 修改了参数化数据!下次此参数化 case 看到的是 60
# 安全:在测试中 copy
def test_something(config):
test_config = dict(config) # 或 config.copy()
test_config["timeout"] = 60小结
参数化测试是消除测试代码重复的最有效手段。关键原则:
- 相同逻辑不同数据:凡是测试函数体一样、只有输入输出不同的场景,都应该参数化
- 用
pytest.param设置 ID:让测试报告中的失败信息有意义 - 从外部文件加载大量数据:超过 10-15 个场景时,考虑 CSV/JSON 文件
- 谨慎使用笛卡尔积:多层参数化容易产生大量用例,只在真正需要矩阵测试时使用
