Spring AI 的多模型路由——同一个请求怎么智能分配给不同模型
Spring AI 的多模型路由——同一个请求怎么智能分配给不同模型
今年二月份,老板把我叫去谈了一次话,主题是:AI 这块每个月烧了多少钱,能不能控制一下。
我去看了一眼账单:GPT-4o 的 API 费用一个月接近 3 万块,这还不算我们用的 Claude。
然后我仔细分析了一下请求日志,发现了一个让我很难受的事实:大约 60% 的请求,根本不需要用 GPT-4o 这个级别的模型。
那 60% 是什么?是用户问「帮我改一下这段话的语气」、「给我写一个 SQL 注释」、「把这个英文翻译成中文」这类简单任务。这类任务用 GPT-3.5-turbo 或者 Qwen-plus 完全够,价格差了将近 10 倍。
从那以后,我花了一个月搭了一套模型路由系统,最终把每月的 API 费用降到了 1.1 万,降幅接近 63%。
这篇文章就写这套路由系统是怎么设计和实现的。
路由的核心思想:不是所有问题都需要 GPT-4
先说清楚路由的几个维度:
按任务复杂度路由
简单任务(翻译、格式化、简单改写)→ 小模型(GPT-3.5 / Qwen-plus / deepseek-chat)
复杂任务(代码审查、逻辑推理、长文创作)→ 大模型(GPT-4o / Claude-3.5-sonnet)
按响应时间要求路由
对响应时延敏感的场景(实时对话、流式输出)→ 小模型(延迟低)
对质量要求高但可以等的场景(批处理、报告生成)→ 大模型
按成本预算路由
每个用户/账号有 token 预算,预算充足时用大模型,接近上限时降级到小模型
按上下文长度路由
上下文很长(超过 8K tokens)→ 支持长上下文的模型(Claude 支持 200K)
短上下文 → 任何模型都行,选便宜的
路由决策流程
复杂度评估的实现
路由的难点在于:怎么判断一个问题「复杂不复杂」?
方案一:用规则(关键词匹配)。简单,但覆盖不全,维护成本高。
方案二:用一个轻量级模型先判断复杂度,再路由。这本身又多了一次 API 调用,而且有延迟。
方案三:基于启发式规则 + 上下文特征的评分。不需要额外的 API 调用,根据 prompt 的多个特征打分,综合评估。
我用的是方案三,效果够用,成本最低。
@Component
public class ComplexityEvaluator {
/**
* 评估任务复杂度,返回 0.0 ~ 1.0 的评分
* 0.0 = 极简单,1.0 = 极复杂
*/
public double evaluate(String userMessage, String systemPrompt) {
double score = 0.0;
// 特征 1:消息长度(越长通常越复杂)
int messageLength = userMessage.length();
if (messageLength > 2000) {
score += 0.3;
} else if (messageLength > 500) {
score += 0.15;
} else {
score += 0.05;
}
// 特征 2:是否包含代码(代码相关任务通常需要更强的推理)
if (containsCode(userMessage)) {
score += 0.2;
}
// 特征 3:是否有推理/分析类关键词
if (containsReasoningKeywords(userMessage)) {
score += 0.25;
}
// 特征 4:是否是简单格式化/翻译任务
if (isSimpleFormattingTask(userMessage)) {
score -= 0.2; // 降低评分
}
// 特征 5:问题中的子问题数量(多个问题通常更复杂)
long questionCount = countQuestions(userMessage);
if (questionCount > 3) {
score += 0.2;
} else if (questionCount > 1) {
score += 0.1;
}
// 特征 6:system prompt 的复杂度
if (systemPrompt != null && systemPrompt.length() > 1000) {
score += 0.1;
}
return Math.max(0.0, Math.min(1.0, score));
}
private boolean containsCode(String text) {
return text.contains("```") ||
text.contains("class ") ||
text.contains("def ") ||
text.matches(".*\\b(function|var|const|let|import|package)\\b.*");
}
private boolean containsReasoningKeywords(String text) {
String lower = text.toLowerCase();
String[] keywords = {
"分析", "推理", "为什么", "原因", "比较", "对比",
"优缺点", "评估", "设计", "架构", "方案", "策略",
"analyze", "compare", "design", "evaluate", "explain"
};
for (String keyword : keywords) {
if (lower.contains(keyword)) return true;
}
return false;
}
private boolean isSimpleFormattingTask(String text) {
String lower = text.toLowerCase();
String[] patterns = {
"翻译", "translate", "格式化", "format",
"改写语气", "润色", "纠错", "改正语法"
};
for (String pattern : patterns) {
if (lower.contains(pattern)) return true;
}
return false;
}
private long countQuestions(String text) {
return text.chars().filter(c -> c == '?' || c == '?').count();
}
}ModelRouter 的核心实现
public enum ModelTier {
SMALL("小模型"),
MEDIUM("中等模型"),
LARGE("大模型"),
LONG_CONTEXT("长上下文模型");
private final String displayName;
ModelTier(String displayName) { this.displayName = displayName; }
}
@Data
public class RoutingDecision {
private ModelTier tier;
private String modelName;
private String reason;
private double complexityScore;
private int estimatedTokens;
}@Component
public class ModelRouter {
private static final Logger log = LoggerFactory.getLogger(ModelRouter.class);
// 不同级别模型的 Spring AI ChatModel Bean
private final Map<ModelTier, ChatModel> modelMap;
private final ComplexityEvaluator complexityEvaluator;
private final UserBudgetService budgetService;
// 配置阈值
private static final int LONG_CONTEXT_THRESHOLD = 32000; // 超过这个 token 数用长上下文模型
private static final double HIGH_COMPLEXITY_THRESHOLD = 0.7;
private static final double LOW_COMPLEXITY_THRESHOLD = 0.3;
public ModelRouter(
@Qualifier("gpt4oChatModel") ChatModel gpt4oModel,
@Qualifier("qwenMaxChatModel") ChatModel qwenMaxModel,
@Qualifier("qwenPlusChatModel") ChatModel qwenPlusModel,
@Qualifier("claudeChatModel") ChatModel claudeModel,
ComplexityEvaluator complexityEvaluator,
UserBudgetService budgetService) {
this.modelMap = new EnumMap<>(ModelTier.class);
this.modelMap.put(ModelTier.LARGE, gpt4oModel);
this.modelMap.put(ModelTier.MEDIUM, qwenMaxModel);
this.modelMap.put(ModelTier.SMALL, qwenPlusModel);
this.modelMap.put(ModelTier.LONG_CONTEXT, claudeModel);
this.complexityEvaluator = complexityEvaluator;
this.budgetService = budgetService;
}
/**
* 路由决策:根据请求特征选择合适的模型
*/
public RoutingDecision route(String userId, String userMessage, String systemPrompt) {
RoutingDecision decision = new RoutingDecision();
// Step 1: 估算 token 数量
int estimatedTokens = estimateTokens(userMessage, systemPrompt);
decision.setEstimatedTokens(estimatedTokens);
// Step 2: 长上下文检查(优先级最高)
if (estimatedTokens > LONG_CONTEXT_THRESHOLD) {
decision.setTier(ModelTier.LONG_CONTEXT);
decision.setModelName("claude-3-5-sonnet");
decision.setReason("上下文超过 32K tokens,使用长上下文模型");
log.info("路由决策: userId={}, tier=LONG_CONTEXT, tokens={}", userId, estimatedTokens);
return decision;
}
// Step 3: 复杂度评估
double complexityScore = complexityEvaluator.evaluate(userMessage, systemPrompt);
decision.setComplexityScore(complexityScore);
// Step 4: 根据复杂度确定初始 tier
ModelTier initialTier;
if (complexityScore >= HIGH_COMPLEXITY_THRESHOLD) {
initialTier = ModelTier.LARGE;
} else if (complexityScore >= LOW_COMPLEXITY_THRESHOLD) {
initialTier = ModelTier.MEDIUM;
} else {
initialTier = ModelTier.SMALL;
}
// Step 5: 预算检查(可能降级)
ModelTier finalTier = applyBudgetConstraint(userId, initialTier);
String reason = String.format(
"复杂度评分=%.2f,初始tier=%s,预算调整后tier=%s",
complexityScore, initialTier, finalTier
);
decision.setTier(finalTier);
decision.setModelName(getModelName(finalTier));
decision.setReason(reason);
log.info("路由决策: userId={}, tier={}, score={}, reason={}",
userId, finalTier, complexityScore, reason);
return decision;
}
/**
* 获取对应 tier 的 ChatModel
*/
public ChatModel getModel(ModelTier tier) {
return modelMap.get(tier);
}
private ModelTier applyBudgetConstraint(String userId, ModelTier requestedTier) {
if (userId == null || "anonymous".equals(userId)) {
// 匿名用户直接降级到 SMALL
return ModelTier.SMALL;
}
UserBudget budget = budgetService.getBudget(userId);
if (budget == null || budget.getRemainingBudget() <= 0) {
// 预算耗尽,强制使用最小模型
log.warn("用户 {} 预算耗尽,降级到 SMALL 模型", userId);
return ModelTier.SMALL;
}
// 根据剩余预算比例决定是否降级
double budgetRatio = budget.getRemainingBudget() / (double) budget.getTotalBudget();
if (budgetRatio < 0.1 && requestedTier == ModelTier.LARGE) {
// 剩余预算不足 10%,大模型请求降级到中等
log.info("用户 {} 预算不足,从 LARGE 降级到 MEDIUM", userId);
return ModelTier.MEDIUM;
}
return requestedTier;
}
private int estimateTokens(String userMessage, String systemPrompt) {
// 粗略估算:中文约 1.5 字/token,英文约 4 字符/token
// 这里用简化公式,生产中可以用 tiktoken 精确计算
int userTokens = (int) (userMessage.length() / 1.5);
int systemTokens = systemPrompt != null ? (int) (systemPrompt.length() / 1.5) : 0;
return userTokens + systemTokens;
}
private String getModelName(ModelTier tier) {
return switch (tier) {
case LARGE -> "gpt-4o";
case MEDIUM -> "qwen-max";
case SMALL -> "qwen-plus";
case LONG_CONTEXT -> "claude-3-5-sonnet";
};
}
}路由服务的整合调用
@Service
public class RoutedChatService {
private final ModelRouter router;
private final Map<ModelTier, ChatClient> chatClientMap;
private final RoutingMetricsService metricsService;
public RoutedChatService(ModelRouter router,
/* 各模型的 ChatClient... */ ) {
this.router = router;
// 初始化每个 tier 对应的 ChatClient
this.chatClientMap = buildChatClientMap();
}
public String chat(String userId, String userMessage) {
return chat(userId, userMessage, null);
}
public String chat(String userId, String userMessage, String systemPrompt) {
// 1. 路由决策
RoutingDecision decision = router.route(userId, userMessage, systemPrompt);
// 2. 拿到对应的 ChatClient
ChatClient chatClient = chatClientMap.get(decision.getTier());
// 3. 执行调用
long startTime = System.currentTimeMillis();
String response;
try {
ChatClient.ChatClientRequest request = chatClient.prompt()
.user(userMessage);
if (systemPrompt != null) {
request = request.system(systemPrompt);
}
response = request.call().content();
} catch (Exception e) {
// 调用失败时降级到小模型重试
log.warn("模型 {} 调用失败,降级重试: {}", decision.getModelName(), e.getMessage());
response = fallbackToSmallModel(userMessage, systemPrompt);
}
long latency = System.currentTimeMillis() - startTime;
// 4. 记录路由指标
metricsService.record(userId, decision, latency);
return response;
}
private String fallbackToSmallModel(String userMessage, String systemPrompt) {
ChatClient smallClient = chatClientMap.get(ModelTier.SMALL);
ChatClient.ChatClientRequest request = smallClient.prompt().user(userMessage);
if (systemPrompt != null) {
request = request.system(systemPrompt);
}
return request.call().content();
}
}多模型的 Spring Boot 配置
同时配置多个模型的关键在于区分 Bean 名称:
spring:
ai:
openai:
api-key: ${OPENAI_API_KEY}
chat:
options:
model: gpt-4o
# 通义千问用 OpenAI 兼容接口
# 需要额外配置两个 ChatModel Bean
# 自定义多模型配置
app:
models:
qwen-max:
base-url: https://dashscope.aliyuncs.com/compatible-mode/v1
api-key: ${DASHSCOPE_API_KEY}
model: qwen-max
qwen-plus:
base-url: https://dashscope.aliyuncs.com/compatible-mode/v1
api-key: ${DASHSCOPE_API_KEY}
model: qwen-plus
claude:
api-key: ${ANTHROPIC_API_KEY}
model: claude-3-5-sonnet-20241022@Configuration
public class MultiModelConfig {
@Bean("gpt4oChatModel")
@Primary
public ChatModel gpt4oChatModel(OpenAiChatModel openAiChatModel) {
return openAiChatModel; // 使用 Spring AI 自动配置的 OpenAI 模型
}
@Bean("qwenMaxChatModel")
public ChatModel qwenMaxChatModel(
@Value("${app.models.qwen-max.base-url}") String baseUrl,
@Value("${app.models.qwen-max.api-key}") String apiKey,
@Value("${app.models.qwen-max.model}") String model) {
OpenAiApi openAiApi = new OpenAiApi(baseUrl, apiKey);
return new OpenAiChatModel(openAiApi,
OpenAiChatOptions.builder()
.withModel(model)
.withTemperature(0.7)
.build());
}
@Bean("qwenPlusChatModel")
public ChatModel qwenPlusChatModel(
@Value("${app.models.qwen-plus.base-url}") String baseUrl,
@Value("${app.models.qwen-plus.api-key}") String apiKey,
@Value("${app.models.qwen-plus.model}") String model) {
OpenAiApi openAiApi = new OpenAiApi(baseUrl, apiKey);
return new OpenAiChatModel(openAiApi,
OpenAiChatOptions.builder()
.withModel(model)
.withTemperature(0.7)
.build());
}
@Bean("claudeChatModel")
public ChatModel claudeChatModel(AnthropicChatModel anthropicChatModel) {
return anthropicChatModel;
}
}路由效果监控
不监控就不知道路由是否按预期工作,我加了一个简单的指标记录:
@Service
public class RoutingMetricsService {
private final MeterRegistry meterRegistry;
public void record(String userId, RoutingDecision decision, long latencyMs) {
// 记录路由到各 tier 的次数
Counter.builder("ai.routing.requests")
.tag("tier", decision.getTier().name())
.tag("model", decision.getModelName())
.register(meterRegistry)
.increment();
// 记录响应时间
Timer.builder("ai.routing.latency")
.tag("tier", decision.getTier().name())
.register(meterRegistry)
.record(latencyMs, TimeUnit.MILLISECONDS);
// 记录复杂度分布
DistributionSummary.builder("ai.routing.complexity")
.register(meterRegistry)
.record(decision.getComplexityScore());
}
}上线一个月后,从 Grafana 里看路由分布:
- SMALL 模型:58% 的请求
- MEDIUM 模型:30% 的请求
- LARGE 模型:11% 的请求
- LONG_CONTEXT 模型:1% 的请求
和我们的预期基本吻合。LARGE 模型用量从原来的 100% 降到了 11%,费用自然大幅下降。
路由系统的局限性和后续改进
说几个真实的局限:
评分不够准确:启发式规则永远比不上真正理解语义。有时候「帮我分析一下这段代码的时间复杂度」被评为简单任务(因为消息不长),但其实需要较强的推理。
冷启动问题:新用户第一次使用,没有历史数据,只能靠规则,准确率会低一些。
用户体验的权衡:有时候用户觉得回答质量不够好,但他们不知道背后用的是小模型。这个需要在产品层面做好预期管理。
后续考虑的改进方向:
- 引入在线学习:记录哪些请求被用户评为「不满意」,这些请求下次路由到更高级的模型
- 基于历史数据自动调整阈值
- 支持用户手动指定「这次用最好的模型」
路由系统不是银弹,但在成本和质量之间找到平衡点,这是每个 AI 应用迟早要面对的问题。
