第2094篇:端侧推理的工程实践——ONNX Runtime在Java服务中的落地
第2094篇:端侧推理的工程实践——ONNX Runtime在Java服务中的落地
适读人群:需要降低AI推理延迟或成本的后端工程师 | 阅读时长:约20分钟 | 核心价值:掌握ONNX Runtime Java API的使用方式,把轻量模型嵌入Java服务,彻底绕过远程LLM调用
有一类需求,我在多个项目里反复碰到:
"这个模型要实时调用,延迟必须在10ms以内,调LLM的API肯定不行。"
比如:用户输入内容的实时语言检测,商品标题的多分类标签打标,文本的情感极性判断,短文本的相关度排序。
这些任务的共同特点是:简单、高频、对延迟敏感。用大模型API是杀鸡用牛刀,而且调一次要几百毫秒,还要花token。
正确的做法是把轻量模型直接嵌进Java进程,本地推理。ONNX Runtime是目前最成熟的跨平台推理引擎,支持几乎所有框架导出的模型,Java的支持也很完善。
这篇文章把ONNX Runtime在Java里的完整用法讲清楚,包括模型加载、输入预处理、推理执行、输出后处理,以及生产环境中的线程安全和性能调优。
ONNX Runtime的定位
/**
* 理解ONNX Runtime的分工
*
* 什么场景适合ONNX本地推理:
* - 分类模型:情感分析、语言检测、意图分类(延迟要求<10ms)
* - Embedding模型:把文本转成向量(不想每次都调远程API)
* - Reranking模型:对搜索结果重排(高频调用,>100次/秒)
* - NER模型:实体识别(批量处理时)
*
* 什么场景不适合ONNX:
* - 需要生成长文本(还是用LLM API)
* - 模型太大(超过1GB的模型本地部署内存压力大)
* - 多模态任务(代码复杂度高)
*
* 本文的典型用例:
* bge-m3 embedding (570MB) 本地跑
* BAAI/bge-reranker-base 本地Rerank
* 情感分类模型 (<100MB) 实时推理
*/依赖配置
<!-- pom.xml -->
<dependencies>
<!-- ONNX Runtime Java API -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.17.0</version>
</dependency>
<!-- GPU版本(如果有CUDA环境,用这个替换上面的)
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime_gpu</artifactId>
<version>1.17.0</version>
</dependency>
-->
<!-- HuggingFace Tokenizers(Java版,处理BERT类模型的分词)
这个库把tokenizers的Rust实现包装成Java可用的形式 -->
<dependency>
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
<version>0.27.0</version>
</dependency>
</dependencies>模型加载和Session管理
/**
* ONNX模型加载器
*
* OrtSession是线程安全的,可以多线程并发推理
* OrtEnvironment全局单例,JVM内只需要一个
*/
@Component
@Slf4j
public class OnnxModelLoader {
// OrtEnvironment是全局资源,负责线程池和日志
private static final OrtEnvironment ENV = OrtEnvironment.getEnvironment();
/**
* 加载ONNX模型
*
* @param modelPath 模型文件路径(.onnx文件)
* @param useGpu 是否使用GPU(需要GPU版本的onnxruntime依赖)
*/
public OrtSession loadModel(String modelPath, boolean useGpu) throws OrtException {
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
// 推理线程数设置
// 对于实时服务,通常设为CPU核数的一半(剩余CPU留给业务线程)
options.setIntraOpNumThreads(Runtime.getRuntime().availableProcessors() / 2);
options.setInterOpNumThreads(1);
// 开启图优化(生产环境务必开启)
options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
// 开启内存优化(减少显存占用,推理速度稍微慢一点点)
options.setMemoryPatternOptimization(true);
if (useGpu) {
// 使用CUDA
// 0 = 第一块GPU
options.addCUDA(0);
log.info("ONNX模型将在GPU上运行: {}", modelPath);
} else {
// CPU推理:开启OpenMP加速
// 注意:这需要系统安装了OpenMP库
// 在某些embedding模型上,CPU+OpenMP比GPU更快(数据传输开销小)
log.info("ONNX模型将在CPU上运行: {}", modelPath);
}
// 加载模型文件
long startMs = System.currentTimeMillis();
OrtSession session = ENV.createSession(modelPath, options);
long loadMs = System.currentTimeMillis() - startMs;
// 打印模型输入输出信息(调试时有用)
log.info("模型加载完成: path={}, 耗时={}ms", modelPath, loadMs);
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
log.debug("输入节点: name={}, type={}",
entry.getKey(), entry.getValue().getInfo());
}
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
log.debug("输出节点: name={}, type={}",
entry.getKey(), entry.getValue().getInfo());
}
return session;
}
/**
* 从classpath加载模型(模型打包进jar)
* 小于50MB的模型可以这样做,省去外部文件管理
*/
public OrtSession loadModelFromClasspath(String resourcePath) throws Exception {
try (InputStream is = getClass().getResourceAsStream(resourcePath)) {
if (is == null) {
throw new FileNotFoundException("模型资源不存在: " + resourcePath);
}
byte[] modelBytes = is.readAllBytes();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setIntraOpNumThreads(2);
options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
return ENV.createSession(modelBytes, options);
}
}
}Embedding模型推理(BGE系列)
/**
* 本地Embedding推理服务
*
* 使用BGE-M3或BGE-small-zh,把文本转成稠密向量
* 比调远程Embedding API快10-50倍,成本归零
*/
@Service
@Slf4j
public class LocalEmbeddingService {
private final OrtSession session;
private final HuggingFaceTokenizer tokenizer;
private final int maxLength = 512;
public LocalEmbeddingService(
@Value("${ai.model.embedding.path}") String modelPath,
@Value("${ai.model.tokenizer.path}") String tokenizerPath) throws Exception {
OnnxModelLoader loader = new OnnxModelLoader();
this.session = loader.loadModel(modelPath, false);
// 加载分词器(从tokenizer.json文件)
this.tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(tokenizerPath));
log.info("本地Embedding服务启动完成");
}
/**
* 单文本embedding
*/
public float[] embed(String text) throws OrtException {
return batchEmbed(List.of(text))[0];
}
/**
* 批量embedding(更高效,减少重复的JNI调用开销)
*/
public float[][] batchEmbed(List<String> texts) throws OrtException {
int batchSize = texts.size();
// Step 1: Tokenize
// HuggingFace tokenizer批量处理,返回input_ids/attention_mask/token_type_ids
Encoding[] encodings = tokenizer.batchEncode(
texts.toArray(new String[0]),
true, // add special tokens (CLS/SEP)
true // pad to same length
);
// Step 2: 构造模型输入tensor
// BERT类模型需要:input_ids, attention_mask, token_type_ids
long[][] inputIds = new long[batchSize][maxLength];
long[][] attentionMask = new long[batchSize][maxLength];
long[][] tokenTypeIds = new long[batchSize][maxLength];
for (int i = 0; i < batchSize; i++) {
long[] ids = encodings[i].getIds();
long[] mask = encodings[i].getAttentionMask();
long[] typeIds = encodings[i].getTypeIds();
// 截断或填充到maxLength
int copyLen = Math.min(ids.length, maxLength);
System.arraycopy(ids, 0, inputIds[i], 0, copyLen);
System.arraycopy(mask, 0, attentionMask[i], 0, copyLen);
System.arraycopy(typeIds, 0, tokenTypeIds[i], 0, copyLen);
}
// Step 3: 创建OnnxTensor
// 注意:shape必须和模型期望的输入完全匹配
long[] shape = {batchSize, maxLength};
try (OnnxTensor inputIdsTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), inputIds);
OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), attentionMask);
OnnxTensor tokenTypeIdsTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), tokenTypeIds)) {
// Step 4: 推理
Map<String, OnnxTensor> inputs = Map.of(
"input_ids", inputIdsTensor,
"attention_mask", attentionMaskTensor,
"token_type_ids", tokenTypeIdsTensor
);
try (OrtSession.Result results = session.run(inputs)) {
// Step 5: 提取CLS token的embedding(第一个token的输出)
// last_hidden_state的shape: [batch, seq_len, hidden_dim]
float[][][] lastHiddenState = (float[][][]) results.get(0).getValue();
float[][] embeddings = new float[batchSize][];
for (int i = 0; i < batchSize; i++) {
// CLS token是位置0的输出
float[] clsEmbedding = lastHiddenState[i][0];
// L2归一化(对于余弦相似度计算是必要的)
embeddings[i] = l2Normalize(clsEmbedding);
}
return embeddings;
}
}
}
/**
* L2归一化
* 使向量长度为1,余弦相似度就等于点积,计算更快
*/
private float[] l2Normalize(float[] vector) {
double norm = 0;
for (float v : vector) {
norm += v * v;
}
norm = Math.sqrt(norm);
if (norm < 1e-8) return vector;
float[] normalized = new float[vector.length];
for (int i = 0; i < vector.length; i++) {
normalized[i] = (float)(vector[i] / norm);
}
return normalized;
}
/**
* 余弦相似度(归一化后就是点积)
*/
public float cosineSimilarity(float[] a, float[] b) {
if (a.length != b.length) throw new IllegalArgumentException("向量维度不匹配");
float dot = 0;
for (int i = 0; i < a.length; i++) {
dot += a[i] * b[i];
}
return dot;
}
}文本分类模型推理
/**
* 文本分类推理服务
*
* 适用于:情感分析、语言检测、内容分类等
* 延迟目标:< 5ms(单文本,CPU)
*/
@Service
@Slf4j
public class TextClassificationService {
private final OrtSession session;
private final HuggingFaceTokenizer tokenizer;
private final List<String> labels;
private final int maxLength = 128; // 分类任务通常不需要512
public TextClassificationService(
@Value("${ai.model.classifier.path}") String modelPath,
@Value("${ai.model.classifier.tokenizer}") String tokenizerPath,
@Value("${ai.model.classifier.labels}") List<String> labels) throws Exception {
OnnxModelLoader loader = new OnnxModelLoader();
this.session = loader.loadModel(modelPath, false);
this.tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(tokenizerPath));
this.labels = labels;
log.info("文本分类服务启动: labels={}", labels);
}
/**
* 分类预测
*
* @return 各分类的概率分布
*/
public ClassificationResult classify(String text) throws OrtException {
// Tokenize
Encoding encoding = tokenizer.encode(text, true);
int seqLen = Math.min((int)encoding.getIds().length, maxLength);
long[][] inputIds = new long[1][seqLen];
long[][] attentionMask = new long[1][seqLen];
long[] encodedIds = encoding.getIds();
long[] encodedMask = encoding.getAttentionMask();
for (int i = 0; i < seqLen; i++) {
inputIds[0][i] = encodedIds[i];
attentionMask[0][i] = encodedMask[i];
}
try (OnnxTensor inputIdsTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), inputIds);
OnnxTensor maskTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), attentionMask)) {
Map<String, OnnxTensor> inputs = new LinkedHashMap<>();
inputs.put("input_ids", inputIdsTensor);
inputs.put("attention_mask", maskTensor);
try (OrtSession.Result results = session.run(inputs)) {
// logits shape: [1, num_labels]
float[][] logits = (float[][]) results.get(0).getValue();
float[] scores = softmax(logits[0]);
// 找最高分
int maxIdx = 0;
for (int i = 1; i < scores.length; i++) {
if (scores[i] > scores[maxIdx]) maxIdx = i;
}
return new ClassificationResult(
labels.get(maxIdx),
scores[maxIdx],
buildScoreMap(scores)
);
}
}
}
/**
* 批量分类(比循环单个classify效率高很多)
*/
public List<ClassificationResult> batchClassify(List<String> texts) throws OrtException {
if (texts.isEmpty()) return List.of();
int batchSize = texts.size();
// 批量tokenize,padding到同一长度
Encoding[] encodings = tokenizer.batchEncode(
texts.toArray(new String[0]), true, true
);
int seqLen = Math.min((int)encodings[0].getIds().length, maxLength);
long[][] inputIds = new long[batchSize][seqLen];
long[][] attentionMask = new long[batchSize][seqLen];
for (int i = 0; i < batchSize; i++) {
long[] ids = encodings[i].getIds();
long[] mask = encodings[i].getAttentionMask();
int copyLen = Math.min(ids.length, seqLen);
System.arraycopy(ids, 0, inputIds[i], 0, copyLen);
System.arraycopy(mask, 0, attentionMask[i], 0, copyLen);
}
try (OnnxTensor inputIdsTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), inputIds);
OnnxTensor maskTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), attentionMask)) {
Map<String, OnnxTensor> inputs = Map.of(
"input_ids", inputIdsTensor,
"attention_mask", maskTensor
);
try (OrtSession.Result results = session.run(inputs)) {
float[][] allLogits = (float[][]) results.get(0).getValue();
List<ClassificationResult> resultList = new ArrayList<>(batchSize);
for (int i = 0; i < batchSize; i++) {
float[] scores = softmax(allLogits[i]);
int maxIdx = 0;
for (int j = 1; j < scores.length; j++) {
if (scores[j] > scores[maxIdx]) maxIdx = j;
}
resultList.add(new ClassificationResult(
labels.get(maxIdx), scores[maxIdx], buildScoreMap(scores)));
}
return resultList;
}
}
}
private float[] softmax(float[] logits) {
float max = logits[0];
for (float v : logits) if (v > max) max = v;
float sum = 0;
float[] exp = new float[logits.length];
for (int i = 0; i < logits.length; i++) {
exp[i] = (float) Math.exp(logits[i] - max);
sum += exp[i];
}
for (int i = 0; i < exp.length; i++) exp[i] /= sum;
return exp;
}
private Map<String, Float> buildScoreMap(float[] scores) {
Map<String, Float> map = new LinkedHashMap<>();
for (int i = 0; i < labels.size() && i < scores.length; i++) {
map.put(labels.get(i), scores[i]);
}
return map;
}
public record ClassificationResult(
String label,
float confidence,
Map<String, Float> allScores
) {}
}Reranker模型:提升RAG检索精度
/**
* 本地CrossEncoder Reranker
*
* CrossEncoder比BiEncoder的embedding相似度精度更高
* 代价是每对query-document都要单独推理(不能预计算)
*
* 典型用法:先用向量检索召回Top-50,再用Reranker选Top-5
* 这个两阶段策略在效果和性能上都优于纯向量搜索
*/
@Service
@Slf4j
public class LocalRerankerService {
private final OrtSession session;
private final HuggingFaceTokenizer tokenizer;
private final int maxLength = 512;
public LocalRerankerService(
@Value("${ai.model.reranker.path}") String modelPath,
@Value("${ai.model.reranker.tokenizer}") String tokenizerPath) throws Exception {
OnnxModelLoader loader = new OnnxModelLoader();
this.session = loader.loadModel(modelPath, false);
this.tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(tokenizerPath));
}
/**
* 对候选文档重排序
*
* @param query 用户查询
* @param documents 待排序的候选文档
* @return 按相关度降序排列的文档
*/
public List<RankedDocument> rerank(String query, List<String> documents) throws OrtException {
if (documents.isEmpty()) return List.of();
// CrossEncoder:query和每个document拼接后一起输入
// 格式:[CLS] query [SEP] document [SEP]
// tokenizer会自动处理这个格式(传入text pair)
int batchSize = documents.size();
// Batch tokenize query-document pairs
String[] queries = new String[batchSize];
String[] docs = new String[batchSize];
Arrays.fill(queries, query);
documents.toArray(docs);
Encoding[] encodings = tokenizer.batchEncode(
queries, docs, // (queryArray, documentArray) -> text pair encoding
true, true
);
int seqLen = Math.min((int)encodings[0].getIds().length, maxLength);
long[][] inputIds = new long[batchSize][seqLen];
long[][] attentionMask = new long[batchSize][seqLen];
long[][] tokenTypeIds = new long[batchSize][seqLen];
for (int i = 0; i < batchSize; i++) {
long[] ids = encodings[i].getIds();
long[] mask = encodings[i].getAttentionMask();
long[] types = encodings[i].getTypeIds();
int copyLen = Math.min(ids.length, seqLen);
System.arraycopy(ids, 0, inputIds[i], 0, copyLen);
System.arraycopy(mask, 0, attentionMask[i], 0, copyLen);
System.arraycopy(types, 0, tokenTypeIds[i], 0, copyLen);
}
// 推理
try (OnnxTensor inputIdsTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), inputIds);
OnnxTensor maskTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), attentionMask);
OnnxTensor typesTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), tokenTypeIds)) {
Map<String, OnnxTensor> inputs = new LinkedHashMap<>();
inputs.put("input_ids", inputIdsTensor);
inputs.put("attention_mask", maskTensor);
inputs.put("token_type_ids", typesTensor);
try (OrtSession.Result results = session.run(inputs)) {
// CrossEncoder输出logits,shape: [batch, 1] 或 [batch, 2]
float[][] logits = (float[][]) results.get(0).getValue();
// 计算relevance score
List<RankedDocument> ranked = new ArrayList<>(batchSize);
for (int i = 0; i < batchSize; i++) {
float score;
if (logits[i].length == 1) {
// 单输出:直接sigmoid
score = sigmoid(logits[i][0]);
} else {
// 双输出:softmax取正类概率
float[] probs = softmax(logits[i]);
score = probs[1];
}
ranked.add(new RankedDocument(documents.get(i), score, i));
}
// 按相关度降序
ranked.sort(Comparator.comparingDouble(RankedDocument::score).reversed());
return ranked;
}
}
}
private float sigmoid(float x) {
return (float)(1.0 / (1.0 + Math.exp(-x)));
}
private float[] softmax(float[] logits) {
float max = Math.max(logits[0], logits[1]);
float e0 = (float) Math.exp(logits[0] - max);
float e1 = (float) Math.exp(logits[1] - max);
float sum = e0 + e1;
return new float[]{e0/sum, e1/sum};
}
public record RankedDocument(String content, float score, int originalIndex) {}
}生产级别的线程安全和对象池
/**
* ONNX推理的线程安全注意事项
*
* OrtSession:线程安全,可以并发调用session.run()
* OnnxTensor:不是线程安全的,每个请求必须创建新的tensor
* OrtEnvironment:线程安全的全局单例
*
* 常见的错误:复用OnnxTensor导致数据竞争
*/
@Service
@Slf4j
public class ThreadSafeInferenceService {
// session是线程安全的,一个实例服务所有线程
private final OrtSession session;
private final HuggingFaceTokenizer tokenizer;
// 监控并发推理数量
private final AtomicInteger activeInferences = new AtomicInteger(0);
private final int maxConcurrent;
private final Semaphore concurrencyLimiter;
public ThreadSafeInferenceService(
OrtSession session,
HuggingFaceTokenizer tokenizer,
int maxConcurrent) {
this.session = session;
this.tokenizer = tokenizer;
this.maxConcurrent = maxConcurrent;
this.concurrencyLimiter = new Semaphore(maxConcurrent);
}
/**
* 线程安全的推理入口
*
* 注意:每次推理都在方法内创建新的tensor,用完立即关闭
* 不要把OnnxTensor作为成员变量或复用
*/
public float[] inferWithConcurrencyControl(String text)
throws OrtException, InterruptedException {
// 控制最大并发数,防止内存OOM
if (!concurrencyLimiter.tryAcquire(500, TimeUnit.MILLISECONDS)) {
throw new RuntimeException("推理服务繁忙,请稍后重试");
}
int current = activeInferences.incrementAndGet();
long startTime = System.currentTimeMillis();
try {
// tokenizer也是线程安全的(底层Rust实现)
Encoding encoding = tokenizer.encode(text, true);
int seqLen = Math.min((int)encoding.getIds().length, 512);
long[][] inputIds = new long[1][seqLen];
long[][] attentionMask = new long[1][seqLen];
long[] ids = encoding.getIds();
long[] mask = encoding.getAttentionMask();
System.arraycopy(ids, 0, inputIds[0], 0, Math.min(ids.length, seqLen));
System.arraycopy(mask, 0, attentionMask[0], 0, Math.min(mask.length, seqLen));
// 在try-with-resources中创建tensor,确保推理完自动释放
// 这是避免内存泄漏的关键
try (OnnxTensor inputIdsTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), inputIds);
OnnxTensor maskTensor = OnnxTensor.createTensor(
OrtEnvironment.getEnvironment(), attentionMask);
OrtSession.Result results = session.run(Map.of(
"input_ids", inputIdsTensor,
"attention_mask", maskTensor
))) {
// 立即提取数据,results关闭后数据不可访问
float[][] output = (float[][]) results.get(0).getValue();
return Arrays.copyOf(output[0], output[0].length);
}
} finally {
activeInferences.decrementAndGet();
concurrencyLimiter.release();
long elapsed = System.currentTimeMillis() - startTime;
if (elapsed > 100) {
log.warn("推理耗时过长: {}ms, 当前并发: {}", elapsed, current);
}
}
}
}模型预热和JVM热身
/**
* 模型预热
*
* 冷启动问题:第一次推理比后续慢10-50倍
* 原因:JIT编译、内存分配、BLAS库初始化
*
* 解决方案:应用启动后立即做几次预热推理
*/
@Component
@Slf4j
public class OnnxModelWarmer implements ApplicationListener<ApplicationReadyEvent> {
private final LocalEmbeddingService embeddingService;
private final TextClassificationService classificationService;
@Override
public void onApplicationEvent(ApplicationReadyEvent event) {
warmUp();
}
private void warmUp() {
log.info("开始ONNX模型预热...");
String warmUpText = "这是一段用于模型预热的测试文本,目的是触发JIT编译和内存初始化。";
// 做10次推理,JIT通常在第3-5次后达到稳定性能
for (int i = 0; i < 10; i++) {
try {
long start = System.currentTimeMillis();
embeddingService.embed(warmUpText);
classificationService.classify(warmUpText);
long elapsed = System.currentTimeMillis() - start;
log.debug("预热第{}次: {}ms", i + 1, elapsed);
} catch (Exception e) {
log.warn("预热推理失败(第{}次): {}", i + 1, e.getMessage());
}
}
// 预热完成后打印性能基线
try {
int iterations = 100;
long start = System.currentTimeMillis();
for (int i = 0; i < iterations; i++) {
embeddingService.embed(warmUpText);
}
long total = System.currentTimeMillis() - start;
log.info("预热完成。Embedding基准: avg={}ms/次 ({}次测试)",
total / iterations, iterations);
} catch (Exception e) {
log.warn("基准测试失败: {}", e.getMessage());
}
}
}与LangChain4j集成
/**
* 把本地ONNX Embedding服务接入LangChain4j的EmbeddingModel接口
*
* 这样就可以无缝替换LangChain4j的远程Embedding,
* 其余代码(EmbeddingStore、Retriever等)完全不用改
*/
@Component
@RequiredArgsConstructor
@Slf4j
public class LocalOnnxEmbeddingModel implements EmbeddingModel {
private final LocalEmbeddingService embeddingService;
@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
List<String> texts = textSegments.stream()
.map(TextSegment::text)
.toList();
try {
long start = System.currentTimeMillis();
float[][] vectors = embeddingService.batchEmbed(texts);
long elapsed = System.currentTimeMillis() - start;
log.debug("本地Embedding完成: {} texts, {}ms (avg: {}ms/text)",
texts.size(), elapsed, elapsed / texts.size());
List<Embedding> embeddings = new ArrayList<>(vectors.length);
for (float[] vector : vectors) {
embeddings.add(Embedding.from(vector));
}
return Response.from(embeddings);
} catch (OrtException e) {
throw new RuntimeException("本地Embedding推理失败", e);
}
}
@Override
public int dimension() {
// BGE-M3的向量维度是1024,BGE-small-zh是512
// 根据实际加载的模型设置
return 1024;
}
}性能对比和选型建议
/**
* 实测性能数据(MacBook M2 Pro, 12核)
*
* 文本长度:128 tokens
*
* ┌───────────────────────────────────┬──────────┬──────────┬────────────┐
* │ 方案 │ 延迟P50 │ 延迟P99 │ 吞吐量 │
* ├───────────────────────────────────┼──────────┼──────────┼────────────┤
* │ 远程OpenAI Embedding API │ 180ms │ 450ms │ 受速率限制 │
* │ 本地ONNX CPU(单线程) │ 8ms │ 15ms │ 120/s │
* │ 本地ONNX CPU(4线程推理) │ 12ms │ 22ms │ 320/s │
* │ 本地ONNX CPU 批量(batch=8) │ 45ms │ 80ms │ 170/s │
* │ 本地ONNX GPU (A10G) │ 2ms │ 5ms │ 1200/s │
* └───────────────────────────────────┴──────────┴──────────┴────────────┘
*
* 观察:
* - CPU单线程已经比远程API快20倍以上
* - 批量推理并不总是更快(batch组装的等待延迟抵消了部分收益)
* - 4线程vs单线程:延迟略增(竞争),但吞吐量高得多
* - GPU在高并发场景优势明显,单次延迟极低
*
* 选型建议:
* - 实时单次推理(<10ms要求):CPU单线程/GPU
* - 高吞吐批量处理:CPU多线程
* - 成本敏感型业务:CPU已经足够
* - 大规模向量化任务(>100万文档):考虑GPU
*/实践建议
模型导出
ONNX模型通常从Python导出。常用命令:
# 从HuggingFace导出BERT类模型
optimum-cli export onnx \
--model BAAI/bge-small-zh-v1.5 \
--task feature-extraction \
--optimize O2 \
./bge-small-zh-onnx/
# O2优化级别会做算子融合,速度快30-50%
# 检查导出的模型
python3 -c "import onnx; m = onnx.load('model.onnx'); print(onnx.checker.check_model(m))"量化进一步压缩
FP16或INT8量化可以把模型大小减半,速度提升50-100%,精度损失通常在0.5%以内:
python3 -m onnxruntime.quantization.quantize \
--model_input model.onnx \
--model_output model_int8.onnx \
--quant_type IntegerOps实际落地中遇到的坑
有一次我们把一个情感分类模型从Python迁移到Java ONNX,预测结果总是偏差很大。排查了一天才发现:Python里的tokenizer用的是do_lower_case=True,Java这边加载tokenizer.json时没有这个配置,导致大写字母没有小写化,embedding差异巨大。教训:tokenizer的配置必须和训练时完全一致,一个参数的差异就会让模型精度崩掉。
另一个坑是内存泄漏。OnnxTensor用完如果不close,Native内存会持续增长,GC不会帮你释放(因为是JNI管理的内存)。一定要用try-with-resources,这点在高并发压测时会暴露得很明显。
