Spring AI链式调用:构建复杂AI处理流水线的设计模式
Spring AI链式调用:构建复杂AI处理流水线的设计模式
那5000行"AI大泥球"代码
2025年3月,我的学员张帆在群里发了一段话,让我哭笑不得:
"老张,我做了一个合同智能处理系统,功能是这样的:先用AI提取合同关键信息,然后根据类型分类,再对高风险条款做深度分析,同时翻译成英文,最后生成摘要报告存入数据库。
现在一个Service里写了5000行代码,所有逻辑全混在一起,改一个地方要找半天,测试根本没法写,每次上需求都要从头看一遍才敢改……"
我问他:"有没有考虑用Chain模式?"
他说:"Chain是什么?"
Chain Pattern(链式模式)在传统Java开发里用得不多,但在AI应用开发中,它是让代码可维护、可复用、可测试的核心设计模式。
张帆的系统可以用这样的链来表示:
[提取关键信息] → [分类判断] → [风险分析(分支)] → [翻译(并行)] → [摘要生成] → [数据库存储]把5000行大泥球,拆成6个各自独立、可单独测试、可复用的处理节点。这就是AI Chain Pattern的价值。
先说结论(TL;DR)
| 链类型 | 适用场景 | 特点 | 复杂度 |
|---|---|---|---|
| 顺序链 | A→B→C的线性处理 | 最简单,结果逐步传递 | 低 |
| 分支链 | 根据条件走不同路径 | 类似if-else,但可配置 | 中 |
| 并行链 | 多个独立任务同时执行 | 显著减少总耗时 | 中 |
| 循环链 | Agent的迭代改进 | 有收敛条件,防无限循环 | 高 |
| 复合链 | 以上类型任意组合 | 最灵活 | 高 |
什么是AI处理链
AI处理链的核心特征:
- 每个节点都可能调用AI(也可以是纯计算步骤)
- 上一个节点的输出是下一个节点的输入
- 节点间通过上下文(Context)传递状态
- 每个节点独立、可测试、可复用
核心实现一:Chain框架基础接口设计
package com.laozhang.chain.core;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.UUID;
/**
* AI处理链的核心上下文
* 贯穿整个链的执行过程,每个节点都可以读写上下文
*/
public class ChainContext {
private final String chainId;
private final String tenantId;
private final Map<String, Object> data;
private final ChainMetrics metrics;
@SuppressWarnings("unchecked")
public <T> T get(String key) { return (T) data.get(key); }
public <T> T get(String key, T defaultValue) {
Object value = data.get(key);
return value != null ? (T) value : defaultValue;
}
public ChainContext set(String key, Object value) {
data.put(key, value);
return this;
}
public boolean has(String key) { return data.containsKey(key) && data.get(key) != null; }
public static ChainContext create(String tenantId) {
return new ChainContext(UUID.randomUUID().toString(), tenantId,
new ConcurrentHashMap<>(), new ChainMetrics());
}
public static ChainContext of(String tenantId, Map<String, Object> initialData) {
ChainContext ctx = create(tenantId);
ctx.data.putAll(initialData);
return ctx;
}
private ChainContext(String chainId, String tenantId,
Map<String, Object> data, ChainMetrics metrics) {
this.chainId = chainId;
this.tenantId = tenantId;
this.data = data;
this.metrics = metrics;
}
public String getChainId() { return chainId; }
public String getTenantId() { return tenantId; }
public ChainMetrics getMetrics() { return metrics; }
}package com.laozhang.chain.core;
import java.util.concurrent.CompletableFuture;
/**
* AI处理节点接口
* 设计原则:无状态、幂等、单一职责
*/
public interface ChainNode {
String getName();
void execute(ChainContext context);
default CompletableFuture<Void> executeAsync(ChainContext context) {
return CompletableFuture.runAsync(() -> execute(context));
}
default boolean shouldSkip(ChainContext context) { return false; }
default void onError(ChainContext context, Exception e) {
throw new ChainNodeException("节点 " + getName() + " 执行失败", e);
}
}package com.laozhang.chain.core;
import lombok.extern.slf4j.Slf4j;
import java.time.Instant;
import java.util.List;
/**
* 顺序链执行器:N1 → N2 → N3 → ... → Nn
*/
@Slf4j
public class SequentialChain {
private final List<ChainNode> nodes;
private final String chainName;
public SequentialChain(String name, List<ChainNode> nodes) {
this.chainName = name;
this.nodes = List.copyOf(nodes);
}
public String getChainName() { return chainName; }
public ChainContext execute(ChainContext context) {
log.info("开始执行链: {}, chainId={}", chainName, context.getChainId());
Instant chainStart = Instant.now();
for (ChainNode node : nodes) {
if (node.shouldSkip(context)) {
log.debug("跳过节点: {}", node.getName());
continue;
}
Instant nodeStart = Instant.now();
try {
node.execute(context);
long elapsed = Instant.now().toEpochMilli() - nodeStart.toEpochMilli();
context.getMetrics().recordNodeSuccess(node.getName(), elapsed);
log.debug("节点完成: {}, 耗时{}ms", node.getName(), elapsed);
} catch (Exception e) {
context.getMetrics().recordNodeError(node.getName());
log.error("节点失败: {}", node.getName(), e);
try {
node.onError(context, e);
} catch (Exception fe) {
throw new ChainExecutionException("链 " + chainName + " 在节点 " + node.getName() + " 处失败", fe);
}
}
}
long total = Instant.now().toEpochMilli() - chainStart.toEpochMilli();
log.info("链执行完成: {}, 总耗时{}ms", chainName, total);
return context;
}
public static Builder builder(String name) { return new Builder(name); }
public static class Builder {
private final String name;
private final java.util.ArrayList<ChainNode> nodes = new java.util.ArrayList<>();
Builder(String name) { this.name = name; }
public Builder then(ChainNode node) { nodes.add(node); return this; }
public SequentialChain build() { return new SequentialChain(name, nodes); }
}
}核心实现二:合同处理节点
信息提取节点
package com.laozhang.chain.nodes;
import com.laozhang.chain.core.ChainContext;
import com.laozhang.chain.core.ChainNode;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.stereotype.Component;
/**
* 合同关键信息提取节点
* 输入:context.get("contractText")
* 输出:context.set("extractedInfo"), context.set("contractType")
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class ContractInfoExtractionNode implements ChainNode {
private final ChatClient chatClient;
@Override
public String getName() { return "ContractInfoExtraction"; }
@Override
public void execute(ChainContext context) {
String contractText = context.get("contractText");
if (contractText == null || contractText.isBlank()) {
throw new IllegalArgumentException("合同文本不能为空");
}
String response = chatClient.prompt()
.system("""
你是合同信息提取专家。从合同文本中提取关键信息,
严格按以下JSON格式返回:
{
"contractType": "合同类型",
"parties": {"partyA": "甲方名称", "partyB": "乙方名称"},
"amount": "合同金额",
"duration": "合同期限",
"keyTerms": ["关键条款摘要,最多5条"]
}
""")
.user("请提取以下合同的关键信息:\n\n" + truncate(contractText, 8000))
.call()
.content();
ContractInfo info = parseContractInfo(response);
context.set("extractedInfo", info);
context.set("contractType", info.getContractType());
log.info("信息提取完成: chainId={}, contractType={}", context.getChainId(), info.getContractType());
}
@Override
public void onError(ChainContext context, Exception e) {
log.warn("信息提取失败,使用降级数据: chainId={}", context.getChainId());
context.set("extractedInfo", ContractInfo.defaultInstance());
context.set("extractionFailed", true);
}
private String truncate(String text, int maxLen) {
return text.length() > maxLen ? text.substring(0, maxLen) + "...[截断]" : text;
}
private ContractInfo parseContractInfo(String json) {
return new ContractInfo(); // 实际用Jackson解析
}
}风险分析节点
package com.laozhang.chain.nodes;
import com.laozhang.chain.core.ChainContext;
import com.laozhang.chain.core.ChainNode;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.stereotype.Component;
/**
* 合同风险分析节点
* 输入:contractText, extractedInfo
* 输出:riskAnalysis, riskLevel, riskScore
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class ContractRiskAnalysisNode implements ChainNode {
private final ChatClient chatClient;
@Override
public String getName() { return "ContractRiskAnalysis"; }
@Override
public void execute(ChainContext context) {
String contractText = context.get("contractText");
ContractInfo info = context.get("extractedInfo");
String response = chatClient.prompt()
.system("""
你是资深合同法律专家。分析维度:
1. 违约条款的合法性和对等性
2. 知识产权归属条款
3. 争议解决条款(仲裁机构的合法性)
4. 不可抗力条款的完整性
返回JSON:{"riskLevel":"HIGH/MEDIUM/LOW","riskScore":1-10,"riskItems":[...],"summary":"..."}
""")
.user("合同类型:" + info.getContractType() + "\n\n合同内容:\n" +
contractText.substring(0, Math.min(contractText.length(), 6000)))
.call()
.content();
RiskAnalysis analysis = parseRiskAnalysis(response);
context.set("riskAnalysis", analysis);
context.set("riskLevel", analysis.getRiskLevel());
context.set("riskScore", analysis.getRiskScore());
log.info("风险分析完成: chainId={}, riskLevel={}", context.getChainId(), analysis.getRiskLevel());
}
private RiskAnalysis parseRiskAnalysis(String response) {
return new RiskAnalysis(); // 实际用Jackson解析
}
}核心实现三:分支链
package com.laozhang.chain.core;
import lombok.extern.slf4j.Slf4j;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Predicate;
/**
* 分支链:根据条件选择不同的处理路径
*/
@Slf4j
public class BranchChain implements ChainNode {
private final String name;
private final LinkedHashMap<Predicate<ChainContext>, SequentialChain> branches;
private final SequentialChain defaultChain;
private BranchChain(String name,
LinkedHashMap<Predicate<ChainContext>, SequentialChain> branches,
SequentialChain defaultChain) {
this.name = name;
this.branches = branches;
this.defaultChain = defaultChain;
}
@Override
public String getName() { return name; }
@Override
public void execute(ChainContext context) {
for (Map.Entry<Predicate<ChainContext>, SequentialChain> entry : branches.entrySet()) {
if (entry.getKey().test(context)) {
log.info("分支链命中: {} → 执行分支 {}", name, entry.getValue().getChainName());
entry.getValue().execute(context);
return;
}
}
if (defaultChain != null) {
log.info("分支链无匹配,执行默认分支: {}", defaultChain.getChainName());
defaultChain.execute(context);
}
}
public static Builder builder(String name) { return new Builder(name); }
public static class Builder {
private final String name;
private final LinkedHashMap<Predicate<ChainContext>, SequentialChain> branches = new LinkedHashMap<>();
private SequentialChain defaultChain;
Builder(String name) { this.name = name; }
public Builder when(Predicate<ChainContext> condition, SequentialChain chain) {
branches.put(condition, chain); return this;
}
public Builder otherwise(SequentialChain chain) { this.defaultChain = chain; return this; }
public BranchChain build() { return new BranchChain(name, branches, defaultChain); }
}
}分支链使用示例
@Component
@RequiredArgsConstructor
public class ContractProcessingChainBuilder {
private final ContractInfoExtractionNode extractionNode;
private final ContractRiskAnalysisNode riskAnalysisNode;
private final HighRiskReviewNode highRiskReviewNode;
private final MediumRiskReviewNode mediumRiskReviewNode;
private final StandardReviewNode standardReviewNode;
private final MultiLanguageTranslationNode translationNode;
private final SelfImprovingSummaryNode summaryNode;
private final DatabaseStorageNode storageNode;
public SequentialChain buildContractProcessingChain() {
BranchChain riskBranch = BranchChain.builder("风险路由")
.when(ctx -> "HIGH".equals(ctx.get("riskLevel")),
SequentialChain.builder("高风险处理").then(highRiskReviewNode).build())
.when(ctx -> "MEDIUM".equals(ctx.get("riskLevel")),
SequentialChain.builder("中风险处理").then(mediumRiskReviewNode).build())
.otherwise(
SequentialChain.builder("标准处理").then(standardReviewNode).build())
.build();
return SequentialChain.builder("合同全流程处理")
.then(extractionNode)
.then(riskAnalysisNode)
.then(riskBranch)
.then(translationNode)
.then(summaryNode)
.then(storageNode)
.build();
}
}核心实现四:并行链
package com.laozhang.chain.core;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
/**
* 并行链:同时执行多个独立节点,等所有节点完成后继续
* 并行节点操作的Context Key不能重叠,否则有并发问题!
*/
@Slf4j
public class ParallelChain implements ChainNode {
private final String name;
private final List<ChainNode> parallelNodes;
private final ExecutorService executorService;
private final long timeoutSeconds;
public ParallelChain(String name, List<ChainNode> nodes,
ExecutorService executor, long timeoutSeconds) {
this.name = name;
this.parallelNodes = List.copyOf(nodes);
this.executorService = executor;
this.timeoutSeconds = timeoutSeconds;
}
@Override
public String getName() { return name; }
@Override
public void execute(ChainContext context) {
log.info("开始并行执行: {}, 节点数={}", name, parallelNodes.size());
java.time.Instant start = java.time.Instant.now();
List<CompletableFuture<Void>> futures = parallelNodes.stream()
.map(node -> CompletableFuture.runAsync(() -> {
try {
node.execute(context);
} catch (Exception e) {
log.error("并行节点失败: {}", node.getName(), e);
try { node.onError(context, e); }
catch (Exception fe) { throw new ChainNodeException(node.getName() + " 执行失败", fe); }
}
}, executorService))
.toList();
CompletableFuture<Void> allDone = CompletableFuture.allOf(
futures.toArray(new CompletableFuture[0]));
try {
allDone.get(timeoutSeconds, TimeUnit.SECONDS);
log.info("并行链完成: {}, 耗时{}ms", name,
java.time.Instant.now().toEpochMilli() - start.toEpochMilli());
} catch (java.util.concurrent.TimeoutException e) {
futures.forEach(f -> f.cancel(true));
throw new ChainExecutionException(name + " 并行执行超时", e);
} catch (Exception e) {
throw new ChainExecutionException(name + " 并行执行失败", e);
}
}
public static Builder builder(String name) { return new Builder(name); }
public static class Builder {
private final String name;
private final java.util.ArrayList<ChainNode> nodes = new java.util.ArrayList<>();
private ExecutorService executor;
private long timeoutSeconds = 120;
Builder(String name) { this.name = name; }
public Builder parallel(ChainNode... n) { nodes.addAll(java.util.Arrays.asList(n)); return this; }
public Builder executor(ExecutorService e) { this.executor = e; return this; }
public Builder timeout(long s) { this.timeoutSeconds = s; return this; }
public ParallelChain build() { return new ParallelChain(name, nodes, executor, timeoutSeconds); }
}
}核心实现五:循环链(Agent迭代改进)
package com.laozhang.chain.core;
import lombok.extern.slf4j.Slf4j;
import java.util.function.Predicate;
/**
* 循环链:重复执行节点,直到收敛条件满足或达到最大迭代次数
*/
@Slf4j
public class IterativeChain implements ChainNode {
private final String name;
private final ChainNode iterableNode;
private final Predicate<ChainContext> convergenceCondition;
private final int maxIterations;
public IterativeChain(String name, ChainNode node,
Predicate<ChainContext> convergenceCondition, int maxIterations) {
this.name = name;
this.iterableNode = node;
this.convergenceCondition = convergenceCondition;
this.maxIterations = maxIterations;
}
@Override
public String getName() { return name; }
@Override
public void execute(ChainContext context) {
int iteration = 0;
log.info("开始迭代链: {}, maxIterations={}", name, maxIterations);
while (iteration < maxIterations) {
iteration++;
context.set("iteration", iteration);
log.debug("迭代 {}/{}: {}", iteration, maxIterations, name);
iterableNode.execute(context);
if (convergenceCondition.test(context)) {
log.info("迭代链收敛: {}, 迭代次数={}", name, iteration);
context.set("converged", true);
context.set("totalIterations", iteration);
return;
}
}
log.warn("迭代链达到最大次数未收敛: {}", name);
context.set("converged", false);
context.set("totalIterations", maxIterations);
}
public static Builder builder(String name) { return new Builder(name); }
public static class Builder {
private final String name;
private ChainNode node;
private Predicate<ChainContext> convergenceCondition;
private int maxIterations = 5;
Builder(String name) { this.name = name; }
public Builder iterate(ChainNode n) { this.node = n; return this; }
public Builder until(Predicate<ChainContext> c) { this.convergenceCondition = c; return this; }
public Builder maxIterations(int max) { this.maxIterations = max; return this; }
public IterativeChain build() { return new IterativeChain(name, node, convergenceCondition, maxIterations); }
}
}循环链实战:自我改进摘要
@Component
@RequiredArgsConstructor
public class SelfImprovingSummaryNode implements ChainNode {
private final ChatClient chatClient;
@Override
public String getName() { return "SelfImprovingSummary"; }
@Override
public void execute(ChainContext context) {
ChainNode generateAndEvaluateNode = new ChainNode() {
@Override
public String getName() { return "GenerateAndEvaluate"; }
@Override
public void execute(ChainContext ctx) {
RiskAnalysis riskAnalysis = ctx.get("riskAnalysis");
String previousSummary = ctx.get("summary");
String previousFeedback = ctx.get("summaryFeedback");
int iteration = ctx.get("iteration", 1);
String prompt;
if (previousSummary != null && previousFeedback != null) {
prompt = String.format("请改进以下合同摘要。\n\n之前的摘要:%s\n\n改进意见:%s\n\n请生成改进版本,150字以内:",
previousSummary, previousFeedback);
} else {
prompt = String.format("请为以下合同生成简洁的风险摘要,150字以内。\n风险等级:%s\n风险分析摘要:%s",
riskAnalysis.getRiskLevel(), riskAnalysis.getSummary());
}
String summary = chatClient.prompt().user(prompt).call().content();
ctx.set("summary", summary);
String evaluation = chatClient.prompt()
.user(String.format("评估以下合同摘要质量(1-10分),只返回JSON:{\"score\":分数,\"feedback\":\"改进建议\"}\n\n摘要:%s", summary))
.call().content();
int score = parseScore(evaluation);
ctx.set("summaryQualityScore", score);
ctx.set("summaryFeedback", parseFeedback(evaluation));
log.info("摘要迭代 {}: 质量分数={}/10", iteration, score);
}
};
IterativeChain iterativeChain = IterativeChain.builder("摘要自我改进")
.iterate(generateAndEvaluateNode)
.until(ctx -> { Integer score = ctx.get("summaryQualityScore"); return score != null && score >= 8; })
.maxIterations(3)
.build();
iterativeChain.execute(context);
}
private int parseScore(String json) {
try {
int start = json.indexOf("\"score\":") + 8;
int end = json.indexOf(",", start);
if (end < 0) end = json.indexOf("}", start);
return Integer.parseInt(json.substring(start, end).trim());
} catch (Exception e) { return 5; }
}
private String parseFeedback(String json) {
int start = json.indexOf("\"feedback\":\"") + 12;
int end = json.lastIndexOf("\"");
return start > 11 && end > start ? json.substring(start, end) : "无具体建议";
}
}核心实现六:Spring AI原生Advisor链
package com.laozhang.chain.advisor;
import org.springframework.ai.chat.client.advisor.api.*;
import org.springframework.core.Ordered;
import lombok.extern.slf4j.Slf4j;
/**
* 内容安全检查Advisor
* 在请求发送前检查输入,在响应返回后检查输出
*/
@Slf4j
public class ContentSafetyAdvisor implements CallAroundAdvisor {
private final ContentSafetyService safetyService;
public ContentSafetyAdvisor(ContentSafetyService safetyService) {
this.safetyService = safetyService;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
// 1. 请求前检查用户输入
String userMessage = advisedRequest.userText();
SafetyCheckResult inputCheck = safetyService.check(userMessage);
if (inputCheck.isUnsafe()) {
throw new ContentSafetyException("输入内容违规: " + inputCheck.getReason());
}
// 2. 调用AI
AdvisedResponse response = chain.nextAroundCall(advisedRequest);
// 3. 响应后检查AI输出
String aiOutput = response.response().getResult().getOutput().getText();
SafetyCheckResult outputCheck = safetyService.check(aiOutput);
if (outputCheck.isUnsafe()) {
log.error("AI响应包含违禁内容: {}", outputCheck.getReason());
return buildSafeResponse(response, "AI响应内容不适合展示,请重新提问。");
}
return response;
}
@Override public String getName() { return "ContentSafetyAdvisor"; }
@Override public int getOrder() { return Ordered.HIGHEST_PRECEDENCE; }
private AdvisedResponse buildSafeResponse(AdvisedResponse original, String content) {
return original; // 简化示例
}
}/**
* Advisor链组合示例 - 多个Advisor按Order顺序执行
*/
@Component
public class AdvancedChatService {
private final ChatClient chatClient;
private final ContentSafetyAdvisor contentSafetyAdvisor; // order=-100
private final RateLimitAdvisor rateLimitAdvisor; // order=-50
private final PromptLoggingAdvisor loggingAdvisor; // order=0
private final CostTrackingAdvisor costTrackingAdvisor; // order=100
public String chat(String userId, String userMessage) {
return chatClient.prompt()
.user(userMessage)
.advisors(contentSafetyAdvisor, rateLimitAdvisor, loggingAdvisor, costTrackingAdvisor)
.call()
.content();
}
}核心实现七:链的可视化与调试
package com.laozhang.chain.debug;
import com.laozhang.chain.core.ChainContext;
import com.laozhang.chain.core.ChainMetrics;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.List;
@Slf4j
@Component
public class ChainDebugger {
public void printExecutionReport(ChainContext context) {
ChainMetrics metrics = context.getMetrics();
List<ChainMetrics.NodeRecord> records = metrics.getNodeRecords();
StringBuilder report = new StringBuilder();
report.append("\n======= Chain Execution Report =======\n");
report.append("ChainId: ").append(context.getChainId()).append("\n");
report.append("Total Nodes: ").append(records.size()).append("\n\n");
long maxLatency = records.stream().mapToLong(ChainMetrics.NodeRecord::latencyMs).max().orElse(0);
for (ChainMetrics.NodeRecord record : records) {
int barLength = (int) ((double) record.latencyMs() / maxLatency * 30);
String bar = "█".repeat(Math.max(1, barLength));
String status = record.success() ? "✓" : "✗";
report.append(String.format("%-30s %s %s %dms\n",
record.nodeName(), status, bar, record.latencyMs()));
}
report.append("\nTotal latency: ").append(metrics.getTotalLatencyMs()).append("ms");
report.append("\n=====================================\n");
log.info(report.toString());
}
}public class ChainMetrics {
private final List<NodeRecord> nodeRecords = new java.util.ArrayList<>();
public void recordNodeSuccess(String nodeName, long latencyMs) {
nodeRecords.add(new NodeRecord(nodeName, latencyMs, true, null));
}
public void recordNodeError(String nodeName) {
nodeRecords.add(new NodeRecord(nodeName, -1, false, "ERROR"));
}
public List<NodeRecord> getNodeRecords() {
return java.util.Collections.unmodifiableList(nodeRecords);
}
public long getTotalLatencyMs() {
return nodeRecords.stream().mapToLong(r -> r.latencyMs() > 0 ? r.latencyMs() : 0).sum();
}
public Map<String, Object> getContextSnapshots() { return new java.util.LinkedHashMap<>(); }
public record NodeRecord(String nodeName, long latencyMs, boolean success, String errorMessage) {}
}实战:完整合同处理服务
package com.laozhang.chain.service;
import com.laozhang.chain.core.*;
import com.laozhang.chain.debug.ChainDebugger;
import com.laozhang.chain.nodes.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.Map;
/**
* 合同智能处理服务
* 张帆的5000行大泥球,重构后的样子
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class ContractProcessingService {
private final ContractInfoExtractionNode extractionNode;
private final ContractRiskAnalysisNode riskAnalysisNode;
private final HighRiskReviewNode highRiskReviewNode;
private final MediumRiskReviewNode mediumRiskReviewNode;
private final StandardReviewNode standardReviewNode;
private final SelfImprovingSummaryNode summaryNode;
private final MultiLanguageTranslationNode translationNode;
private final DatabaseStorageNode storageNode;
private final ChainDebugger debugger;
public ContractProcessingResult process(String contractText, ProcessingOptions options) {
ChainContext context = ChainContext.of(options.getTenantId(), Map.of(
"contractText", contractText,
"requiresTranslation", options.isRequiresTranslation(),
"processingMode", options.getMode()
));
SequentialChain processingChain = buildProcessingChain(options);
try {
processingChain.execute(context);
if (options.isDebugMode()) debugger.printExecutionReport(context);
return buildResult(context);
} catch (ChainExecutionException e) {
log.error("合同处理链执行失败: chainId={}", context.getChainId(), e);
throw new ContractProcessingException("合同处理失败", e);
}
}
private SequentialChain buildProcessingChain(ProcessingOptions options) {
BranchChain riskBranch = BranchChain.builder("风险路由")
.when(ctx -> "HIGH".equals(ctx.get("riskLevel")),
SequentialChain.builder("高风险处理").then(highRiskReviewNode).build())
.when(ctx -> "MEDIUM".equals(ctx.get("riskLevel")),
SequentialChain.builder("中风险处理").then(mediumRiskReviewNode).build())
.otherwise(
SequentialChain.builder("标准处理").then(standardReviewNode).build())
.build();
SequentialChain.Builder builder = SequentialChain.builder("合同智能处理")
.then(extractionNode)
.then(riskAnalysisNode)
.then(riskBranch)
.then(summaryNode);
if (options.isRequiresTranslation()) builder.then(translationNode);
builder.then(storageNode);
return builder.build();
}
private ContractProcessingResult buildResult(ChainContext context) {
return ContractProcessingResult.builder()
.chainId(context.getChainId())
.contractType(context.get("contractType"))
.riskLevel(context.get("riskLevel"))
.riskScore(context.get("riskScore"))
.summary(context.get("summary"))
.summaryQualityScore(context.get("summaryQualityScore"))
.totalLatencyMs(context.getMetrics().getTotalLatencyMs())
.build();
}
}生产环境注意事项
带超时的链执行
public ChainContext executeWithTimeout(SequentialChain chain, ChainContext context, long timeoutSeconds) {
CompletableFuture<ChainContext> future = CompletableFuture.supplyAsync(
() -> chain.execute(context), chainExecutor);
try {
return future.get(timeoutSeconds, TimeUnit.SECONDS);
} catch (TimeoutException e) {
future.cancel(true);
throw new ChainExecutionException("链执行超时", e);
} catch (Exception e) {
throw new ChainExecutionException("链执行异常", e);
}
}踩坑1:并行链的Context线程安全 并行节点共享同一个ChainContext。规范:并行节点只写自己专属的Key(约定命名规则),避免并发写入同一Key。
踩坑2:迭代链的无限循环 迭代链必须设置maxIterations上限。超过5次迭代通常意味着质量标准设置太严,而不是需要更多迭代。
踩坑3:节点降级后链继续执行 当节点onError设置了降级值(不抛异常),链会继续。下游节点必须检查降级标志(如context.get("extractionFailed"))。
踩坑4:上下文数据膨胀 每个节点写数据,Context会越来越大。处理完毕后,清除不再需要的大型中间数据(如原始AI响应文本),只保留结构化结果。
常见问题解答
Q1:Spring AI Advisor链和自定义Chain有什么区别? A:Advisor链是Spring AI内置的,用于拦截和增强单次ChatClient调用(类似Servlet Filter)。自定义Chain是业务层编排框架,用于组织多个AI调用和非AI步骤的执行顺序。两者互补:Advisor处理横切关注点,Chain处理业务流程编排。
Q2:并行链中,一个节点失败,其他节点会被取消吗? A:取决于实现策略。当前实现中,一个节点失败触发onError降级,不会取消其他并行节点。如果需要"任一节点失败则全部取消"的语义,用CompletableFuture.anyOf()配合取消逻辑实现。
Q3:如何在链执行过程中向用户推送实时进度? A:在SequentialChain.execute()中,每个节点完成后调用progressReporter.report(chainId, progress, message)。结合SSE方案,用户可以看到"正在提取→正在分析→正在生成摘要"的实时进度。
Q4:链的配置能不能运行时动态修改? A:完全可以。把链的结构(节点列表、分支条件、最大迭代次数)存入数据库,运行时动态组装SequentialChain。结合Nacos动态配置,可以在不重启的情况下改变链的行为。
Q5:ChainContext在并行链中线程安全吗? A:ChainContext内部使用ConcurrentHashMap,单次读写是线程安全的。但"读-改-写"操作不是原子的,如果多个并行节点需要更新同一Key,要么用不同的Key,要么加锁。
Q6:如何测试单个链节点? A:这正是Chain Pattern的优势!每个节点可以独立单元测试:创建ChainContext,填入输入数据,调用node.execute(context),检查输出数据。不需要启动整个链,不需要Mock其他节点。
总结
张帆的5000行大泥球,用Chain Pattern重构后变成了6个清晰的节点,每个不超过100行,有独立测试,可以复用,可以动态组装。
这就是设计模式的价值:不是为了炫技,而是让代码能活得更久。
可操作行动清单:
代码写得好,不是让机器更快,而是让人更自信。
