第1837篇:Reranker模型的选型与集成——BGE、Cohere Rerank的工程实践
第1837篇:Reranker模型的选型与集成——BGE、Cohere Rerank的工程实践
做 RAG 系统的时候,有个很常见的现象:向量检索召回的 Top-10 结果里,真正最相关的那个往往不在第一位。
Embedding 模型做的是语义相似度的粗估计,它的优势是快,但精度不足以直接做精排。Reranker(重排器)是专门用来解决这个问题的:输入一个查询和一批候选文档,输出每个文档与查询的精确相关性得分,重新排序。
我在一个实际项目里的数据:加 Reranker 之前,MRR@5(前5个结果里最相关文档的平均倒数排名)是 0.61,加上之后提到了 0.79,提升很显著。代价是每次查询多了 50-150ms 的延迟。
这篇讲 Reranker 的原理、主流模型对比,以及在 Java 服务里的集成实践。
一、Reranker 的工作原理
1.1 Cross-Encoder 架构
目前主流的 Reranker 都是基于 Cross-Encoder 架构:把查询和文档拼接成一个输入序列,通过 BERT 类模型做全序列 attention,最后输出相关性得分。
输入格式:
[CLS] 查询文本 [SEP] 文档文本 [SEP]
模型:
BERT → [CLS] token 的向量 → 线性层 → 相关性分数 (0~1)Cross-Encoder 和双塔(Bi-Encoder)的本质区别:
双塔模型的查询和文档独立编码,因此查询向量和文档向量之间没有直接的词级别交互,只有最后的余弦相似度;Cross-Encoder 从一开始就让查询和文档每个词互相 attention,捕捉到了更细粒度的匹配信息。
这就是为什么 Reranker 精度更高,代价是文档必须实时编码(不能预计算),只适合对少量候选做精排。
1.2 为什么要两阶段检索
单用 Reranker 对全库扫描不可行(100 万文档 × 单次 10ms = 2.7 小时),但对 50-200 个候选做精排完全可行。
二、主流 Reranker 模型对比
2.1 BGE-Reranker 系列
由 BAAI 开源,有多个版本:
| 模型 | 参数量 | 中文质量 | 英文质量 | 推理速度 |
|---|---|---|---|---|
| bge-reranker-base | 278M | ★★★ | ★★★ | 最快 |
| bge-reranker-large | 560M | ★★★★ | ★★★★ | 中 |
| bge-reranker-v2-m3 | 568M | ★★★★★ | ★★★★ | 中 |
| bge-reranker-v2-gemma | ~2B | ★★★★★ | ★★★★★ | 较慢 |
bge-reranker-v2-m3 是目前中文场景下最推荐的,基于 BGE-M3 的基座,同时支持多语言。
2.2 Cohere Rerank
Cohere 的 Rerank 是商业 API,主要优势是:
- 不需要自己部署模型
- 多语言支持很好
- 延迟可预期(通常 100-200ms)
劣势:
- 有 API 调用成本(按调用量计费)
- 数据要发送给第三方(隐私场景慎用)
- 国内访问稳定性有时不好
2.3 在线 vs 离线部署对比
三、BGE-Reranker 的 Java 集成
3.1 通过 ONNX 本地推理
import ai.onnxruntime.*;
import java.nio.*;
import java.util.*;
/**
* BGE-Reranker 本地推理服务
* 通过 ONNX Runtime 调用,无需 GPU(CPU 可用)
*/
public class BGERerankerService {
private final OrtEnvironment env;
private final OrtSession session;
private final TokenizerService tokenizer;
// BGE-Reranker-v2-m3 的最大输入长度
private static final int MAX_LENGTH = 512;
public BGERerankerService(String modelOnnxPath,
String tokenizerPath) throws OrtException {
this.env = OrtEnvironment.getEnvironment();
var opts = new OrtSession.SessionOptions();
// 利用全部 CPU 核心
opts.setIntraOpNumThreads(Runtime.getRuntime().availableProcessors());
opts.setInterOpNumThreads(2);
// 优化等级:图优化
opts.setOptimizationLevel(
OrtSession.SessionOptions.OptLevel.ALL_OPT);
this.session = env.createSession(modelOnnxPath, opts);
this.tokenizer = new TokenizerService(tokenizerPath);
}
/**
* 对单个 (查询, 文档) 对计算相关性分数
*
* @return 相关性分数,值域约为 [-10, 10],不是概率
* 通过 sigmoid 可以转成 [0, 1]
*/
public float score(String query, String document) throws OrtException {
return scoreBatch(List.of(new QueryDocPair(query, document))).get(0);
}
/**
* 批量计算相关性分数(推荐,效率更高)
*/
public List<Float> scoreBatch(List<QueryDocPair> pairs) throws OrtException {
int batchSize = pairs.size();
// 分词:查询和文档一起编码
var tokenized = tokenizer.tokenizePairs(
pairs.stream().map(p -> p.query()).collect(Collectors.toList()),
pairs.stream().map(p -> p.document()).collect(Collectors.toList()),
MAX_LENGTH
);
long seqLen = tokenized.maxSeqLen();
var inputIds = OnnxTensor.createTensor(env,
LongBuffer.wrap(tokenized.inputIds()),
new long[]{batchSize, seqLen});
var attentionMask = OnnxTensor.createTensor(env,
LongBuffer.wrap(tokenized.attentionMask()),
new long[]{batchSize, seqLen});
var tokenTypeIds = OnnxTensor.createTensor(env,
LongBuffer.wrap(tokenized.tokenTypeIds()),
new long[]{batchSize, seqLen});
try (var outputs = session.run(Map.of(
"input_ids", inputIds,
"attention_mask", attentionMask,
"token_type_ids", tokenTypeIds))) {
var logitsOutput = (OnnxTensor) outputs.get("logits").get();
float[][] logits = (float[][]) logitsOutput.getValue();
List<Float> scores = new ArrayList<>();
for (float[] logit : logits) {
// BGE-Reranker 输出 logits[0] 作为相关性分数
scores.add(sigmoid(logit[0]));
}
return scores;
}
}
/**
* 对候选文档重排
*
* @param query 查询
* @param candidates 候选文档列表(通常 50-200 个)
* @param topK 返回前 K 个
*/
public List<RankedDocument> rerank(String query,
List<String> candidates,
int topK) throws OrtException {
if (candidates.isEmpty()) return Collections.emptyList();
long startTime = System.currentTimeMillis();
// 批次推理(避免单批次太大导致OOM)
int inferenceBatchSize = 32;
List<Float> allScores = new ArrayList<>();
for (int i = 0; i < candidates.size(); i += inferenceBatchSize) {
List<QueryDocPair> batch = new ArrayList<>();
int end = Math.min(i + inferenceBatchSize, candidates.size());
for (int j = i; j < end; j++) {
batch.add(new QueryDocPair(query, candidates.get(j)));
}
allScores.addAll(scoreBatch(batch));
}
long inferenceTime = System.currentTimeMillis() - startTime;
// 排序
List<RankedDocument> ranked = new ArrayList<>();
for (int i = 0; i < candidates.size(); i++) {
ranked.add(new RankedDocument(candidates.get(i), allScores.get(i), i));
}
ranked.sort((a, b) -> Float.compare(b.score(), a.score()));
System.out.printf("Rerank 完成: %d 个候选,耗时 %d ms%n",
candidates.size(), inferenceTime);
return ranked.subList(0, Math.min(topK, ranked.size()));
}
private float sigmoid(float x) {
return 1.0f / (1.0f + (float) Math.exp(-x));
}
public record QueryDocPair(String query, String document) {}
public record RankedDocument(String text, float score, int originalRank) {}
}3.2 Cohere Rerank API 集成
import com.cohere.api.Cohere;
import com.cohere.api.resources.v2.rerank.*;
/**
* Cohere Rerank API 集成
*/
public class CohereRerankService {
private final Cohere cohere;
// 可用模型
public static final String RERANK_ENGLISH = "rerank-english-v3.0";
public static final String RERANK_MULTILINGUAL = "rerank-multilingual-v3.0";
public CohereRerankService(String apiKey) {
this.cohere = Cohere.builder().token(apiKey).build();
}
/**
* 重排序
*
* @param query 查询
* @param documents 候选文档
* @param topK 返回数量
* @param model 模型选择(英文/多语言)
*/
public List<RerankResult> rerank(String query,
List<String> documents,
int topK, String model) {
var response = cohere.v2().rerank(V2RerankRequest.builder()
.query(query)
.documents(documents.stream()
.map(d -> RerankRequestDocumentsItem.of(d))
.collect(Collectors.toList()))
.topN(topK)
.model(model)
.returnDocuments(false)
.build());
return response.getResults().stream()
.map(r -> new RerankResult(
documents.get(r.getIndex()),
r.getIndex(),
r.getRelevanceScore().floatValue()
))
.collect(Collectors.toList());
}
public record RerankResult(String text, int originalIndex, float score) {}
}四、完整的两阶段检索服务
/**
* 完整的两阶段检索服务
* 第一阶段: 向量检索召回候选
* 第二阶段: Reranker 精排
*/
@Service
public class TwoStageRetrievalService {
private final VectorSearchService vectorSearch;
private final BGERerankerService rerankerService;
private final DocumentRepository documentRepo;
// 性能监控
private final MeterRegistry meterRegistry;
@Autowired
public TwoStageRetrievalService(VectorSearchService vectorSearch,
BGERerankerService rerankerService,
DocumentRepository documentRepo,
MeterRegistry meterRegistry) {
this.vectorSearch = vectorSearch;
this.rerankerService = rerankerService;
this.documentRepo = documentRepo;
this.meterRegistry = meterRegistry;
}
/**
* 两阶段检索
*
* @param query 查询文本
* @param queryEmbedding 查询向量(由调用方提前计算)
* @param finalK 最终返回数量
* @param recallK 第一阶段召回数量(建议 5x~10x finalK)
*/
public List<RetrievalResult> retrieve(String query,
float[] queryEmbedding,
int finalK, int recallK) {
var totalTimer = meterRegistry.timer("retrieval.total");
return totalTimer.record(() -> {
try {
return doRetrieve(query, queryEmbedding, finalK, recallK);
} catch (Exception e) {
throw new RuntimeException("检索失败", e);
}
});
}
private List<RetrievalResult> doRetrieve(String query,
float[] queryEmbedding,
int finalK,
int recallK) throws Exception {
// === 第一阶段:向量检索 ===
long phase1Start = System.currentTimeMillis();
List<String> candidateDocIds = vectorSearch.search(queryEmbedding, recallK);
List<Document> candidateDocs = documentRepo.findAllById(candidateDocIds);
long phase1Time = System.currentTimeMillis() - phase1Start;
meterRegistry.timer("retrieval.phase1").record(
phase1Time, java.util.concurrent.TimeUnit.MILLISECONDS);
if (candidateDocs.isEmpty()) {
return Collections.emptyList();
}
// === 第二阶段:Reranker 精排 ===
long phase2Start = System.currentTimeMillis();
// 对文档内容做截断(Reranker 输入长度限制)
List<String> docTexts = candidateDocs.stream()
.map(doc -> truncateForReranker(doc.getContent(), 400))
.collect(Collectors.toList());
List<BGERerankerService.RankedDocument> reranked =
rerankerService.rerank(query, docTexts, finalK);
long phase2Time = System.currentTimeMillis() - phase2Start;
meterRegistry.timer("retrieval.phase2").record(
phase2Time, java.util.concurrent.TimeUnit.MILLISECONDS);
// 构建结果
return reranked.stream()
.map(r -> {
Document doc = candidateDocs.get(r.originalRank());
return new RetrievalResult(
doc.getId(),
doc.getContent(),
r.score(),
phase1Time,
phase2Time
);
})
.collect(Collectors.toList());
}
/**
* 截断文档内容,但尽量在句子边界截断
*/
private String truncateForReranker(String text, int maxChars) {
if (text.length() <= maxChars) return text;
// 在 maxChars 附近找句子边界
int idx = maxChars;
while (idx > maxChars * 0.8 && idx > 0) {
char c = text.charAt(idx);
if (c == '。' || c == '.' || c == '\n') break;
idx--;
}
return text.substring(0, idx + 1);
}
public record RetrievalResult(
String docId, String content, float relevanceScore,
long phase1Ms, long phase2Ms
) {}
}五、性能优化技巧
5.1 批量推理 vs 逐条推理
Reranker 的推理延迟在批次大小上不是线性的:
/**
* 批量大小对推理性能的影响(示意数据,实际因硬件不同)
* CPU (8核) + bge-reranker-base:
* batch_size=1: 每条 8ms,处理50条 = 400ms
* batch_size=8: 每条 4ms,处理50条 = 200ms(分7次推理)
* batch_size=32: 每条 2ms,处理50条 = 100ms(分2次推理)
* batch_size=64: 每条 2.5ms,处理50条 = 125ms(单次推理,但OOM风险)
*
* 推荐: batch_size = 16~32
*/
public class OptimizedBatchReranker {
private static final int OPTIMAL_BATCH_SIZE = 32;
public List<Float> scoreAllWithOptimalBatch(String query,
List<String> documents,
BGERerankerService reranker)
throws OrtException {
List<Float> allScores = new ArrayList<>(documents.size());
for (int i = 0; i < documents.size(); i += OPTIMAL_BATCH_SIZE) {
int end = Math.min(i + OPTIMAL_BATCH_SIZE, documents.size());
List<BGERerankerService.QueryDocPair> batch = new ArrayList<>();
for (int j = i; j < end; j++) {
batch.add(new BGERerankerService.QueryDocPair(query, documents.get(j)));
}
allScores.addAll(reranker.scoreBatch(batch));
}
return allScores;
}
}5.2 基于相似度的提前终止
如果向量检索的前几个结果相似度已经很高,可以不用对所有候选做 Rerank:
/**
* 自适应 Rerank:根据向量相似度决定是否 Rerank
*/
public class AdaptiveRerankStrategy {
// 如果 Top-1 向量相似度高于此阈值,认为召回质量够好,跳过精排
private static final float SKIP_RERANK_THRESHOLD = 0.92f;
// 只对相似度超过此阈值的候选做 Rerank(过滤低质量候选)
private static final float RERANK_CANDIDATE_THRESHOLD = 0.60f;
public List<RetrievalResult> adaptiveRerank(
String query,
List<VectorSearchResult> vectorResults,
int finalK,
BGERerankerService reranker) throws OrtException {
if (vectorResults.isEmpty()) return Collections.emptyList();
// 策略1:向量检索置信度很高,直接返回,不做 Rerank
float topScore = vectorResults.get(0).score();
if (topScore >= SKIP_RERANK_THRESHOLD && vectorResults.size() >= finalK) {
System.out.println("向量置信度高,跳过 Rerank");
return vectorResults.stream()
.limit(finalK)
.map(r -> new RetrievalResult(r.docId(), r.text(),
r.score(), false))
.collect(Collectors.toList());
}
// 策略2:过滤低质量候选,只对高质量候选做 Rerank
List<VectorSearchResult> rerankCandidates = vectorResults.stream()
.filter(r -> r.score() >= RERANK_CANDIDATE_THRESHOLD)
.collect(Collectors.toList());
// 至少保留 finalK 个候选
if (rerankCandidates.size() < finalK) {
rerankCandidates = vectorResults.subList(
0, Math.min(finalK * 2, vectorResults.size()));
}
// 执行 Rerank
List<String> texts = rerankCandidates.stream()
.map(VectorSearchResult::text)
.collect(Collectors.toList());
List<BGERerankerService.RankedDocument> ranked =
reranker.rerank(query, texts, finalK);
return ranked.stream()
.map(r -> new RetrievalResult(
rerankCandidates.get(r.originalRank()).docId(),
r.text(), r.score(), true))
.collect(Collectors.toList());
}
public record VectorSearchResult(String docId, String text, float score) {}
public record RetrievalResult(String docId, String text,
float score, boolean wasReranked) {}
}5.3 异步预热和缓存
/**
* Reranker 结果缓存
* 对相同的 (查询, 文档集合) 缓存结果,避免重复推理
*/
@Component
public class CachedRerankService {
private final BGERerankerService rerankerService;
private final Cache<String, List<Float>> scoreCache;
public CachedRerankService(BGERerankerService rerankerService) {
this.rerankerService = rerankerService;
// 缓存 1000 个查询的结果,过期时间 5 分钟
this.scoreCache = Caffeine.newBuilder()
.maximumSize(1000)
.expireAfterWrite(5, TimeUnit.MINUTES)
.recordStats()
.build();
}
/**
* 带缓存的打分
*/
public List<BGERerankerService.RankedDocument> rerankWithCache(
String query, List<String> documents, int topK) throws OrtException {
// 生成缓存 key(查询 + 文档 MD5)
String cacheKey = generateCacheKey(query, documents);
List<Float> cachedScores = scoreCache.getIfPresent(cacheKey);
if (cachedScores != null) {
// 命中缓存
return buildRankedFromScores(documents, cachedScores, topK);
}
// 推理
List<BGERerankerService.RankedDocument> result =
rerankerService.rerank(query, documents, topK);
// 存入缓存
List<Float> scores = new ArrayList<>();
Map<Integer, Float> scoreByOriginalRank = result.stream()
.collect(Collectors.toMap(
BGERerankerService.RankedDocument::originalRank,
BGERerankerService.RankedDocument::score
));
for (int i = 0; i < documents.size(); i++) {
scores.add(scoreByOriginalRank.getOrDefault(i, 0.0f));
}
scoreCache.put(cacheKey, scores);
return result;
}
private String generateCacheKey(String query, List<String> docs) {
// 简化版:实际用 MD5 或 SHA256
return query.hashCode() + "_" +
docs.stream().mapToInt(String::hashCode).sum();
}
private List<BGERerankerService.RankedDocument> buildRankedFromScores(
List<String> docs, List<Float> scores, int topK) {
List<BGERerankerService.RankedDocument> ranked = new ArrayList<>();
for (int i = 0; i < docs.size(); i++) {
ranked.add(new BGERerankerService.RankedDocument(
docs.get(i), scores.get(i), i));
}
ranked.sort((a, b) -> Float.compare(b.score(), a.score()));
return ranked.subList(0, Math.min(topK, ranked.size()));
}
}六、Cohere Rerank 和 BGE-Reranker 的实测对比
用一个企业知识库数据集(5000 篇文档,600 个测试查询)做对比:
| 指标 | 无Rerank | BGE-Reranker-v2-m3 | Cohere-Multilingual |
|---|---|---|---|
| MRR@5 | 0.62 | 0.79 | 0.81 |
| Recall@5 | 0.71 | 0.85 | 0.87 |
| Recall@10 | 0.83 | 0.91 | 0.92 |
| 平均延迟(50候选) | 8ms | 85ms | 180ms |
| 成本(百万次查询) | - | ~¥50(GPU 折算) | ~$600 |
结论:
- Cohere 略优于 BGE-Reranker-v2-m3,差距约 1-2 个百分点
- BGE-Reranker 延迟更低,成本更可控
- 对于有隐私合规要求的企业场景,BGE-Reranker 本地部署是必选
七、踩坑经验
坑1:文档截断对 Reranker 的影响比 Embedding 更大
Reranker 对输入长度的截断很敏感。如果把文档截到 128 字符,可能把最关键的信息截掉了,反而比不做 Rerank 还差。
建议:
- Reranker 输入长度至少保留 300-500 字符
- 如果文档很长,用滑动窗口方式(对文档的多个片段分别打分,取最高分)
/**
* 长文档的 Reranker 策略:分段打分取最大值
*/
public float scoreLongDocument(String query, String longDoc,
BGERerankerService reranker,
int windowSize, int stride) throws OrtException {
if (longDoc.length() <= windowSize) {
return reranker.score(query, longDoc);
}
float maxScore = Float.NEGATIVE_INFINITY;
for (int i = 0; i < longDoc.length(); i += stride) {
String window = longDoc.substring(i,
Math.min(i + windowSize, longDoc.length()));
float score = reranker.score(query, window);
if (score > maxScore) maxScore = score;
if (i + windowSize >= longDoc.length()) break;
}
return maxScore;
}坑2:Reranker 的分数不能直接比较不同查询
BGE-Reranker 输出的 sigmoid(logit) 在 [0, 1],但不同查询的绝对分数范围差异很大。有些查询的相关文档得分普遍较低(0.3-0.5),有些查询的相关文档得分普遍较高(0.7-0.9)。
不能用统一的阈值(比如"分数 > 0.5 才使用")来过滤,这样会漏掉一批查询的所有结果。
建议:不用绝对阈值,只用相对排名(取 Top-K),或者用标准化的相对分数(每次 Rerank 后对分数做 softmax 归一化)。
坑3:Reranker 对噪声文档比 Embedding 更鲁棒
这其实是个好消息。我发现 Embedding 检索会受到一类"关键词高度重叠但语义不相关"的文档干扰,而 Reranker 因为做了词级别交互,能更好地识别出这类噪声。
这也意味着:如果你的向量库里有很多"关键词相似但语义不相关"的文档(比如不同领域的技术文档),加 Reranker 的收益会更大。
八、生产部署建议
部署架构
Reranker 服务的部署规格建议:
- CPU 部署(8核):bge-reranker-base,QPS 约 5-10(50候选/次)
- CPU 部署(16核):bge-reranker-v2-m3,QPS 约 3-5
- GPU(T4):bge-reranker-v2-m3,QPS 约 50-100
降级策略
/**
* Reranker 降级:当 Reranker 服务超时时,回退到纯向量检索结果
*/
public List<RetrievalResult> retrieveWithFallback(String query,
float[] queryVec,
int topK) {
// 向量检索总是执行
List<String> vectorResults = vectorSearch.search(queryVec, topK * 3);
try {
// Reranker 带超时
CompletableFuture<List<RetrievalResult>> rerankFuture =
CompletableFuture.supplyAsync(() -> {
try {
return rerank(query, vectorResults, topK);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
return rerankFuture.get(150, TimeUnit.MILLISECONDS);
} catch (TimeoutException | ExecutionException e) {
// 降级:直接用向量检索结果
log.warn("Reranker 超时,降级使用向量检索结果");
return vectorResults.stream()
.limit(topK)
.map(docId -> new RetrievalResult(docId, getDocText(docId),
0.0f, false))
.collect(Collectors.toList());
}
}九、总结
Reranker 是 RAG 链路里性价比最高的单点优化,通常能带来 10-20 个百分点的 MRR 提升。
选型建议:
- 中文为主的企业场景:BGE-Reranker-v2-m3,本地部署,成本可控
- 多语言国际化产品:Cohere Rerank Multilingual 或 Jina Reranker v2
- 快速 POC 验证:Cohere Rerank API,最省事
- 延迟极度敏感(< 50ms):BGE-Reranker-base,精度略低但够用
集成注意点:长文档分段打分、不用绝对阈值、批量推理、做好超时降级。
