第1908篇:Spring AI集成Ollama的生产方案——本地模型的负载均衡与故障转移
第1908篇:Spring AI集成Ollama的生产方案——本地模型的负载均衡与故障转移
用 Ollama 跑本地模型这件事,很多团队在 POC 阶段都做过,成功跑起来了就觉得"挺简单的"。但一旦要把它搬到生产环境,问题就来了:一台机器上跑 Ollama 扛不住并发,加机器了但怎么负载均衡,某台 Ollama 节点挂了怎么故障转移,GPU 内存不够时怎么降级处理……
这些问题在开发环境完全感知不到,到了生产才发现是大麻烦。
这篇文章就从实际工程角度,讲一套 Spring AI + Ollama 的生产级部署方案,重点解决多节点负载均衡和故障转移的问题。
先聊 Ollama 的生产部署限制
在设计方案之前,得先了解 Ollama 的几个特性,这些特性直接影响架构设计:
1. 单进程单模型
Ollama 默认一次只加载一个模型到 GPU 内存(可以配置 OLLAMA_MAX_LOADED_MODELS,但受 GPU 内存限制)。如果你有多个请求用不同模型,会有模型卸载/加载的开销。
2. 并发请求的队列机制
Ollama 处理并发请求是排队执行的,不是真正的并行(除非你开了多实例)。这意味着单节点的并发能力受 GPU 处理速度限制,QPS 天花板不高。
3. 没有内置健康检查端点
Ollama 有 /api/health(0.4.x 以后),但没有返回 GPU 负载、当前队列深度等详细状态。需要自己实现监控。
4. 内存 OOM 是常见问题
如果请求的 context 长度超过 GPU 内存限制,Ollama 会报错而不是优雅降级。
多节点 Ollama 的负载均衡方案
核心思路:在 Spring AI 和 Ollama 之间加一层智能代理,实现健康检查、权重路由、故障转移。
@Component
public class OllamaNodeRegistry {
@Value("${ollama.nodes}")
private List<String> nodeUrls; // ["http://gpu-01:11434", "http://gpu-02:11434", ...]
private final ConcurrentHashMap<String, OllamaNodeInfo> nodes = new ConcurrentHashMap<>();
@PostConstruct
public void initNodes() {
nodeUrls.forEach(url -> {
OllamaNodeInfo node = OllamaNodeInfo.builder()
.url(url)
.status(NodeStatus.UNKNOWN)
.weight(1)
.currentLoad(new AtomicInteger(0))
.failureCount(new AtomicInteger(0))
.lastCheckTime(Instant.now())
.build();
nodes.put(url, node);
});
}
/**
* 获取当前可用的节点列表,按负载排序
*/
public List<OllamaNodeInfo> getAvailableNodes() {
return nodes.values().stream()
.filter(n -> n.getStatus() == NodeStatus.HEALTHY)
.sorted(Comparator.comparingDouble(this::calculateScore))
.collect(Collectors.toList());
}
/**
* 节点评分:负载越低、失败次数越少,分数越低(优先选择低分节点)
*/
private double calculateScore(OllamaNodeInfo node) {
double loadScore = node.getCurrentLoad().get();
double failurePenalty = node.getFailureCount().get() * 0.5;
double responseTimePenalty = node.getAvgResponseTimeMs() / 1000.0;
return loadScore + failurePenalty + responseTimePenalty;
}
public void updateStatus(String nodeUrl, NodeStatus status) {
OllamaNodeInfo node = nodes.get(nodeUrl);
if (node != null) {
node.setStatus(status);
node.setLastCheckTime(Instant.now());
if (status == NodeStatus.HEALTHY) {
node.getFailureCount().set(0); // 恢复健康时清零失败计数
}
}
}
public void recordFailure(String nodeUrl) {
OllamaNodeInfo node = nodes.get(nodeUrl);
if (node != null) {
int failures = node.getFailureCount().incrementAndGet();
if (failures >= 3) {
node.setStatus(NodeStatus.UNHEALTHY);
log.warn("节点连续失败 {} 次,标记为不健康: {}", failures, nodeUrl);
}
}
}
public void recordSuccess(String nodeUrl, long responseTimeMs) {
OllamaNodeInfo node = nodes.get(nodeUrl);
if (node != null) {
node.getFailureCount().set(0);
node.setStatus(NodeStatus.HEALTHY);
// 更新平均响应时间(指数移动平均)
double current = node.getAvgResponseTimeMs();
node.setAvgResponseTimeMs(current * 0.8 + responseTimeMs * 0.2);
}
}
}健康检查的定时任务:
@Component
public class OllamaHealthChecker {
@Autowired
private OllamaNodeRegistry nodeRegistry;
private final WebClient webClient = WebClient.builder()
.defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.build();
@Scheduled(fixedDelay = 10_000) // 每 10 秒检查一次
public void checkAllNodes() {
nodeRegistry.getAllNodes().forEach(this::checkNode);
}
private void checkNode(OllamaNodeInfo node) {
long startTime = System.currentTimeMillis();
webClient.get()
.uri(node.getUrl() + "/api/health")
.retrieve()
.toBodilessEntity()
.timeout(Duration.ofSeconds(5))
.subscribe(
response -> {
long responseTime = System.currentTimeMillis() - startTime;
nodeRegistry.updateStatus(node.getUrl(), NodeStatus.HEALTHY);
nodeRegistry.recordSuccess(node.getUrl(), responseTime);
log.debug("节点健康: {}, 响应时间: {}ms", node.getUrl(), responseTime);
},
error -> {
log.warn("节点健康检查失败: {}, 原因: {}",
node.getUrl(), error.getMessage());
nodeRegistry.recordFailure(node.getUrl());
}
);
}
/**
* 更详细的节点状态检查:查询当前运行中的模型和队列状态
*/
@Scheduled(fixedDelay = 30_000) // 每 30 秒做一次深度检查
public void deepCheckNodes() {
nodeRegistry.getAvailableNodes().forEach(node -> {
webClient.get()
.uri(node.getUrl() + "/api/ps") // Ollama 的进程状态接口
.retrieve()
.bodyToMono(OllamaProcessStatus.class)
.subscribe(status -> {
// 更新节点的实时负载信息
int activeRequests = status.getModels() != null
? status.getModels().size() : 0;
node.getCurrentLoad().set(activeRequests);
log.debug("节点 {} 当前加载模型数: {}",
node.getUrl(), activeRequests);
});
});
}
}实现负载均衡的 ChatModel 包装器
Spring AI 的 ChatModel 是接口,我们可以实现一个负载均衡版本:
@Component
@Primary // 替换默认的 OllamaChatModel
public class LoadBalancedOllamaChatModel implements ChatModel {
@Autowired
private OllamaNodeRegistry nodeRegistry;
// 缓存每个节点的 ChatModel 实例
private final ConcurrentHashMap<String, OllamaChatModel> modelCache =
new ConcurrentHashMap<>();
@Autowired
private OllamaChatOptions defaultOptions;
@Override
public ChatResponse call(Prompt prompt) {
List<OllamaNodeInfo> candidates = nodeRegistry.getAvailableNodes();
if (candidates.isEmpty()) {
throw new NoAvailableNodeException("所有 Ollama 节点不可用,请检查服务状态");
}
// 尝试按优先级执行,失败则自动故障转移
for (OllamaNodeInfo node : candidates) {
try {
node.getCurrentLoad().incrementAndGet();
long startTime = System.currentTimeMillis();
ChatModel model = getOrCreateModel(node.getUrl());
ChatResponse response = model.call(prompt);
long duration = System.currentTimeMillis() - startTime;
nodeRegistry.recordSuccess(node.getUrl(), duration);
node.getCurrentLoad().decrementAndGet();
log.debug("请求成功: node={}, durationMs={}", node.getUrl(), duration);
return response;
} catch (Exception e) {
node.getCurrentLoad().decrementAndGet();
log.warn("节点请求失败: {}, 尝试下一个节点。原因: {}",
node.getUrl(), e.getMessage());
nodeRegistry.recordFailure(node.getUrl());
// 如果是致命错误(认证失败、参数错误),不要尝试其他节点
if (isFatalError(e)) {
throw e;
}
// 否则继续尝试下一个节点
}
}
throw new AllNodesFailedException("所有可用节点都请求失败");
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
List<OllamaNodeInfo> candidates = nodeRegistry.getAvailableNodes();
if (candidates.isEmpty()) {
return Flux.error(new NoAvailableNodeException("所有 Ollama 节点不可用"));
}
// 流式调用选择当前负载最低的节点
OllamaNodeInfo selectedNode = candidates.get(0);
return Flux.defer(() -> {
selectedNode.getCurrentLoad().incrementAndGet();
long startTime = System.currentTimeMillis();
ChatModel model = getOrCreateModel(selectedNode.getUrl());
return model.stream(prompt)
.doOnComplete(() -> {
selectedNode.getCurrentLoad().decrementAndGet();
long duration = System.currentTimeMillis() - startTime;
nodeRegistry.recordSuccess(selectedNode.getUrl(), duration);
})
.doOnError(e -> {
selectedNode.getCurrentLoad().decrementAndGet();
nodeRegistry.recordFailure(selectedNode.getUrl());
log.error("流式请求失败: node={}", selectedNode.getUrl(), e);
})
.onErrorResume(e -> {
// 流式不容易做故障转移(因为已经开始输出),记录错误后上报
return Flux.error(new StreamFailedException(
"流式生成失败: " + e.getMessage(), e));
});
});
}
private OllamaChatModel getOrCreateModel(String nodeUrl) {
return modelCache.computeIfAbsent(nodeUrl, url -> {
OllamaApi api = new OllamaApi(url);
return new OllamaChatModel(api, defaultOptions);
});
}
private boolean isFatalError(Exception e) {
String message = e.getMessage();
return message != null && (
message.contains("model not found") ||
message.contains("invalid parameter") ||
message.contains("context length exceeded")
);
}
@Override
public ChatOptions getDefaultOptions() {
return defaultOptions;
}
}模型不可用时的降级策略
当所有本地 Ollama 节点都不可用时,需要有降级方案:
@Component
public class FallbackAwareChatService {
@Autowired
private LoadBalancedOllamaChatModel localModel;
// 配置一个云端模型作为降级备选(比如 OpenAI 或者国内的通义千问)
@Autowired
@Qualifier("fallbackCloudModel")
private ChatModel fallbackModel;
@Autowired
private FallbackPolicyConfig fallbackPolicy;
/**
* 带降级的模型调用
*/
public ChatResponse callWithFallback(Prompt prompt, String requestType) {
// 检查是否允许降级(某些场景不允许把数据发到云端)
if (!fallbackPolicy.isFallbackAllowed(requestType)) {
// 不允许降级,直接调用本地,失败就抛异常
return localModel.call(prompt);
}
try {
return localModel.call(prompt);
} catch (NoAvailableNodeException | AllNodesFailedException e) {
log.warn("本地 Ollama 不可用,降级到云端模型。requestType={}", requestType);
// 记录降级事件
alertService.sendAlert(AlertLevel.WARNING,
"Ollama 集群不可用,已自动降级到云端模型");
// 降级调用可能需要调整 prompt(不同模型的最佳 prompt 写法不同)
Prompt adaptedPrompt = adaptPromptForFallback(prompt);
return fallbackModel.call(adaptedPrompt);
}
}
private Prompt adaptPromptForFallback(Prompt prompt) {
// 如果 prompt 里有 Ollama 特定的参数,在这里做适配
// 比如某些 Ollama 专用的 template 格式
return prompt; // 简化处理,实际要根据模型差异调整
}
}降级配置:
fallback:
policy:
# 允许降级到云端的请求类型
allowed-types:
- PUBLIC_QA
- CUSTOMER_SERVICE
# 不允许降级的请求类型(涉及敏感数据)
restricted-types:
- INTERNAL_DOCS
- HR_ANALYSIS
- FINANCIAL_REPORT
# 降级时的告警阈值(百分比:降级请求占总请求的比例超过这个值时告警)
alert-threshold: 20Ollama 模型预热与内存管理
Ollama 冷启动时加载模型很慢(7B 模型大约需要 10-20 秒),生产环境需要预热:
@Component
public class OllamaModelWarmer {
@Value("${ollama.models.preload}")
private List<String> modelsToPreload;
@Autowired
private OllamaNodeRegistry nodeRegistry;
private final WebClient webClient = WebClient.create();
/**
* 服务启动后预热所有节点的目标模型
*/
@EventListener(ApplicationReadyEvent.class)
public void warmUpModels() {
log.info("开始预热 Ollama 模型: {}", modelsToPreload);
nodeRegistry.getAvailableNodes().forEach(node -> {
modelsToPreload.forEach(model -> warmUpModel(node.getUrl(), model));
});
}
private void warmUpModel(String nodeUrl, String modelName) {
// 发一个空请求让模型加载到内存
Map<String, Object> request = Map.of(
"model", modelName,
"prompt", "hi",
"stream", false
);
webClient.post()
.uri(nodeUrl + "/api/generate")
.bodyValue(request)
.retrieve()
.toBodilessEntity()
.timeout(Duration.ofSeconds(120)) // 首次加载可能很慢
.subscribe(
resp -> log.info("模型预热完成: node={}, model={}", nodeUrl, modelName),
err -> log.warn("模型预热失败: node={}, model={}, reason={}",
nodeUrl, modelName, err.getMessage())
);
}
/**
* 保持模型在内存中不被卸载(周期性发送 keep-alive 请求)
*/
@Scheduled(fixedDelay = 300_000) // 每 5 分钟一次
public void keepModelsAlive() {
nodeRegistry.getAvailableNodes().forEach(node -> {
modelsToPreload.forEach(model -> {
Map<String, Object> request = Map.of(
"model", model,
"keep_alive", "10m" // 保持 10 分钟
);
webClient.post()
.uri(node.getUrl() + "/api/generate")
.bodyValue(request)
.retrieve()
.toBodilessEntity()
.subscribe();
});
});
}
}多节点架构图
踩坑记录
坑1:Ollama 的并发限制比想象的严格
以为加了多台 Ollama 节点后并发就上去了,结果发现单个 Ollama 节点处理并发请求时,GPU 使用率会飙到 100%,响应时间急剧增加。
Ollama 本身有 OLLAMA_NUM_PARALLEL 环境变量控制并发数,默认是 1。要根据 GPU 内存和模型大小合理设置,不是越大越好。一般 24GB GPU 运行 7B 模型,OLLAMA_NUM_PARALLEL=2 比较合适。
坑2:节点恢复时的请求积压
某个节点挂了一段时间,恢复后健康检查判断它健康了,立刻给它发大量请求。由于刚启动的节点模型还在预热,这批请求都很慢,触发超时,节点又被标记为不健康……形成抖动循环。
解决方案:节点恢复后先进入"预热期",在预热期内只接受少量请求(比如总流量的 10%),过了一段时间再恢复正常权重。
// 节点恢复后的预热逻辑
public void markAsRecovering(String nodeUrl) {
OllamaNodeInfo node = nodes.get(nodeUrl);
if (node != null) {
node.setStatus(NodeStatus.RECOVERING);
node.setWeight(1); // 低权重,只接受少量流量
// 60 秒后恢复正常权重
scheduledExecutor.schedule(() -> {
if (node.getStatus() == NodeStatus.RECOVERING) {
node.setStatus(NodeStatus.HEALTHY);
node.setWeight(10); // 恢复正常权重
log.info("节点预热完成,恢复正常服务: {}", nodeUrl);
}
}, 60, TimeUnit.SECONDS);
}
}坑3:context length 导致的静默错误
用户发了一个很长的对话上下文,超过了 Ollama 配置的 num_ctx,Ollama 不报错,而是静默截断上下文,导致模型"失忆"。
解决方案:在 Spring AI 层做 token 数预估,超过阈值时主动触发内存压缩或给用户提示:
@Component
public class ContextLengthGuard {
@Value("${ollama.context.max-tokens:4096}")
private int maxContextTokens;
@Value("${ollama.context.warn-threshold:0.85}")
private double warnThreshold;
public void checkAndWarn(List<Message> messages) {
int estimatedTokens = estimateTokenCount(messages);
if (estimatedTokens > maxContextTokens) {
throw new ContextLengthExceededException(
String.format("对话上下文太长(约 %d tokens),超过模型限制(%d tokens)。" +
"请开始新对话或使用/clear命令清除历史。",
estimatedTokens, maxContextTokens));
}
if (estimatedTokens > maxContextTokens * warnThreshold) {
log.warn("对话上下文接近限制: {} tokens / {} 最大",
estimatedTokens, maxContextTokens);
}
}
private int estimateTokenCount(List<Message> messages) {
// 粗略估算:中文约 1.5 token/字,英文约 0.75 token/词
int chars = messages.stream()
.mapToInt(m -> m.getContent().length())
.sum();
return (int) (chars * 1.5);
}
}小结
Spring AI + Ollama 的生产方案,核心是围绕多节点管理做工程化:
- 健康检查要主动、频繁,且区分"浅检查"(ping)和"深检查"(负载状态)
- 负载均衡要考虑实时负载,不能用简单的轮询
- 故障转移分两层:节点级(在其他 Ollama 节点重试)和服务级(降级到云端 API)
- 模型预热和 keep-alive 机制避免冷启动延迟
- Context length 超限问题要在应用层提前处理
Ollama 在私有化部署场景下有很强的成本优势,但要做到生产可用,这些工程细节一个都不能省。
