第2136篇:Reranking实战——让RAG检索结果从"相关"变"精准"
第2136篇:Reranking实战——让RAG检索结果从"相关"变"精准"
适读人群:需要提升RAG精确度的工程师 | 阅读时长:约18分钟 | 核心价值:掌握Cross-Encoder重排序的工程实现,通过两阶段检索显著提升RAG回答质量
向量检索有一个先天的精度问题:它使用双编码器(Bi-Encoder)独立编码查询和文档,然后计算相似度。这种方式速度快,但"查询"和"文档"之间没有交互,容易把"相关但不准确"的文档排在前面。
Reranking(重排序)用Cross-Encoder模型对候选结果做精排:查询和每个文档拼在一起,模型同时看到两者,给出更精确的相关性分数。代价是速度慢(不能批量索引),但精度高。
这是RAG系统提升质量性价比最高的改进之一。
Bi-Encoder vs Cross-Encoder
/**
* 两种检索模型的对比
*
* ===== Bi-Encoder(向量检索)=====
*
* 工作方式:
* - 文档和查询分别独立编码成向量
* - 用余弦相似度比较
*
* 优点:
* - 文档向量可以预计算(离线索引)
* - 查询时只需要计算查询向量,然后ANN搜索
* - 速度极快:毫秒级,支持千万级文档库
*
* 缺点:
* - 查询和文档没有交互,相关性判断不精确
* - 对于需要比较两段文字差异的场景效果差
*
* ===== Cross-Encoder(重排序)=====
*
* 工作方式:
* - 查询和文档拼在一起,输入到模型
* - 模型输出一个0-1的相关性分数
* - 模型能"看到"两者的完整上下文和交互
*
* 优点:
* - 精度显著高于Bi-Encoder
* - 在BEIR基准上,Cross-Encoder比Bi-Encoder高10-20%
*
* 缺点:
* - 不能预计算(每次查询都要重新运行模型)
* - 速度慢:处理100个候选文档需要100ms-500ms
* - 不能用于大规模初召(10万+文档就太慢了)
*
* ===== 结合使用(两阶段检索)=====
*
* 阶段1:Bi-Encoder快速召回Top-100
* 阶段2:Cross-Encoder精排,取Top-5
*
* 效果:最终结果的精度接近Cross-Encoder
* 速度:接近Bi-Encoder(只对100个候选做精排)
*/本地Cross-Encoder服务
/**
* 本地CrossEncoder重排序服务
*
* 推荐模型:
* - BAAI/bge-reranker-base(中英文,768MB,CPU可用)
* - ms-marco-MiniLM-L-6-v2(英文,轻量级)
* - cross-encoder/ms-marco-electra-base(英文,高精度)
*/
@Service
@Slf4j
public class LocalRerankerService {
private final OrtEnvironment ortEnv;
private final OrtSession session;
private final HuggingFaceTokenizer tokenizer;
// 批量推理的大小
private static final int BATCH_SIZE = 32;
public LocalRerankerService(
@Value("${reranker.model.path}") String modelPath) throws Exception {
this.ortEnv = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
this.session = ortEnv.createSession(modelPath, opts);
this.tokenizer = HuggingFaceTokenizer.newInstance(modelPath,
Map.of("maxLength", "512", "truncation", "true", "padding", "true"));
log.info("Reranker模型已加载: {}", modelPath);
}
/**
* 对候选文档重排序
*
* @param query 用户查询
* @param candidates 候选文档列表
* @return 重排序后的结果(按相关性降序)
*/
public List<RerankResult> rerank(String query, List<String> candidates) {
if (candidates.isEmpty()) return List.of();
long startMs = System.currentTimeMillis();
// 批量计算相关性分数
float[] scores = computeScoresBatch(query, candidates);
// 组合结果并排序
List<RerankResult> results = new ArrayList<>();
for (int i = 0; i < candidates.size(); i++) {
results.add(new RerankResult(i, candidates.get(i), scores[i]));
}
results.sort(Comparator.comparingDouble(RerankResult::score).reversed());
log.debug("重排序完成: candidates={}, latency={}ms",
candidates.size(), System.currentTimeMillis() - startMs);
return results;
}
/**
* 批量计算相关性分数
*/
private float[] computeScoresBatch(String query, List<String> documents) {
float[] allScores = new float[documents.size()];
// 按批次处理
for (int batchStart = 0; batchStart < documents.size(); batchStart += BATCH_SIZE) {
int batchEnd = Math.min(batchStart + BATCH_SIZE, documents.size());
List<String> batch = documents.subList(batchStart, batchEnd);
float[] batchScores = computeBatchScores(query, batch);
System.arraycopy(batchScores, 0, allScores, batchStart, batchScores.length);
}
return allScores;
}
private float[] computeBatchScores(String query, List<String> batch) {
try {
// 构建"查询 [SEP] 文档"格式的输入
List<String> pairs = batch.stream()
.map(doc -> query)
.toList();
// Tokenize所有pair
long[][] inputIds = new long[batch.size()][];
long[][] attentionMasks = new long[batch.size()][];
for (int i = 0; i < batch.size(); i++) {
// 对于CrossEncoder,使用对格式:query + sep + document
Encoding encoding = tokenizer.encode(query, batch.get(i));
inputIds[i] = encoding.getIds();
attentionMasks[i] = encoding.getAttentionMask();
}
// 找最长序列,做padding
int maxLen = Arrays.stream(inputIds).mapToInt(arr -> arr.length).max().orElse(0);
// ONNX推理
// (完整的张量创建和推理代码略,参考article-2094的实现)
// 简化返回:实际需要从logits计算sigmoid
float[] scores = new float[batch.size()];
// ... ONNX推理逻辑
return scores;
} catch (Exception e) {
log.error("批量推理失败: {}", e.getMessage(), e);
// 返回默认分数
float[] fallback = new float[batch.size()];
Arrays.fill(fallback, 0.5f);
return fallback;
}
}
public record RerankResult(int originalIndex, String content, float score) {}
}两阶段RAG集成
/**
* 两阶段检索:向量召回 + Reranking精排
*/
@Service
@RequiredArgsConstructor
@Slf4j
public class TwoStageRagService {
private final VectorStore vectorStore;
private final EmbeddingModel embeddingModel;
private final LocalRerankerService reranker;
private final ChatLanguageModel llm;
/**
* 两阶段检索并生成答案
*
* @param query 用户查询
* @param recallTopK 第一阶段召回数量(通常20-50)
* @param rerankTopK 最终保留数量(通常3-5)
*/
public RagAnswer retrieveAndGenerate(String query, int recallTopK, int rerankTopK) {
long totalStart = System.currentTimeMillis();
// 阶段1:向量召回
long retrievalStart = System.currentTimeMillis();
float[] queryVector = embeddingModel.embed(query).content().vector();
List<VectorStore.SearchResult> recalled = vectorStore.search(queryVector, recallTopK, null);
long retrievalMs = System.currentTimeMillis() - retrievalStart;
if (recalled.isEmpty()) {
return RagAnswer.noResults();
}
// 阶段2:Reranking精排
long rerankStart = System.currentTimeMillis();
List<String> candidateContents = recalled.stream()
.map(VectorStore.SearchResult::getContent)
.toList();
List<LocalRerankerService.RerankResult> reranked =
reranker.rerank(query, candidateContents);
long rerankMs = System.currentTimeMillis() - rerankStart;
// 取Top-K精排结果
List<String> topContext = reranked.stream()
.limit(rerankTopK)
.filter(r -> r.score() > 0.3) // 过滤掉相关度太低的
.map(LocalRerankerService.RerankResult::content)
.toList();
if (topContext.isEmpty()) {
return RagAnswer.noResults();
}
// 生成答案
long generationStart = System.currentTimeMillis();
String context = String.join("\n\n---\n\n", topContext);
String answer = generateAnswer(query, context);
long generationMs = System.currentTimeMillis() - generationStart;
long totalMs = System.currentTimeMillis() - totalStart;
log.debug("两阶段RAG完成: recall={}ms, rerank={}ms, generation={}ms, total={}ms",
retrievalMs, rerankMs, generationMs, totalMs);
return new RagAnswer(
answer, topContext, recalled.size(),
retrievalMs, rerankMs, generationMs, totalMs
);
}
/**
* 评估Reranking的效果
*
* 对比有无Reranking的答案质量
*/
public RerankingEffectivenessReport evaluateRerankingEffect(
List<String> testQueries, List<String> expectedAnswers) {
int improvements = 0;
int regressions = 0;
int neutral = 0;
for (int i = 0; i < testQueries.size(); i++) {
String query = testQueries.get(i);
// 无Reranking(只用向量检索的Top-5)
float[] queryVector = embeddingModel.embed(query).content().vector();
List<VectorStore.SearchResult> recalled = vectorStore.search(queryVector, 5, null);
String withoutRerank = generateAnswer(query, recalled.stream()
.map(VectorStore.SearchResult::getContent)
.collect(Collectors.joining("\n\n")));
// 有Reranking
RagAnswer withRerank = retrieveAndGenerate(query, 20, 5);
// 对比质量(简化:用LLM判断)
int comparison = compareAnswers(withoutRerank, withRerank.answer(), expectedAnswers.get(i));
if (comparison > 0) improvements++;
else if (comparison < 0) regressions++;
else neutral++;
}
return new RerankingEffectivenessReport(
testQueries.size(), improvements, regressions, neutral
);
}
private int compareAnswers(String without, String with, String expected) {
// 用LLM比较两个答案,返回1(with更好),-1(without更好),0(差不多)
// 简化实现
return 0;
}
private String generateAnswer(String query, String context) {
return llm.generate("根据以下资料回答:\n\n" + context + "\n\n问题:" + query);
}
@Data
@Builder
public static class RagAnswer {
private String answer;
private List<String> usedContexts;
private int totalRecalled;
private long retrievalMs;
private long rerankMs;
private long generationMs;
private long totalMs;
public static RagAnswer noResults() {
return RagAnswer.builder()
.answer("根据现有资料无法回答此问题。")
.usedContexts(List.of())
.build();
}
}
record RerankingEffectivenessReport(int total, int improvements,
int regressions, int neutral) {
public double improvementRate() { return (double) improvements / total; }
}
}实践建议
Reranking的收益在"hard negatives"场景最明显
所谓"hard negatives",是指那些向量相似度高但实际不相关的文档。比如:用户问"如何取消订单",向量检索可能返回"如何创建订单"(两者在语义空间很接近),但Cross-Encoder能区分"取消"和"创建"的实质差异。如果你的知识库里有很多主题相近但内容不同的文档,Reranking的提升会特别显著。反之,如果文档主题很分散,向量检索效果本来就好,Reranking提升有限。
召回数量(recall_top_k)要足够大
两阶段检索的前提是第一阶段召回够全。如果第一阶段只召回10个文档,正确答案可能根本没在里面,Reranking再好也没用。建议第一阶段召回50-100个文档,第二阶段精排取Top-5。这样相当于在100个候选里挑最好的5个,正确文档被召回的概率大幅提升。
本地部署Reranker vs 使用API
Cohere、Voyage AI等都提供Reranking API,接入简单,但每次调用有费用。对于低QPS应用,API是合理选择。对于高QPS(日均10万+次检索),本地部署BGE-Reranker等开源模型更经济——CPU机器单核可以处理约10 QPS,4核机器就能支撑日均300万次请求,成本比API低一个数量级。
