RAG 的多跳推理——当答案分散在多个文档里
RAG 的多跳推理——当答案分散在多个文档里
我记得那个需求是这么来的:法务同事希望知识库能回答合同问题。
她的问题是这样的:"我们与某供应商签了三年的框架协议,上面写了违约金按合同金额的 10%。但是每次具体采购都签了补充协议,补充协议里有没有关于违约的特殊约定?如果有冲突,应该以哪个为准?"
这个问题有多复杂?它需要:
- 找到框架协议(文档 A)里的违约条款
- 找到补充协议(文档 B、C、D……)里的相关条款
- 判断两类文档之间的条款冲突
- 给出法律适用的优先级
这四个步骤,每一步都依赖上一步的结果。不可能用一次向量检索完成,因为你根本不知道要搜什么——你得先找到框架协议,才知道要去补充协议里找什么。
这就是 RAG 的多跳推理问题:答案不在单一文档里,需要从多个文档中提取信息,经过推理链条才能组合出答案。
单跳 RAG 为什么搞不定这个问题
普通的 RAG 是"一次检索 + 一次生成"的流程:用户问题 → 向量检索 → 组合上下文 → LLM 回答。
这个流程有一个隐含假设:答案所需的信息在相似度最高的那几个文档块里。
但这个假设在很多场景下是错的:
场景 1:因果链。"这个政策的执行效果如何?"——要回答这个,你得先找到政策文档,再找执行报告,再找效果评估。三个文档,有先后依赖关系。
场景 2:比较分析。"A 产品和 B 产品在退货条款上有什么差异?"——需要分别从两份合同里找退货条款,再进行比较。
场景 3:条件推断。"如果我满足了 X 条件,我可以申请 Y 吗?"——需要找到 X 的定义、Y 的申请条件,再做逻辑推断。
这些场景有个共同特点:需要多轮检索,后一轮的 Query 依赖前一轮的结果。
三种多跳 RAG 的实现思路
思路 1:迭代式检索
最直观:先用原始问题检索一次,把结果喂给 LLM,让 LLM 判断"还需要什么信息",再检索,如此循环,直到 LLM 觉得信息足够了。
思路 2:Sub-question 分解
在检索之前,先让 LLM 把复杂问题分解成多个独立的子问题,然后分别检索每个子问题,最后汇总所有答案让 LLM 综合回答。
思路 3:Chain-of-Thought RAG
让 LLM 先生成推理步骤(不使用外部知识),然后根据每个推理步骤需要什么信息,逐步检索。
Spring AI 实现多跳 RAG
我来展示一个实际可用的实现,把迭代式检索和 Sub-question 分解结合起来:
首先是 Sub-question 分解器:
@Component
@Slf4j
public class SubQuestionPlanner {
private final ChatClient chatClient;
private static final String PLAN_PROMPT = """
你是一个复杂问题分析专家。
用户问题:{question}
请分析这个问题,将其分解为需要独立检索的子问题序列。
注意:
1. 子问题之间可能有依赖关系(后一个依赖前一个的结果)
2. 标注依赖关系:如果子问题 B 依赖子问题 A 的结果,用 [DEPENDS_ON: A] 标注
3. 每个子问题应该足够具体,能独立检索
输出格式(JSON):
{
"sub_questions": [
{"id": "q1", "question": "...", "depends_on": []},
{"id": "q2", "question": "...", "depends_on": ["q1"]},
{"id": "q3", "question": "...", "depends_on": ["q1", "q2"]}
]
}
用户问题:{question}
""";
public SubQuestionPlan plan(String question) {
try {
String response = chatClient.prompt()
.user(u -> u.text(PLAN_PROMPT).param("question", question))
.call()
.content();
return parseSubQuestionPlan(response);
} catch (Exception e) {
log.error("Sub-question planning failed", e);
// 降级:把原始问题作为单个子问题
SubQuestion single = new SubQuestion("q1", question, Collections.emptyList());
return new SubQuestionPlan(List.of(single));
}
}
private SubQuestionPlan parseSubQuestionPlan(String jsonResponse) {
// 提取 JSON 内容(LLM 有时会在 JSON 前后加说明文字)
String json = extractJson(jsonResponse);
ObjectMapper mapper = new ObjectMapper();
try {
JsonNode root = mapper.readTree(json);
JsonNode subQuestionsNode = root.get("sub_questions");
List<SubQuestion> subQuestions = new ArrayList<>();
for (JsonNode node : subQuestionsNode) {
String id = node.get("id").asText();
String question = node.get("question").asText();
List<String> dependsOn = new ArrayList<>();
JsonNode depsNode = node.get("depends_on");
if (depsNode != null && depsNode.isArray()) {
depsNode.forEach(dep -> dependsOn.add(dep.asText()));
}
subQuestions.add(new SubQuestion(id, question, dependsOn));
}
return new SubQuestionPlan(subQuestions);
} catch (Exception e) {
log.warn("Failed to parse sub-question plan JSON", e);
return new SubQuestionPlan(List.of(new SubQuestion("q1", jsonResponse, Collections.emptyList())));
}
}
private String extractJson(String text) {
int start = text.indexOf('{');
int end = text.lastIndexOf('}');
if (start >= 0 && end > start) {
return text.substring(start, end + 1);
}
return text;
}
@Data
@AllArgsConstructor
public static class SubQuestion {
private String id;
private String question;
private List<String> dependsOn;
}
@Data
@AllArgsConstructor
public static class SubQuestionPlan {
private List<SubQuestion> subQuestions;
public boolean hasOnlyOneQuestion() {
return subQuestions.size() <= 1;
}
}
}多跳检索执行引擎:
@Service
@Slf4j
public class MultiHopRagService {
private final SubQuestionPlanner planner;
private final VectorStore vectorStore;
private final ChatClient chatClient;
private static final int TOP_K_PER_SUBQUESTION = 5;
private static final int MAX_HOPS = 4; // 防止无限循环
private static final String ANSWER_SUBQUESTION_PROMPT = """
基于以下参考文档,回答子问题。
如果文档中没有足够信息,请明确说明"文档中未找到相关信息",不要编造。
子问题:{sub_question}
参考文档:
{context}
回答:
""";
private static final String SYNTHESIZE_PROMPT = """
基于以下各子问题的回答,综合回答用户的原始问题。
原始问题:{original_question}
子问题及回答:
{sub_answers}
请给出综合性的最终回答,逻辑清晰,有理有据:
""";
public MultiHopRagService(SubQuestionPlanner planner,
VectorStore vectorStore,
ChatClient.Builder builder) {
this.planner = planner;
this.vectorStore = vectorStore;
this.chatClient = builder.build();
}
public MultiHopRagResult query(String originalQuestion) {
log.info("Starting multi-hop RAG for: {}", originalQuestion);
// Step 1: 分解问题
SubQuestionPlanner.SubQuestionPlan plan = planner.plan(originalQuestion);
log.info("Plan: {} sub-questions", plan.getSubQuestions().size());
// 如果是简单问题,退化为普通 RAG
if (plan.hasOnlyOneQuestion()) {
return singleHopFallback(originalQuestion);
}
// Step 2: 按依赖顺序执行子问题
Map<String, String> subAnswers = new LinkedHashMap<>();
Map<String, List<Document>> subDocuments = new HashMap<>();
int hopCount = 0;
for (SubQuestionPlanner.SubQuestion subQ : plan.getSubQuestions()) {
if (hopCount >= MAX_HOPS) {
log.warn("Reached max hops limit, stopping at {} hops", MAX_HOPS);
break;
}
// 如果有依赖,把依赖的答案注入到当前子问题里
String enrichedQuestion = enrichWithDependencies(
subQ.getQuestion(),
subQ.getDependsOn(),
subAnswers
);
// 检索
List<Document> docs = vectorStore.similaritySearch(
SearchRequest.query(enrichedQuestion).withTopK(TOP_K_PER_SUBQUESTION)
);
subDocuments.put(subQ.getId(), docs);
// 用 LLM 回答子问题
String subAnswer = answerSubQuestion(enrichedQuestion, docs);
subAnswers.put(subQ.getId(), subAnswer);
log.debug("Sub-question {} answered: {}", subQ.getId(),
subAnswer.substring(0, Math.min(100, subAnswer.length())));
hopCount++;
}
// Step 3: 综合所有子答案
String finalAnswer = synthesize(originalQuestion, subAnswers);
// 收集所有引用文档
List<Document> allDocuments = subDocuments.values().stream()
.flatMap(Collection::stream)
.distinct()
.collect(Collectors.toList());
return MultiHopRagResult.builder()
.originalQuestion(originalQuestion)
.subQuestions(plan.getSubQuestions())
.subAnswers(subAnswers)
.finalAnswer(finalAnswer)
.referencedDocuments(allDocuments)
.hopCount(hopCount)
.build();
}
/**
* 把依赖问题的答案注入到当前问题里
* 例如:"根据前面找到的框架协议信息:XXX,现在请查找补充协议中..."
*/
private String enrichWithDependencies(
String question,
List<String> dependsOn,
Map<String, String> subAnswers) {
if (dependsOn.isEmpty()) return question;
StringBuilder enriched = new StringBuilder();
enriched.append("基于以下已知信息:\n");
for (String depId : dependsOn) {
String depAnswer = subAnswers.get(depId);
if (depAnswer != null) {
enriched.append(String.format("- %s\n", depAnswer));
}
}
enriched.append("\n现在请回答:").append(question);
return enriched.toString();
}
private String answerSubQuestion(String subQuestion, List<Document> docs) {
if (docs.isEmpty()) {
return "未在知识库中找到相关信息。";
}
String context = docs.stream()
.map(Document::getContent)
.collect(Collectors.joining("\n\n---\n\n"));
return chatClient.prompt()
.user(u -> u.text(ANSWER_SUBQUESTION_PROMPT)
.param("sub_question", subQuestion)
.param("context", context))
.call()
.content();
}
private String synthesize(String originalQuestion, Map<String, String> subAnswers) {
String subAnswersText = subAnswers.entrySet().stream()
.map(e -> String.format("问题 %s 的答案:\n%s", e.getKey(), e.getValue()))
.collect(Collectors.joining("\n\n"));
return chatClient.prompt()
.user(u -> u.text(SYNTHESIZE_PROMPT)
.param("original_question", originalQuestion)
.param("sub_answers", subAnswersText))
.call()
.content();
}
private MultiHopRagResult singleHopFallback(String question) {
List<Document> docs = vectorStore.similaritySearch(
SearchRequest.query(question).withTopK(5)
);
String answer = answerSubQuestion(question, docs);
return MultiHopRagResult.builder()
.originalQuestion(question)
.subQuestions(Collections.emptyList())
.subAnswers(Map.of("q1", answer))
.finalAnswer(answer)
.referencedDocuments(docs)
.hopCount(1)
.build();
}
}结果对象:
@Data
@Builder
public class MultiHopRagResult {
private String originalQuestion;
private List<SubQuestionPlanner.SubQuestion> subQuestions;
private Map<String, String> subAnswers;
private String finalAnswer;
private List<Document> referencedDocuments;
private int hopCount;
/**
* 生成引用来源信息,方便前端展示
*/
public List<String> getSourceFiles() {
return referencedDocuments.stream()
.map(doc -> (String) doc.getMetadata().getOrDefault("source_file", "未知来源"))
.distinct()
.collect(Collectors.toList());
}
/**
* 获取推理链条的文字描述,适合前端展示"思考过程"
*/
public String getReasoningChain() {
if (subQuestions.isEmpty()) return "直接检索回答";
StringBuilder chain = new StringBuilder("推理步骤:\n");
for (SubQuestionPlanner.SubQuestion sq : subQuestions) {
chain.append(String.format(" - %s: %s\n", sq.getId(), sq.getQuestion()));
String answer = subAnswers.get(sq.getId());
if (answer != null) {
chain.append(String.format(" → %s\n",
answer.substring(0, Math.min(80, answer.length())) + "..."));
}
}
return chain.toString();
}
}控制迭代深度和成本
多跳 RAG 的最大风险是失控——每一跳都有 LLM 调用,如果问题太复杂,跑 8-10 跳,成本和延迟都会爆炸。
我们的控制策略:
@Component
@Slf4j
public class MultiHopCostController {
// 最大跳数
private static final int MAX_HOPS = 4;
// 问题复杂度阈值,超过才走多跳流程
private static final int COMPLEXITY_THRESHOLD = 2;
private final ChatClient chatClient;
/**
* 评估问题复杂度,决定是否需要多跳处理
* 返回 1-5 的复杂度评分
*/
public int assessComplexity(String question) {
String prompt = """
评估以下问题的复杂度,判断它需要几跳推理才能回答(1=单文档可回答,5=需要多文档复杂推理)。
只输出一个 1-5 的数字。
问题:%s
""".formatted(question);
try {
String response = chatClient.prompt()
.user(prompt)
.call()
.content()
.trim();
int score = Integer.parseInt(response.replaceAll("[^1-5]", "").substring(0, 1));
log.debug("Complexity assessment for '{}': {}", question, score);
return score;
} catch (Exception e) {
log.warn("Complexity assessment failed, defaulting to 3");
return 3;
}
}
public boolean needsMultiHop(String question) {
return assessComplexity(question) >= COMPLEXITY_THRESHOLD;
}
}在 Controller 层做路由:
@Service
@Slf4j
public class SmartRagRouter {
private final MultiHopCostController costController;
private final MultiHopRagService multiHopService;
private final SimpleRagService simpleRagService;
public RagResponse query(String question, UserContext userContext) {
if (costController.needsMultiHop(question)) {
log.info("Routing to multi-hop RAG: {}", question);
MultiHopRagResult result = multiHopService.query(question);
return RagResponse.fromMultiHop(result);
} else {
log.info("Routing to simple RAG: {}", question);
return simpleRagService.query(question, userContext);
}
}
}真实效果验证
我拿 50 个需要多文档综合分析的法务问题测试了一下:
| 方案 | 完全正确率 | 部分正确率 | 平均跳数 | 平均耗时 |
|---|---|---|---|---|
| 普通 RAG | 18% | 42% | 1 | 0.8s |
| 多跳 RAG(2跳) | 51% | 31% | 2 | 4.2s |
| 多跳 RAG(3-4跳) | 73% | 18% | 3.2 | 9.8s |
对于这类需要多文档综合推理的问题,多跳 RAG 的完全正确率从 18% 提升到了 73%,代价是耗时从 0.8 秒增加到近 10 秒。
10 秒对于合同分析场景完全可以接受,法务同事愿意等。但如果是普通问答场景,这个延迟是不可接受的,一定要做好路由判断,不能所有问题都走多跳。
适用场景和局限性
适用的场景:
- 合同条款对比和分析
- 跨部门政策联动查询
- 需要综合多个来源才能判断的合规问题
- 历史决策溯源(需要找到当时的背景文档再找决策文档)
不适用的场景:
- 简单的事实查询("我们公司年假是几天?")
- 实时对话要求快速响应
- 文档库内容高度一致,单次检索足够的场景
一个重要局限: 多跳推理的质量上限取决于子问题分解的质量。如果 LLM 把问题分解错了(这种情况确实会发生),后面的跳再多也没用。所以生产环境里需要有人工监控,定期抽查多跳结果是否合理。
总结
多跳 RAG 不是银弹,它增加了系统复杂度、延迟和成本。但对于确实需要多文档综合推理的场景,它是必要的。
关键是做好路由:简单问题走简单 RAG,复杂问题走多跳。不要用大炮打蚊子,也不要用弹弓打坦克。
