第1835篇:多向量表示——ColBERT的延迟交互检索在精排中的应用
第1835篇:多向量表示——ColBERT的延迟交互检索在精排中的应用
说到 RAG 的检索链路,大多数人第一反应是双塔模型:用 Embedding 把查询和文档分别编码成单向量,然后做向量相似度搜索。这种方案的好处是快,支持离线索引、在线毫秒级查询。
但单向量有个本质的局限:整个文档的语义被压缩进一个固定维度的向量里,细粒度的词级别交互信息全部丢失了。
比如用户问"Python 中 GIL 对多线程的影响",文档里有一段讲 GIL、另一段讲多线程,但这两段是分开的。单向量 Embedding 可能捕捉到文档的整体主题,但查询中"影响"这个词与文档中具体描述影响的句子之间的精细匹配,单向量很难做到。
ColBERT(Contextualized Late Interaction over BERT)就是为解决这个问题而设计的。
一、从单向量到多向量:ColBERT 的核心思想
1.1 传统双塔模型的问题
传统双塔(Bi-Encoder)的工作方式:
整个文档被压缩成一个向量,信息压缩比极高,细节必然丢失。
1.2 全交互(Cross-Encoder)的问题
全交互模型(如直接用 BERT 把查询和文档拼在一起做 cross-attention)能捕捉所有词级别交互,效果极好,但致命的问题是:文档必须在查询时实时编码,无法预计算索引。
1000 万个文档,每次查询要对所有文档做全交互——这显然不可行。
1.3 ColBERT 的折中:延迟交互
ColBERT 的创新是把交互推迟到最后一步(Late Interaction),但保留了词级别的多向量表示:
- 文档向量:离线预计算,每个 token 一个向量,存入向量库
- 查询向量:在线计算,通常只有 32 个 token,速度很快
- MaxSim 交互:查询的每个 token 找文档中最相似的 token,求和得到最终得分
这样做到了:预计算+高精度。
二、MaxSim 操作详解
ColBERT 的相关性得分公式:
含义:
- 是查询第 i 个 token 的向量
- 是文档第 j 个 token 的向量
- 对每个查询 token,找文档里与它最相似的那个 token
- 把所有查询 token 的最大相似度加起来
这个操作保证了:
- 查询的每个关键词都能找到文档中最匹配的位置(不会因为平均而稀释)
- 如果文档里有多个相关片段,每个都对得分有贡献
举例理解
查询:"Python GIL 对多线程的影响"
分词后大约有 8 个有意义的 token:[Python, GIL, 对, 多线程, 的, 影响, ...]
文档有段落讲"GIL 导致多线程在 CPU 密集型任务上无法并行..."
MaxSim 计算:
- 查询 "GIL" token → 找文档中与之最相似的 token → "GIL"(高分)
- 查询 "多线程" token → 找文档中最相似的 → "多线程"(高分)
- 查询 "影响" token → 找文档中最相似的 → "导致"、"无法并行"(较高分)
最终得分是这些最大值的总和,能精确捕捉到"文档讨论了GIL对多线程的具体影响"。
三、Java工程实现
3.1 ColBERT 向量的生成与存储
ColBERT 通常用 Python 侧的模型生成向量,Java 侧负责存储和检索。
/**
* ColBERT 多向量文档表示
* 每个文档由多个 token 向量组成
*/
public record ColBERTDocument(
String docId,
String text,
List<float[]> tokenVectors, // 每个token一个向量,通常128维
int tokenCount // token数量
) {
/**
* 估算存储大小(字节)
*/
public int storageBytes() {
return tokenCount * tokenVectors.get(0).length * 4;
}
/**
* 从Python端返回的JSON格式解析
* 格式: {"doc_id": "...", "token_vectors": [[0.1, 0.2, ...], ...]}
*/
public static ColBERTDocument fromJson(String json) {
// 实际用 Jackson 解析
var mapper = new com.fasterxml.jackson.databind.ObjectMapper();
try {
var node = mapper.readTree(json);
String docId = node.get("doc_id").asText();
String text = node.get("text").asText();
var vectorsNode = node.get("token_vectors");
List<float[]> vectors = new ArrayList<>();
for (var vecNode : vectorsNode) {
float[] vec = new float[vecNode.size()];
for (int i = 0; i < vecNode.size(); i++) {
vec[i] = (float) vecNode.get(i).asDouble();
}
vectors.add(vec);
}
return new ColBERTDocument(docId, text, vectors, vectors.size());
} catch (Exception e) {
throw new RuntimeException("解析ColBERT向量失败", e);
}
}
}3.2 Python 端的 ColBERT 向量生成脚本
# colbert_encoder.py
# 依赖: pip install colbert-ai torch transformers
from colbert import Indexer, Searcher
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
import json
import numpy as np
class ColBERTEncoder:
"""
ColBERT 多向量编码器
输出每个 token 的向量(128维,已归一化)
"""
def __init__(self, checkpoint_path: str = "colbert-ir/colbertv2.0"):
self.checkpoint = Checkpoint(checkpoint_path,
colbert_config=ColBERTConfig(doc_maxlen=512))
def encode_query(self, query: str) -> np.ndarray:
"""
编码查询,返回 shape [q_len, 128] 的向量矩阵
query 会 padding 到 32 个 token
"""
queries = [query]
Q = self.checkpoint.queryFromText(queries)
return Q[0].cpu().numpy() # [32, 128]
def encode_document(self, doc: str) -> np.ndarray:
"""
编码文档,返回 shape [d_len, 128] 的向量矩阵
"""
docs = [doc]
D, mask = self.checkpoint.docFromText(docs, keep_dims=False)
# 过滤掉 padding token
d_vectors = D[0][:mask[0].sum()].cpu().numpy()
return d_vectors # [actual_token_count, 128]
def encode_documents_batch(self, docs: list, output_file: str):
"""
批量编码文档,结果写入 JSON Lines 文件
"""
with open(output_file, 'w', encoding='utf-8') as f:
for i, doc_info in enumerate(docs):
vectors = self.encode_document(doc_info['text'])
result = {
'doc_id': doc_info['id'],
'text': doc_info['text'],
'token_vectors': vectors.tolist()
}
f.write(json.dumps(result, ensure_ascii=False) + '\n')
if (i + 1) % 100 == 0:
print(f"已处理 {i+1}/{len(docs)} 篇文档")
print(f"编码完成,结果保存到: {output_file}")3.3 Milvus 中存储多向量
ColBERT 的多向量存储有两种方案:
方案一:把一篇文档的所有 token 向量展平,每个 token 一条记录
/**
* 方案一:每个 token 向量作为独立记录
* 优点:直接用标准向量库
* 缺点:文档数据膨胀(平均200个token→200倍记录数)
*/
public class ColBERTMilvusService {
private final MilvusClientV2 milvusClient;
private static final int TOKEN_DIM = 128;
/**
* 创建 token 级别的 Collection
*/
public void createTokenCollection(String collectionName) {
var schema = CreateCollectionReq.CollectionSchema.newBuilder()
.addField(AddFieldReq.newBuilder()
.fieldName("id").dataType(DataType.Int64)
.isPrimaryKey(true).autoID(true).build())
.addField(AddFieldReq.newBuilder()
.fieldName("doc_id").dataType(DataType.VarChar)
.maxLength(64).build())
.addField(AddFieldReq.newBuilder()
.fieldName("token_pos").dataType(DataType.Int32).build())
.addField(AddFieldReq.newBuilder()
.fieldName("embedding").dataType(DataType.FloatVector)
.dimension(TOKEN_DIM).build())
.build();
milvusClient.createCollection(
CreateCollectionReq.newBuilder()
.collectionName(collectionName)
.collectionSchema(schema)
.build()
);
}
/**
* 插入文档的所有 token 向量
*/
public void insertDocument(String collectionName, ColBERTDocument doc) {
List<JsonObject> rows = new ArrayList<>();
for (int tokenPos = 0; tokenPos < doc.tokenVectors().size(); tokenPos++) {
JsonObject row = new JsonObject();
row.addProperty("doc_id", doc.docId());
row.addProperty("token_pos", tokenPos);
var arr = new com.google.gson.JsonArray();
for (float v : doc.tokenVectors().get(tokenPos)) arr.add(v);
row.add("embedding", arr);
rows.add(row);
}
milvusClient.insert(InsertReq.newBuilder()
.collectionName(collectionName)
.data(rows)
.build());
}
/**
* MaxSim 检索:查询所有 token 向量,按 doc_id 聚合
*
* @param queryVectors 查询的 token 向量列表(通常32个)
* @param topK 返回文档数
* @param candidateK 每个 token 向量返回的候选 token 数
*/
public List<String> maxSimSearch(String collectionName,
List<float[]> queryVectors,
int topK, int candidateK) {
// 存储每个 doc_id 的累积 MaxSim 得分
Map<String, float[]> docScores = new HashMap<>();
// key: doc_id, value: [当前累积得分, 已处理的查询token数]
for (int qi = 0; qi < queryVectors.size(); qi++) {
float[] queryToken = queryVectors.get(qi);
// 对每个查询 token,搜索最相似的 candidateK 个文档 token
var results = milvusClient.search(SearchReq.newBuilder()
.collectionName(collectionName)
.data(List.of(floatArrayToList(queryToken)))
.annsField("embedding")
.topK(candidateK)
.outputFields(List.of("doc_id"))
.build());
// 收集这个查询 token 对各个文档的最大相似度
Map<String, Float> tokenMaxSim = new HashMap<>();
for (var result : results.getSearchResults().get(0)) {
String docId = (String) result.getEntity().get("doc_id");
float score = result.getScore();
tokenMaxSim.merge(docId, score, Math::max);
}
// 累加到文档得分
for (var entry : tokenMaxSim.entrySet()) {
docScores.computeIfAbsent(entry.getKey(),
k -> new float[]{0.0f});
docScores.get(entry.getKey())[0] += entry.getValue();
}
}
// 按得分排序,返回 top-K 文档
return docScores.entrySet().stream()
.sorted((a, b) -> Float.compare(b.getValue()[0], a.getValue()[0]))
.limit(topK)
.map(Map.Entry::getKey)
.collect(Collectors.toList());
}
private List<Float> floatArrayToList(float[] arr) {
List<Float> list = new ArrayList<>(arr.length);
for (float v : arr) list.add(v);
return list;
}
}方案二:使用 Milvus 的 Multi-Vector Collection(Milvus 2.4+)
/**
* 方案二:Milvus 原生多向量支持(更高效)
* Milvus 2.4 开始支持一个 field 存储多个向量
*/
public class ColBERTMultiVecService {
/**
* 使用 Array of Vectors 字段存储 token 向量
* 注意:目前 Milvus 对多向量字段的原生支持还在发展中
* 生产上更稳定的方案是用 token_id 做分片 + 应用层 MaxSim
*/
public void createMultiVecCollection(String collectionName, int maxTokens) {
// 方案:把多个向量 flatten 成一个大向量
// 128维 × 最大200个token = 25600维(不推荐,太大)
// 更实用的方案:存储 PLAID 格式(ColBERT v2 的压缩存储)
// 每个 doc 的 token 向量用 int8 量化后存储为 blob
var schema = CreateCollectionReq.CollectionSchema.newBuilder()
.addField(AddFieldReq.newBuilder()
.fieldName("doc_id").dataType(DataType.VarChar)
.maxLength(64).isPrimaryKey(true).build())
.addField(AddFieldReq.newBuilder()
.fieldName("text").dataType(DataType.VarChar)
.maxLength(4000).build())
.addField(AddFieldReq.newBuilder()
.fieldName("token_count").dataType(DataType.Int32).build())
// 存储量化后的 token 向量(int8,展平)
.addField(AddFieldReq.newBuilder()
.fieldName("token_vectors_raw")
.dataType(DataType.JSON).build()) // 用JSON存byte数组
// 同时存储文档级别的向量用于粗排
.addField(AddFieldReq.newBuilder()
.fieldName("doc_embedding")
.dataType(DataType.FloatVector)
.dimension(128).build())
.build();
milvusClient.createCollection(
CreateCollectionReq.newBuilder()
.collectionName(collectionName)
.collectionSchema(schema)
.build()
);
}
}3.4 应用层 MaxSim 计算
/**
* MaxSim 得分计算(应用层实现)
* 当 Milvus 直接支持 ColBERT 之前,在应用层做
*/
public class MaxSimScorer {
/**
* 计算查询和文档之间的 MaxSim 得分
*
* @param queryTokenVecs 查询的 token 向量列表 [qlen, dim]
* @param docTokenVecs 文档的 token 向量列表 [dlen, dim]
* @return ColBERT 相关性得分
*/
public static float compute(List<float[]> queryTokenVecs,
List<float[]> docTokenVecs) {
float totalScore = 0.0f;
for (float[] qVec : queryTokenVecs) {
float maxSim = Float.NEGATIVE_INFINITY;
for (float[] dVec : docTokenVecs) {
float sim = dotProduct(qVec, dVec);
if (sim > maxSim) maxSim = sim;
}
// 跳过无意义的 padding token(相似度接近0)
if (maxSim > 0) totalScore += maxSim;
}
return totalScore;
}
/**
* 批量计算:一个查询 vs 多个文档(重排序用)
* 返回按得分降序的文档索引
*/
public static List<Integer> rerank(List<float[]> queryTokenVecs,
List<List<float[]>> candidateDocVecs,
int topK) {
// 计算每个候选文档的 MaxSim 得分
float[] scores = new float[candidateDocVecs.size()];
for (int i = 0; i < candidateDocVecs.size(); i++) {
scores[i] = compute(queryTokenVecs, candidateDocVecs.get(i));
}
// 排序取 top-K
Integer[] indices = new Integer[scores.length];
for (int i = 0; i < indices.length; i++) indices[i] = i;
Arrays.sort(indices, (a, b) -> Float.compare(scores[b], scores[a]));
List<Integer> result = new ArrayList<>();
for (int i = 0; i < Math.min(topK, indices.length); i++) {
result.add(indices[i]);
}
return result;
}
/**
* 向量内积(ColBERT 使用内积而非余弦,因为向量已归一化)
*/
private static float dotProduct(float[] a, float[] b) {
float sum = 0;
for (int i = 0; i < a.length; i++) sum += a[i] * b[i];
return sum;
}
/**
* SIMD 优化版本(需要 JDK 16+ 的 Vector API 或手动循环展开)
* 对 128 维向量,这个优化可以带来 2-4x 加速
*/
public static float dotProductOptimized(float[] a, float[] b) {
int len = a.length;
float sum0 = 0, sum1 = 0, sum2 = 0, sum3 = 0;
int i = 0;
// 4路展开
for (; i + 3 < len; i += 4) {
sum0 += a[i] * b[i];
sum1 += a[i+1] * b[i+1];
sum2 += a[i+2] * b[i+2];
sum3 += a[i+3] * b[i+3];
}
// 处理剩余
for (; i < len; i++) sum0 += a[i] * b[i];
return sum0 + sum1 + sum2 + sum3;
}
}四、ColBERT 在 RAG 链路中的定位
ColBERT 很少单独作为召回阶段使用(因为需要对候选集计算 MaxSim,太慢),更多用作精排(Reranker):
这套链路的延迟分析:
- 第一阶段:HNSW 召回 100 个候选,耗时 < 10ms
- 第二阶段:对 100 个候选用 ColBERT MaxSim 重排,耗时约 50-100ms(取决于文档长度和硬件)
- 总延迟:< 120ms,可以接受
五、延迟优化技巧
5.1 PLAID 量化(ColBERT v2)
ColBERT v2 提出了 PLAID(Passage-Level Approximate IVF Decomposition),核心是:
- 用 k-means 对所有 token 向量做聚类(通常 65536 个中心)
- 每个 token 向量只存储最近的中心索引(2字节)+ 残差(int8)
- 存储压缩约 8-16x,检索时先用中心向量做粗排,再用残差精排
# PLAID 量化的 Python 实现概念(实际用 colbert-ai 库)
from colbert.indexing.codecs.residual import ResidualCodec
# 训练 k-means 中心
codec = ResidualCodec.train(all_token_vectors, num_bits=2)
# 编码文档向量
encoded_doc = codec.compress(doc_token_vectors)
# 检索时解压(只解压候选文档)
decoded_doc = codec.decompress(encoded_doc)5.2 候选集预筛选
不对全库做 MaxSim,先用粗糙的单向量召回缩小候选集:
/**
* 两阶段 ColBERT 检索
* 第一阶段用文档级别的平均向量粗排
* 第二阶段用 MaxSim 精排
*/
public class TwoStageColBERTSearch {
private final VectorSearchService roughSearchService;
private final ColBERTDocumentStore docStore;
/**
* @param query 查询文本
* @param queryTokenVecs 查询的 token 向量
* @param finalK 最终返回数量
* @param roughK 第一阶段粗排候选数(越大越准但越慢)
*/
public List<SearchResult> search(String query,
List<float[]> queryTokenVecs,
int finalK, int roughK) {
// Step 1: 用查询平均向量做粗排召回 roughK 个候选
float[] queryMeanVec = computeMeanVector(queryTokenVecs);
List<String> roughCandidates = roughSearchService.search(
queryMeanVec, roughK);
// Step 2: 加载候选文档的 token 向量
List<ColBERTDocument> candidateDocs = roughCandidates.stream()
.map(docStore::getDocument)
.filter(Objects::nonNull)
.collect(Collectors.toList());
// Step 3: MaxSim 精排
List<float[]> candidateVecs = candidateDocs.stream()
.map(ColBERTDocument::tokenVectors)
.collect(Collectors.toList());
List<Integer> rerankedIndices = MaxSimScorer.rerank(
queryTokenVecs, candidateVecs, finalK);
// 构建返回结果
return rerankedIndices.stream()
.map(idx -> {
var doc = candidateDocs.get(idx);
float score = MaxSimScorer.compute(
queryTokenVecs, doc.tokenVectors());
return new SearchResult(doc.docId(), doc.text(), score);
})
.collect(Collectors.toList());
}
private float[] computeMeanVector(List<float[]> vecs) {
int dim = vecs.get(0).length;
float[] mean = new float[dim];
for (float[] v : vecs) {
for (int i = 0; i < dim; i++) mean[i] += v[i];
}
float norm = 0;
for (int i = 0; i < dim; i++) {
mean[i] /= vecs.size();
norm += mean[i] * mean[i];
}
norm = (float) Math.sqrt(norm);
for (int i = 0; i < dim; i++) mean[i] /= norm;
return mean;
}
public record SearchResult(String docId, String text, float score) {}
}六、与其他 Reranker 的对比
ColBERT 不是唯一的精排方案,和 Cross-Encoder 类 Reranker 相比:
| 维度 | ColBERT | Cross-Encoder (如 BGE-Reranker) |
|---|---|---|
| 精度 | 高 | 最高 |
| 延迟(100候选) | 50-100ms | 200-500ms |
| 文档预计算 | 支持 | 不支持 |
| 存储开销 | 高(多向量) | 无(实时编码) |
| 部署难度 | 中 | 低 |
工程上的选择逻辑:
- 候选集 < 50,延迟要求宽松(< 500ms):Cross-Encoder 更准确,直接用
- 候选集 50-200,延迟要求中等(< 200ms):ColBERT 是最佳折中
- 候选集 > 200,延迟要求严格(< 100ms):只能用 ColBERT 或更轻量的方案
七、踩坑经验
坑1:查询 token 的 [MASK] padding 影响得分
ColBERT 对查询做 padding 到 32 token,[MASK] token 的向量会参与 MaxSim 计算。如果不过滤这些 padding token,它们会给所有文档贡献一个近似相等的基础分,稀释真实的区分度。
解决:计算 MaxSim 时,跳过相似度接近 0(< 0.01)的查询 token 贡献。
坑2:文档 token 向量占用的存储被严重低估
一篇 500 字的中文文档,分词后大约 200 个 token,每个 128 维 float32 向量: 200 × 128 × 4 = 102400 字节 = 100KB
100 万篇文档:100GB。这是很多团队没想到的。
必须使用 int8 量化(25GB)或 PLAID 压缩(< 10GB)。
坑3:maxSim 实现时没有区分 CLS token
ColBERT 编码文档时,第一个 [CLS] token 的向量通常代表文档整体语义,其他 token 代表局部语义。在某些使用场景下,把 [CLS] token 也纳入 MaxSim 计算会引入混淆。建议在工程实现中明确是否包含 [CLS],与训练时保持一致。
八、总结
ColBERT 的核心价值在于用多向量表示解决了单向量的信息压缩损失问题,通过"延迟交互"在可扩展性和精度之间找到了最优平衡点。
适合用 ColBERT 的场景:
- RAG 的精排阶段(候选集 50-300 个)
- 对精度要求高但实时编码太慢的场景
- 需要可解释性(能知道哪个 token 贡献了相关性)
不适合的场景:
- 直接作为一阶段召回(存储和计算成本太高)
- 文档更新极频繁(每次更新要重新计算多个向量)
