第1714篇:AI提示词的单元测试框架设计——可重复验证Prompt效果的工程方案
第1714篇:AI提示词的单元测试框架设计——可重复验证Prompt效果的工程方案
Prompt工程是个很有意思的存在:每个人都在写,但很少有人把它当代码来管理。我见过的大多数团队,Prompt是散落在配置文件里的字符串,改了也不知道对不对,对了也说不清楚为什么对。遇到效果下降,翻提交历史一看,三个月前某人改了某个词,但当时没人发现问题……
这篇文章讲一个方向:把Prompt当代码来管理,给它写单元测试。
一、Prompt为什么需要单元测试
先把问题说清楚。Prompt的单元测试和普通代码的单元测试有本质区别:
普通代码单元测试:给定输入,验证确定性输出。 Prompt单元测试:给定输入,验证输出满足某些属性或约束。
你不可能说"这段Prompt的输出必须是XXX",但你可以说"这段Prompt的输出必须满足以下条件":
- 输出是有效的JSON
- JSON里包含
sentiment字段 sentiment的值是positive、negative或neutral之一- 输出不包含代码块标记(```)
- 输出长度在100-500字之间
这些是可以程序化验证的。
另外,AI输出虽然有随机性,但对于设计良好的Prompt,在固定的Temperature参数下,输出应该有高度的结构一致性。我们测试的不是具体内容,而是结构和约束的稳定性。
二、框架设计思路
一个完整的Prompt单元测试框架需要这些组件:
核心设计原则:
- 可录制/回放:对真实LLM的调用结果可以录制,后续测试用录制结果回放,不每次都真实调用
- 规则组合:验证规则可以自由组合,像搭积木一样
- 失败报告清晰:告诉你哪条规则失败了,输出是什么,期望是什么
三、核心数据结构
// 提示词模板
@Data
@Builder
public class PromptTemplate {
private String id;
private String version;
private String systemPrompt;
private String userPromptTemplate; // 支持 {{variable}} 占位符
private Map<String, Object> defaultParameters;
private PromptMetadata metadata;
}
@Data
@Builder
public class PromptMetadata {
private String author;
private String purpose;
private List<String> tags;
private String createdAt;
private String lastModifiedAt;
}
// 提示词测试用例
@Data
@Builder
public class PromptTestCase {
private String testId;
private String description;
private Map<String, Object> inputVariables;
private List<ValidationRule> validationRules;
private TestMode mode; // REAL / MOCK / PLAYBACK
private String recordingId; // 回放时使用的录制ID
}
// 验证规则接口
public interface ValidationRule {
String getRuleName();
ValidationResult validate(String output, Map<String, Object> context);
}
// 验证结果
@Data
@Builder
public class ValidationResult {
private boolean passed;
private String ruleName;
private String message;
private String actualValue;
private String expectedConstraint;
}
// 测试结果
@Data
@Builder
public class PromptTestResult {
private String testId;
private boolean passed;
private String output;
private List<ValidationResult> validationResults;
private long executionTimeMs;
private String errorMessage;
}四、内置验证规则实现
// 规则1:JSON结构验证
public class JsonStructureRule implements ValidationRule {
private final List<String> requiredFields;
private final ObjectMapper mapper = new ObjectMapper();
public JsonStructureRule(String... fields) {
this.requiredFields = List.of(fields);
}
@Override
public String getRuleName() {
return "json-structure";
}
@Override
public ValidationResult validate(String output, Map<String, Object> context) {
try {
JsonNode root = mapper.readTree(output);
List<String> missingFields = requiredFields.stream()
.filter(field -> root.get(field) == null)
.collect(Collectors.toList());
if (missingFields.isEmpty()) {
return ValidationResult.builder()
.passed(true)
.ruleName(getRuleName())
.message("JSON结构验证通过")
.build();
} else {
return ValidationResult.builder()
.passed(false)
.ruleName(getRuleName())
.message("缺少必填字段:" + missingFields)
.actualValue(output)
.expectedConstraint("包含字段:" + requiredFields)
.build();
}
} catch (Exception e) {
return ValidationResult.builder()
.passed(false)
.ruleName(getRuleName())
.message("输出不是有效的JSON:" + e.getMessage())
.actualValue(output)
.expectedConstraint("有效的JSON格式")
.build();
}
}
}
// 规则2:枚举值验证
public class EnumValueRule implements ValidationRule {
private final String fieldPath;
private final Set<String> allowedValues;
private final ObjectMapper mapper = new ObjectMapper();
public EnumValueRule(String fieldPath, String... values) {
this.fieldPath = fieldPath;
this.allowedValues = new HashSet<>(Arrays.asList(values));
}
@Override
public String getRuleName() {
return "enum-value:" + fieldPath;
}
@Override
public ValidationResult validate(String output, Map<String, Object> context) {
try {
JsonNode root = mapper.readTree(output);
JsonNode fieldNode = navigatePath(root, fieldPath);
if (fieldNode == null || fieldNode.isNull()) {
return fail("字段不存在:" + fieldPath, output);
}
String actualValue = fieldNode.asText();
if (allowedValues.contains(actualValue)) {
return pass();
} else {
return ValidationResult.builder()
.passed(false)
.ruleName(getRuleName())
.message(String.format("字段[%s]的值[%s]不在允许范围内", fieldPath, actualValue))
.actualValue(actualValue)
.expectedConstraint("允许的值:" + allowedValues)
.build();
}
} catch (Exception e) {
return fail("验证过程发生异常:" + e.getMessage(), output);
}
}
private JsonNode navigatePath(JsonNode root, String path) {
JsonNode current = root;
for (String segment : path.split("\\.")) {
if (current == null) return null;
current = current.get(segment);
}
return current;
}
private ValidationResult pass() {
return ValidationResult.builder().passed(true).ruleName(getRuleName()).build();
}
private ValidationResult fail(String msg, String actual) {
return ValidationResult.builder().passed(false).ruleName(getRuleName())
.message(msg).actualValue(actual).build();
}
}
// 规则3:长度约束
public class LengthConstraintRule implements ValidationRule {
private final int minLength;
private final int maxLength;
public LengthConstraintRule(int min, int max) {
this.minLength = min;
this.maxLength = max;
}
@Override
public String getRuleName() {
return "length-constraint";
}
@Override
public ValidationResult validate(String output, Map<String, Object> context) {
int len = output == null ? 0 : output.length();
boolean passed = len >= minLength && len <= maxLength;
return ValidationResult.builder()
.passed(passed)
.ruleName(getRuleName())
.message(passed ? "长度验证通过" :
String.format("长度%d不在[%d, %d]范围内", len, minLength, maxLength))
.actualValue(String.valueOf(len))
.expectedConstraint(String.format("[%d, %d]", minLength, maxLength))
.build();
}
}
// 规则4:关键词包含/排除
public class KeywordRule implements ValidationRule {
private final List<String> mustContain;
private final List<String> mustNotContain;
public KeywordRule(List<String> must, List<String> mustNot) {
this.mustContain = must != null ? must : List.of();
this.mustNotContain = mustNot != null ? mustNot : List.of();
}
@Override
public String getRuleName() {
return "keyword-constraint";
}
@Override
public ValidationResult validate(String output, Map<String, Object> context) {
List<String> missing = mustContain.stream()
.filter(kw -> !output.contains(kw))
.collect(Collectors.toList());
List<String> forbidden = mustNotContain.stream()
.filter(output::contains)
.collect(Collectors.toList());
if (missing.isEmpty() && forbidden.isEmpty()) {
return ValidationResult.builder().passed(true).ruleName(getRuleName()).build();
}
String message = "";
if (!missing.isEmpty()) message += "缺少必要关键词:" + missing + " ";
if (!forbidden.isEmpty()) message += "包含禁止关键词:" + forbidden;
return ValidationResult.builder()
.passed(false)
.ruleName(getRuleName())
.message(message.trim())
.actualValue(output)
.build();
}
}
// 规则5:正则匹配
public class RegexRule implements ValidationRule {
private final String pattern;
private final boolean shouldMatch;
public RegexRule(String pattern, boolean shouldMatch) {
this.pattern = pattern;
this.shouldMatch = shouldMatch;
}
@Override
public String getRuleName() {
return "regex:" + pattern;
}
@Override
public ValidationResult validate(String output, Map<String, Object> context) {
boolean matched = output != null && output.matches("(?s).*" + pattern + ".*");
boolean passed = shouldMatch == matched;
String message = shouldMatch
? (passed ? "正则匹配成功" : "输出不匹配模式:" + pattern)
: (passed ? "正确排除了禁止模式" : "输出包含了禁止模式:" + pattern);
return ValidationResult.builder()
.passed(passed)
.ruleName(getRuleName())
.message(message)
.build();
}
}五、测试执行引擎
@Service
public class PromptTestRunner {
private final LlmExecutor realLlmExecutor;
private final RecordingStore recordingStore;
private final PromptRenderer renderer;
public PromptTestResult run(PromptTemplate template, PromptTestCase testCase) {
long startTime = System.currentTimeMillis();
try {
// 渲染Prompt
String renderedPrompt = renderer.render(
template.getUserPromptTemplate(),
testCase.getInputVariables()
);
// 执行LLM调用(根据模式选择执行器)
String output = executeByMode(
testCase, template.getSystemPrompt(), renderedPrompt
);
// 运行所有验证规则
List<ValidationResult> results = testCase.getValidationRules().stream()
.map(rule -> rule.validate(output, testCase.getInputVariables()))
.collect(Collectors.toList());
boolean allPassed = results.stream().allMatch(ValidationResult::isPassed);
return PromptTestResult.builder()
.testId(testCase.getTestId())
.passed(allPassed)
.output(output)
.validationResults(results)
.executionTimeMs(System.currentTimeMillis() - startTime)
.build();
} catch (Exception e) {
return PromptTestResult.builder()
.testId(testCase.getTestId())
.passed(false)
.errorMessage(e.getMessage())
.executionTimeMs(System.currentTimeMillis() - startTime)
.build();
}
}
private String executeByMode(PromptTestCase testCase, String system, String user) {
return switch (testCase.getMode()) {
case REAL -> realLlmExecutor.execute(system, user);
case MOCK -> generateMockResponse(testCase);
case PLAYBACK -> recordingStore.getRecording(testCase.getRecordingId());
};
}
// 批量运行并生成报告
public PromptTestSuiteResult runSuite(PromptTemplate template,
List<PromptTestCase> testCases) {
List<PromptTestResult> results = testCases.parallelStream()
.map(tc -> run(template, tc))
.collect(Collectors.toList());
long passed = results.stream().filter(PromptTestResult::isPassed).count();
return PromptTestSuiteResult.builder()
.templateId(template.getId())
.templateVersion(template.getVersion())
.totalTests(testCases.size())
.passedTests((int) passed)
.failedTests((int)(testCases.size() - passed))
.results(results)
.build();
}
}六、实战:情感分析Prompt的测试套件
现在来写一个完整的测试套件:
@SpringBootTest
class SentimentPromptTestSuite {
@Autowired
private PromptTestRunner testRunner;
@Autowired
private PromptTemplateRepository templateRepo;
// 定义通用验证规则
private static final List<ValidationRule> STANDARD_SENTIMENT_RULES = List.of(
new JsonStructureRule("sentiment", "score", "reasoning"),
new EnumValueRule("sentiment", "positive", "negative", "neutral"),
new JsonFieldTypeRule("score", JsonNodeType.NUMBER),
new NumberRangeRule("score", 0.0, 1.0),
new LengthConstraintRule("reasoning", 10, 500),
// 确保没有markdown代码块
new RegexRule("```", false)
);
@Test
void testPositiveTextAnalysis() {
PromptTemplate template = templateRepo.findByIdAndVersion(
"sentiment-analysis", "v2.1"
);
PromptTestCase testCase = PromptTestCase.builder()
.testId("positive-review-basic")
.description("典型正面评价应该被识别为positive")
.inputVariables(Map.of(
"text", "这款耳机音质真的太棒了,买过最值的东西!",
"language", "zh"
))
.validationRules(concat(
STANDARD_SENTIMENT_RULES,
List.of(
new EnumValueRule("sentiment", "positive"), // 必须是positive
new NumberMinRule("score", 0.7) // 得分要高
)
))
.mode(TestMode.PLAYBACK)
.recordingId("positive-review-basic-v1")
.build();
PromptTestResult result = testRunner.run(template, testCase);
assertPromptTestPassed(result);
}
@Test
void testNegativeTextAnalysis() {
PromptTemplate template = templateRepo.findByIdAndVersion(
"sentiment-analysis", "v2.1"
);
PromptTestCase testCase = PromptTestCase.builder()
.testId("negative-review-basic")
.description("典型负面评价应该被识别为negative")
.inputVariables(Map.of(
"text", "完全是骗人的,质量差到爆,浪费钱!",
"language", "zh"
))
.validationRules(concat(
STANDARD_SENTIMENT_RULES,
List.of(
new EnumValueRule("sentiment", "negative"),
new NumberMaxRule("score", 0.4)
)
))
.mode(TestMode.PLAYBACK)
.recordingId("negative-review-basic-v1")
.build();
PromptTestResult result = testRunner.run(template, testCase);
assertPromptTestPassed(result);
}
@Test
void testEdgeCaseEmptyText() {
PromptTemplate template = templateRepo.findByIdAndVersion(
"sentiment-analysis", "v2.1"
);
PromptTestCase testCase = PromptTestCase.builder()
.testId("empty-text-edge-case")
.description("空文本应该优雅降级,不崩溃")
.inputVariables(Map.of("text", "", "language", "zh"))
.validationRules(List.of(
// 空文本时,要么返回error字段,要么返回neutral
new AnyOfRule(
new JsonStructureRule("error"),
new AllOfRule(
new JsonStructureRule("sentiment"),
new EnumValueRule("sentiment", "neutral")
)
)
))
.mode(TestMode.REAL) // 边界情况用真实调用
.build();
PromptTestResult result = testRunner.run(template, testCase);
assertPromptTestPassed(result);
}
// 回归测试:验证Prompt修改后历史用例不退化
@ParameterizedTest
@MethodSource("regressionTestCases")
void regressionTest(PromptTestCase testCase) {
PromptTemplate template = templateRepo.findByIdAndVersion(
"sentiment-analysis", "v2.1"
);
PromptTestResult result = testRunner.run(template, testCase);
if (!result.isPassed()) {
String failureReport = buildFailureReport(result);
fail("回归测试失败:\n" + failureReport);
}
}
static Stream<PromptTestCase> regressionTestCases() {
// 从录制的测试用例库加载
return RegressionTestCaseLoader.loadAll("sentiment-analysis");
}
private void assertPromptTestPassed(PromptTestResult result) {
if (!result.isPassed()) {
String report = buildFailureReport(result);
fail("Prompt测试失败:\n" + report);
}
}
private String buildFailureReport(PromptTestResult result) {
StringBuilder sb = new StringBuilder();
sb.append("测试ID: ").append(result.getTestId()).append("\n");
sb.append("实际输出: ").append(result.getOutput()).append("\n");
sb.append("失败规则:\n");
result.getValidationResults().stream()
.filter(vr -> !vr.isPassed())
.forEach(vr -> {
sb.append(" - ").append(vr.getRuleName()).append(": ")
.append(vr.getMessage()).append("\n");
if (vr.getActualValue() != null) {
sb.append(" 实际值: ").append(vr.getActualValue()).append("\n");
}
if (vr.getExpectedConstraint() != null) {
sb.append(" 期望: ").append(vr.getExpectedConstraint()).append("\n");
}
});
return sb.toString();
}
}七、录制与回放机制
真实LLM调用成本高、速度慢,所以要引入录制回放机制:
@Component
public class LlmRecordingExecutor implements LlmExecutor {
private final LlmClient realClient;
private final RecordingStore store;
@Override
public String execute(String systemPrompt, String userPrompt) {
String recordingKey = generateKey(systemPrompt, userPrompt);
// 先看有没有录制
Optional<String> recording = store.find(recordingKey);
if (recording.isPresent()) {
return recording.get();
}
// 没有就真实调用,然后录制
String response = realClient.complete(systemPrompt, userPrompt);
store.save(recordingKey, response);
return response;
}
private String generateKey(String systemPrompt, String userPrompt) {
// 用内容的hash作为key
String combined = systemPrompt + "|||" + userPrompt;
return DigestUtils.md5DigestAsHex(combined.getBytes());
}
}
// 录制存储(可以存文件或Redis)
@Repository
public class FileBasedRecordingStore implements RecordingStore {
private static final Path RECORDINGS_DIR = Paths.get("src/test/resources/recordings");
@Override
public Optional<String> find(String recordingId) {
Path file = RECORDINGS_DIR.resolve(recordingId + ".json");
if (Files.exists(file)) {
try {
return Optional.of(Files.readString(file));
} catch (IOException e) {
return Optional.empty();
}
}
return Optional.empty();
}
@Override
public void save(String recordingId, String response) {
try {
Files.createDirectories(RECORDINGS_DIR);
Files.writeString(
RECORDINGS_DIR.resolve(recordingId + ".json"),
response
);
} catch (IOException e) {
throw new RuntimeException("保存录制失败", e);
}
}
}八、Prompt版本管理与测试
Prompt要像代码一样做版本管理:
// Prompt版本对比测试
@Test
void testPromptVersionMigration() {
PromptTemplate oldVersion = templateRepo.findByIdAndVersion("sentiment-analysis", "v2.0");
PromptTemplate newVersion = templateRepo.findByIdAndVersion("sentiment-analysis", "v2.1");
List<PromptTestCase> regressionCases = loadRegressionCases("sentiment-analysis");
PromptTestSuiteResult oldResults = testRunner.runSuite(oldVersion, regressionCases);
PromptTestSuiteResult newResults = testRunner.runSuite(newVersion, regressionCases);
// 新版本的通过率不能低于旧版本
assertThat(newResults.getPassRate())
.as("新版本Prompt通过率不能低于旧版本")
.isGreaterThanOrEqualTo(oldResults.getPassRate());
// 打印差异报告
printDiffReport(oldResults, newResults);
}九、踩坑:Prompt测试的真实问题
问题1:Temperature导致的不稳定
即使是PLAYBACK模式,同一个Prompt用不同Temperature也会产生不同结构。解决方法:在测试环境强制设置Temperature=0(或接近0),提高结构稳定性。
问题2:模型版本升级静默破坏
第三方LLM悄悄升了版本,同样的Prompt输出结构变了。解决方法:把模型版本也纳入录制key,在CI里定期(每周一次)用真实LLM更新录制文件,检测漂移。
问题3:测试用例维护成本
随着Prompt的迭代,测试用例也需要更新,否则会有大量误报。解决方法:把测试用例按优先级分层,核心的精细维护,边缘的粗放管理。
总结
Prompt的单元测试框架不是什么高科技,但它解决了一个真实的工程问题:如何让Prompt的改动可量化、可验证、可回溯。
框架的核心价值在于三点:
- 把模糊的"感觉效果变好了"变成可量化的通过率
- 引入录制回放,把LLM测试的成本降低到可接受范围
- 通过版本对比,确保每次Prompt升级不退化
