AI 网关设计——统一管理公司所有大模型调用
2026/4/30大约 8 分钟
AI 网关设计——统一管理公司所有大模型调用
适读人群:有多个 AI 应用需要统一管理的技术负责人、架构师 | 阅读时长:约 15 分钟 | 核心价值:完整的企业级 AI 网关设计方案,从架构到实现
去年年底,我们公司同时在跑八个 AI 相关的项目。八个项目,各自维护自己的 API Key,各自写自己的限流逻辑,各自统计自己的成本,各自做自己的日志。
有一天 CFO 问我,「AI 这块每月花了多少?」
我花了三天才把各个项目的数据汇总起来,还有一个项目的 Key 被共享出去了,有一部分消耗根本查不到来源。
那之后我做了一件事:把所有大模型调用收归到一个统一的 AI 网关里。
为什么需要 AI 网关
在没有网关之前,典型的混乱状态:
项目A ──直接调OpenAI─────────────────┐
项目B ──直接调Anthropic──────────────┤──→ 各种大模型API
项目C ──直接调OpenAI─────────────────┤
项目D ──直接调Azure OpenAI────────────┘
问题:
- API Key分散,无法统一管理
- 各项目各自做限流,标准不统一
- 成本无法按部门统计
- 日志分散,出问题难以排查
- 某个模型挂了,各项目需要各自改代码切换有了 AI 网关之后:
项目A ─┐
项目B ─┤──→ AI Gateway ──→ OpenAI
项目C ─┤ ──→ Anthropic
项目D ─┘ ──→ Azure OpenAI
──→ 本地模型
Gateway负责:
- 统一鉴权
- 统一限流
- 统一日志
- 成本统计
- 自动故障切换
- 模型路由整体架构
┌─────────────────────────────────────────┐
│ AI Gateway │
│ │
请求入口 │ ┌──────────┐ ┌──────────────────┐ │
客户端 ─────────────────→ │ │ Auth │───→│ Rate Limiter │ │
│ │ 鉴权层 │ │ 限流层 │ │
│ └──────────┘ └────────┬─────────┘ │
│ │ │
│ ┌────────▼─────────┐ │
│ │ Router │ │
│ │ 模型路由层 │ │
│ └────────┬─────────┘ │
│ │ │
│ ┌───────────────────────▼──────────┐ │
│ │ Provider Manager │ │
│ │ OpenAI │ Anthropic │ Azure │ 本地│ │
│ └───────────────────────┬──────────┘ │
│ │ │
│ ┌──────────┐ ┌────────▼─────────┐ │
│ │ Cost │ │ Logger │ │
│ │ Tracker │ │ 日志层 │ │
│ └──────────┘ └──────────────────┘ │
└─────────────────────────────────────────┘
│
▼
┌──────────────────────┐
│ 大模型 API │
│ OpenAI / Claude / 其他│
└──────────────────────┘核心模块实现
模块一:统一鉴权层
每个接入网关的项目/服务都有自己的虚拟 API Key,和真实的大模型 Key 隔离:
from dataclasses import dataclass
from typing import Optional
import secrets
import hashlib
@dataclass
class ApiClient:
client_id: str
client_name: str
virtual_key: str # 分配给下游使用的虚拟Key
allowed_models: list # 允许使用的模型列表
department: str # 所属部门(用于成本统计)
monthly_budget_usd: float # 月度预算
is_active: bool
class GatewayAuthManager:
def __init__(self, db_client, redis_client):
self.db = db_client
self.redis = redis_client
def generate_virtual_key(self) -> str:
"""生成虚拟API Key"""
return 'aig-' + secrets.token_hex(24) # aig = AI Gateway
async def create_client(
self,
client_name: str,
department: str,
allowed_models: list,
monthly_budget_usd: float = 100.0
) -> ApiClient:
"""创建新的接入客户端"""
virtual_key = self.generate_virtual_key()
await self.db.execute(
"""INSERT INTO gateway_clients
(client_name, virtual_key, allowed_models, department, monthly_budget_usd, is_active)
VALUES (%s, %s, %s, %s, %s, 1)""",
(client_name, virtual_key, ','.join(allowed_models), department, monthly_budget_usd)
)
return ApiClient(
client_id=str(await self.db.lastrowid()),
client_name=client_name,
virtual_key=virtual_key,
allowed_models=allowed_models,
department=department,
monthly_budget_usd=monthly_budget_usd,
is_active=True
)
async def authenticate(self, virtual_key: str) -> Optional[ApiClient]:
"""验证虚拟Key,返回客户端信息"""
cache_key = f"gateway:auth:{hashlib.md5(virtual_key.encode()).hexdigest()}"
# 缓存查找(5分钟TTL)
cached = self.redis.get(cache_key)
if cached:
import json
data = json.loads(cached)
return ApiClient(**data)
# 数据库查找
row = await self.db.query_one(
"SELECT * FROM gateway_clients WHERE virtual_key = %s AND is_active = 1",
(virtual_key,)
)
if not row:
return None
client = ApiClient(
client_id=str(row['id']),
client_name=row['client_name'],
virtual_key=virtual_key,
allowed_models=row['allowed_models'].split(','),
department=row['department'],
monthly_budget_usd=row['monthly_budget_usd'],
is_active=True
)
# 写缓存
import json
self.redis.setex(cache_key, 300, json.dumps({
'client_id': client.client_id,
'client_name': client.client_name,
'virtual_key': virtual_key,
'allowed_models': client.allowed_models,
'department': client.department,
'monthly_budget_usd': client.monthly_budget_usd,
'is_active': True
}))
return client模块二:模型路由层
from typing import Optional
import random
@dataclass
class ProviderConfig:
provider_name: str # openai / anthropic / azure
model_id: str # 该provider的实际模型ID
api_key: str
base_url: Optional[str]
weight: int # 路由权重(用于负载均衡)
is_healthy: bool
input_cost_per_1k: float
output_cost_per_1k: float
class ModelRouter:
def __init__(self, redis_client):
self.redis = redis_client
# 模型别名到实际Provider配置的映射
# 一个逻辑模型可以有多个Provider(用于故障切换和负载均衡)
self._model_registry = {}
def register_model(self, logical_name: str, providers: list[ProviderConfig]):
"""注册一个逻辑模型及其Provider配置"""
self._model_registry[logical_name] = providers
def select_provider(self, logical_model: str) -> Optional[ProviderConfig]:
"""
选择一个健康的Provider
支持权重路由和故障切换
"""
providers = self._model_registry.get(logical_model, [])
if not providers:
return None
# 过滤掉不健康的Provider
healthy_providers = [p for p in providers if self._is_healthy(p)]
if not healthy_providers:
# 所有Provider都不健康,返回最后已知的(有降级的可能)
return providers[0] if providers else None
# 按权重随机选择
total_weight = sum(p.weight for p in healthy_providers)
rand_val = random.uniform(0, total_weight)
cumulative = 0
for provider in healthy_providers:
cumulative += provider.weight
if rand_val <= cumulative:
return provider
return healthy_providers[-1]
def _is_healthy(self, provider: ProviderConfig) -> bool:
"""检查Provider是否健康(通过Redis存储的健康状态)"""
health_key = f"gateway:health:{provider.provider_name}:{provider.model_id}"
status = self.redis.get(health_key)
if status is None:
return True # 没有记录,默认健康
return status.decode() == 'healthy'
def mark_unhealthy(self, provider: ProviderConfig, duration_seconds: int = 60):
"""标记Provider为不健康,持续N秒"""
health_key = f"gateway:health:{provider.provider_name}:{provider.model_id}"
self.redis.setex(health_key, duration_seconds, 'unhealthy')
def mark_healthy(self, provider: ProviderConfig):
health_key = f"gateway:health:{provider.provider_name}:{provider.model_id}"
self.redis.set(health_key, 'healthy')
# 初始化示例
router = ModelRouter(redis_client)
router.register_model('gpt-4o', [
ProviderConfig(
provider_name='openai',
model_id='gpt-4o',
api_key=OPENAI_API_KEY,
base_url=None,
weight=70,
is_healthy=True,
input_cost_per_1k=0.0025,
output_cost_per_1k=0.01
),
ProviderConfig(
provider_name='azure',
model_id='gpt-4o-2024-08-06',
api_key=AZURE_API_KEY,
base_url=AZURE_BASE_URL,
weight=30,
is_healthy=True,
input_cost_per_1k=0.0025,
output_cost_per_1k=0.01
),
])模块三:统一调用层
import asyncio
from openai import AsyncOpenAI
import anthropic
import time
class UnifiedLLMClient:
"""统一的LLM调用客户端,屏蔽底层Provider差异"""
def __init__(self, router: ModelRouter, cost_tracker, logger):
self.router = router
self.cost_tracker = cost_tracker
self.logger = logger
# 各Provider的客户端
self._openai_clients = {}
self._anthropic_client = None
def _get_openai_client(self, provider: ProviderConfig) -> AsyncOpenAI:
key = f"{provider.provider_name}:{provider.base_url}"
if key not in self._openai_clients:
self._openai_clients[key] = AsyncOpenAI(
api_key=provider.api_key,
base_url=provider.base_url
)
return self._openai_clients[key]
async def chat_complete(
self,
logical_model: str,
messages: list,
client: ApiClient,
**kwargs
) -> dict:
"""
统一的chat completion接口
自动处理Provider路由、故障切换、日志、计费
"""
# 检查模型权限
if logical_model not in client.allowed_models:
raise PermissionError(f"客户端 {client.client_name} 无权使用模型 {logical_model}")
provider = self.router.select_provider(logical_model)
if not provider:
raise ValueError(f"没有可用的Provider: {logical_model}")
start_time = time.time()
request_id = self._generate_request_id()
try:
response = await self._call_provider(
provider, messages, **kwargs
)
# 计算成本
usage = response.get('usage', {})
cost_usd = self._calculate_cost(provider, usage)
latency_ms = int((time.time() - start_time) * 1000)
# 记录成功日志
await self.logger.log_request(
request_id=request_id,
client_id=client.client_id,
department=client.department,
logical_model=logical_model,
actual_provider=provider.provider_name,
actual_model=provider.model_id,
input_tokens=usage.get('prompt_tokens', 0),
output_tokens=usage.get('completion_tokens', 0),
cost_usd=cost_usd,
latency_ms=latency_ms,
status='success'
)
# 更新成本统计
await self.cost_tracker.record(
client_id=client.client_id,
department=client.department,
cost_usd=cost_usd,
tokens=usage.get('total_tokens', 0)
)
# 标记Provider健康
self.router.mark_healthy(provider)
return response
except Exception as e:
latency_ms = int((time.time() - start_time) * 1000)
# 记录失败日志
await self.logger.log_request(
request_id=request_id,
client_id=client.client_id,
department=client.department,
logical_model=logical_model,
actual_provider=provider.provider_name,
actual_model=provider.model_id,
input_tokens=0,
output_tokens=0,
cost_usd=0,
latency_ms=latency_ms,
status='error',
error_message=str(e)
)
# 如果是Provider故障,标记不健康并尝试切换
if self._is_provider_error(e):
self.router.mark_unhealthy(provider, duration_seconds=120)
# 尝试用备用Provider重试
fallback_provider = self.router.select_provider(logical_model)
if fallback_provider and fallback_provider != provider:
return await self.chat_complete(
logical_model, messages, client, **kwargs
)
raise
async def _call_provider(self, provider: ProviderConfig, messages: list, **kwargs) -> dict:
"""根据Provider类型调用对应的API"""
if provider.provider_name in ('openai', 'azure'):
client = self._get_openai_client(provider)
response = await client.chat.completions.create(
model=provider.model_id,
messages=messages,
**kwargs
)
return response.model_dump()
elif provider.provider_name == 'anthropic':
# Anthropic API格式不同,需要做适配
if not self._anthropic_client:
self._anthropic_client = anthropic.AsyncAnthropic(api_key=provider.api_key)
# 从messages里提取system prompt
system_prompt = None
chat_messages = []
for msg in messages:
if msg['role'] == 'system':
system_prompt = msg['content']
else:
chat_messages.append(msg)
create_kwargs = {
'model': provider.model_id,
'max_tokens': kwargs.get('max_tokens', 1024),
'messages': chat_messages
}
if system_prompt:
create_kwargs['system'] = system_prompt
response = await self._anthropic_client.messages.create(**create_kwargs)
# 转换为统一格式
return {
'choices': [{
'message': {
'role': 'assistant',
'content': response.content[0].text
},
'finish_reason': response.stop_reason
}],
'usage': {
'prompt_tokens': response.usage.input_tokens,
'completion_tokens': response.usage.output_tokens,
'total_tokens': response.usage.input_tokens + response.usage.output_tokens
}
}
else:
raise ValueError(f"不支持的Provider: {provider.provider_name}")
def _calculate_cost(self, provider: ProviderConfig, usage: dict) -> float:
input_tokens = usage.get('prompt_tokens', 0)
output_tokens = usage.get('completion_tokens', 0)
return (input_tokens * provider.input_cost_per_1k +
output_tokens * provider.output_cost_per_1k) / 1000
def _is_provider_error(self, error: Exception) -> bool:
"""判断是否是Provider侧的故障(而非用户请求问题)"""
error_str = str(error).lower()
return any(keyword in error_str for keyword in [
'connection', 'timeout', 'service unavailable', '503', '502', '504'
])
def _generate_request_id(self) -> str:
import uuid
return str(uuid.uuid4())模块四:FastAPI 网关入口
from fastapi import FastAPI, Header, HTTPException, Request
from pydantic import BaseModel
app = FastAPI(title="AI Gateway")
# 初始化各组件(实际项目里用依赖注入)
auth_manager = GatewayAuthManager(db_client, redis_client)
model_router = ModelRouter(redis_client)
gateway_client = UnifiedLLMClient(model_router, cost_tracker, logger)
class ChatRequest(BaseModel):
model: str
messages: list
max_tokens: int = 1024
temperature: float = 1.0
stream: bool = False
@app.post("/v1/chat/completions")
async def chat_completions(
request: ChatRequest,
http_request: Request,
authorization: str = Header(None)
):
# 1. 提取API Key
if not authorization or not authorization.startswith('Bearer '):
raise HTTPException(status_code=401, detail="Missing API key")
virtual_key = authorization[7:] # 去掉 'Bearer '
# 2. 鉴权
client = await auth_manager.authenticate(virtual_key)
if not client:
raise HTTPException(status_code=401, detail="Invalid API key")
# 3. 限流检查(复用上一篇的RateLimiter)
rate_key = f"gateway:rate:{client.client_id}"
result, info = rate_limiter.check_and_increment(rate_key, limit=60, window_seconds=60)
if result == LimitResult.RATE_LIMITED:
raise HTTPException(
status_code=429,
headers={"Retry-After": str(info.get('retry_after', 60))},
detail="Rate limit exceeded"
)
# 4. 调用LLM
try:
response = await gateway_client.chat_complete(
logical_model=request.model,
messages=request.messages,
client=client,
max_tokens=request.max_tokens,
temperature=request.temperature
)
return response
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
@app.get("/v1/usage/summary")
async def usage_summary(
authorization: str = Header(None),
month: str = None # 格式:202401
):
"""查询使用量和成本摘要(管理接口)"""
# 鉴权(只有管理员Key才能访问)
client = await auth_manager.authenticate(authorization.replace('Bearer ', ''))
if not client or 'admin' not in client.allowed_models:
raise HTTPException(status_code=403)
return await cost_tracker.get_summary(month=month)部署注意事项
API 兼容性:网关暴露的接口格式要和 OpenAI 一致(/v1/chat/completions),这样下游项目改动最小——只改 base_url 和 api_key,代码不用动。
高可用部署:
负载均衡
/ \
网关实例1 网关实例2
\ /
Redis集群(共享限流状态)
|
数据库(日志、配置)网关本身是无状态的,状态存在 Redis 里,可以水平扩展。
延迟影响:增加一层网关会增加延迟。在我们的部署里,网关本身的处理时间在 2-5ms,对于通常 500ms+ 的 LLM 响应来说可以忽略。
实际收益:在我们公司用了这套方案后:
- 各项目再也不用各自维护 API Key,安全性大幅提升
- CFO 要成本数据,一个接口直接查,不用再花三天汇总
- 某次 OpenAI 抖动,网关自动切到 Azure,下游项目完全无感知
- 所有请求有完整日志,出了问题 5 分钟内能定位
搭这套东西大概花了两周,之后省的麻烦远不止这两周。
