异步工作流在 AI 应用中的实践——别让用户等着
异步工作流在 AI 应用中的实践——别让用户等着
适读人群:有后端开发经验的 AI 应用工程师 | 阅读时长:约 14 分钟 | 核心价值:一套完整的异步工作流方案,消息队列 + 状态轮询 + WebSocket,可直接用于生产
那是一个下午,我在公司内部演示一个 AI 文档处理功能。用户上传了一个 80 页的 PDF,我的程序开始处理——提取文本、分段、逐段调用 GPT-4 总结、最后汇总。
整个过程花了大概 2 分 20 秒。
在这 140 秒里,前端页面就是一个转圈圈的 loading。没有进度,没有反馈,什么都没有。产品经理当场皱眉,说"这个体验不行"。
他说得对。
同步模型在 AI 应用里行不通
传统 Web 请求是同步的。用户点击,服务器处理,返回结果,一般 500ms 以内完成。这个模型假设"处理时间很短"。
AI 任务打破了这个假设:
- GPT-4 生成一篇 2000 字的文章:15-30 秒
- 处理一个长 PDF:1-5 分钟
- 生成一张高质量图片(Midjourney/DALL-E 3):10-30 秒
- 跑一个多步骤 Agent:不确定,可能 30 秒,可能 5 分钟
用同步 HTTP 请求处理这些任务,你会面临:
超时问题。 Nginx、ALB、客户端都有超时配置,通常 60-120 秒。你的任务跑 3 分钟,HTTP 连接早就断了。
资源浪费。 一个线程/进程在那里 block 等 LLM 响应,完全是在浪费资源。
体验问题。 用户不知道任务到哪了,不知道要等多久,体验很差,复杂网络环境下还容易断连。
正确的做法是异步工作流:任务提交立刻返回,后台异步处理,通过轮询或推送告知结果。
完整架构
我在生产里用的方案,三层组合:
客户端
|
| 1. POST /tasks (提交任务)
v
API 服务
|
| 2. 写入任务到队列
v
消息队列 (Redis Streams / RabbitMQ / SQS)
|
| 3. Worker 消费任务
v
Worker 服务
|
| 4. 调用 LLM,处理结果
| 5. 更新任务状态到 Redis
| 6. 通过 WebSocket 推送进度
v
结果存储 (Redis + PostgreSQL)
|
| 7. 客户端查询或收到推送
v
客户端这套架构里:
- 消息队列负责任务分发和解耦
- Redis 负责任务状态的快速读写
- WebSocket 负责实时进度推送(可选,但体验差异很大)
- PostgreSQL 负责最终结果的持久化
下面逐层写代码。
第一层:任务提交和状态管理
先定义任务的状态机:
# task_models.py
from enum import Enum
from dataclasses import dataclass, field
from typing import Optional, Any
import time
import uuid
class TaskStatus(Enum):
PENDING = "pending" # 已提交,等待处理
PROCESSING = "processing" # 处理中
COMPLETED = "completed" # 完成
FAILED = "failed" # 失败
CANCELLED = "cancelled" # 取消
@dataclass
class TaskProgress:
current_step: int
total_steps: int
step_description: str
percentage: float = field(init=False)
def __post_init__(self):
self.percentage = (self.current_step / self.total_steps) * 100 if self.total_steps > 0 else 0
@dataclass
class Task:
task_id: str
task_type: str
payload: dict
status: TaskStatus = TaskStatus.PENDING
progress: Optional[TaskProgress] = None
result: Optional[Any] = None
error: Optional[str] = None
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
@classmethod
def create(cls, task_type: str, payload: dict) -> "Task":
return cls(
task_id=str(uuid.uuid4()),
task_type=task_type,
payload=payload
)任务状态用 Redis 存,快且支持过期:
# task_store.py
import redis
import json
from typing import Optional
from task_models import Task, TaskStatus, TaskProgress
class TaskStore:
def __init__(self, redis_client: redis.Redis, ttl_seconds: int = 86400):
self.redis = redis_client
self.ttl = ttl_seconds
def _key(self, task_id: str) -> str:
return f"task:{task_id}"
def save(self, task: Task) -> None:
import dataclasses
import time
task.updated_at = time.time()
data = {
"task_id": task.task_id,
"task_type": task.task_type,
"payload": task.payload,
"status": task.status.value,
"result": task.result,
"error": task.error,
"created_at": task.created_at,
"updated_at": task.updated_at,
"progress": {
"current_step": task.progress.current_step,
"total_steps": task.progress.total_steps,
"step_description": task.progress.step_description,
"percentage": task.progress.percentage,
} if task.progress else None
}
self.redis.setex(
self._key(task.task_id),
self.ttl,
json.dumps(data, ensure_ascii=False)
)
def get(self, task_id: str) -> Optional[Task]:
raw = self.redis.get(self._key(task_id))
if not raw:
return None
data = json.loads(raw)
task = Task(
task_id=data["task_id"],
task_type=data["task_type"],
payload=data["payload"],
status=TaskStatus(data["status"]),
result=data["result"],
error=data["error"],
created_at=data["created_at"],
updated_at=data["updated_at"],
)
if data.get("progress"):
p = data["progress"]
task.progress = TaskProgress(
current_step=p["current_step"],
total_steps=p["total_steps"],
step_description=p["step_description"]
)
return task
def update_status(self, task_id: str, status: TaskStatus,
progress: Optional[TaskProgress] = None,
result=None, error: Optional[str] = None) -> Optional[Task]:
task = self.get(task_id)
if not task:
return None
task.status = status
if progress:
task.progress = progress
if result is not None:
task.result = result
if error:
task.error = error
self.save(task)
return task第二层:消息队列和 Worker
API 层收到任务请求,把任务塞进队列,立刻返回:
# api.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import redis
import json
from task_models import Task
from task_store import TaskStore
app = FastAPI()
r = redis.Redis(host="localhost", port=6379, decode_responses=True)
task_store = TaskStore(r)
TASK_QUEUE = "ai_tasks"
class DocumentProcessRequest(BaseModel):
document_url: str
language: str = "zh"
output_format: str = "summary"
@app.post("/tasks/document-process")
async def submit_document_task(request: DocumentProcessRequest):
# 创建任务
task = Task.create(
task_type="document_process",
payload={
"document_url": request.document_url,
"language": request.language,
"output_format": request.output_format
}
)
# 保存初始状态
task_store.save(task)
# 推入队列
r.lpush(TASK_QUEUE, json.dumps({
"task_id": task.task_id,
"task_type": task.task_type,
"payload": task.payload
}))
# 立刻返回 task_id,不等待处理结果
return {
"task_id": task.task_id,
"status": "pending",
"poll_url": f"/tasks/{task.task_id}",
"websocket_url": f"/ws/tasks/{task.task_id}"
}
@app.get("/tasks/{task_id}")
async def get_task_status(task_id: str):
task = task_store.get(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
response = {
"task_id": task.task_id,
"status": task.status.value,
"created_at": task.created_at,
"updated_at": task.updated_at,
}
if task.progress:
response["progress"] = {
"percentage": task.progress.percentage,
"step_description": task.progress.step_description,
}
if task.status.value == "completed":
response["result"] = task.result
if task.status.value == "failed":
response["error"] = task.error
return responseWorker 消费队列,处理任务:
# worker.py
import redis
import json
import time
import logging
from openai import OpenAI
from task_models import Task, TaskStatus, TaskProgress
from task_store import TaskStore
from websocket_manager import WebSocketManager
logger = logging.getLogger(__name__)
r = redis.Redis(host="localhost", port=6379, decode_responses=True)
task_store = TaskStore(r)
openai_client = OpenAI()
ws_manager = WebSocketManager(r)
TASK_QUEUE = "ai_tasks"
def process_document(task_id: str, payload: dict):
"""处理文档摘要任务"""
def update_progress(step: int, total: int, description: str):
progress = TaskProgress(current_step=step, total_steps=total, step_description=description)
task_store.update_status(task_id, TaskStatus.PROCESSING, progress=progress)
ws_manager.publish_progress(task_id, progress)
# Step 1: 下载文档
update_progress(1, 5, "正在下载文档...")
document_text = download_document(payload["document_url"])
# Step 2: 分段
update_progress(2, 5, "正在分析文档结构...")
chunks = split_into_chunks(document_text, max_tokens=2000)
# Step 3: 逐段摘要
summaries = []
for i, chunk in enumerate(chunks):
desc = f"正在处理第 {i+1}/{len(chunks)} 段..."
# 在步骤3里细化进度
detailed_progress = TaskProgress(
current_step=2 + (i + 1) / len(chunks),
total_steps=5,
step_description=desc
)
task_store.update_status(task_id, TaskStatus.PROCESSING, progress=detailed_progress)
ws_manager.publish_progress(task_id, detailed_progress)
response = openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "请用简洁的中文概括以下文本的核心内容,不超过200字。"},
{"role": "user", "content": chunk}
]
)
summaries.append(response.choices[0].message.content)
# Step 4: 汇总
update_progress(4, 5, "正在生成最终摘要...")
combined = "\n\n".join(summaries)
final_response = openai_client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "基于以下各段摘要,生成一份结构清晰的完整文档摘要,包含主要观点和关键信息。"},
{"role": "user", "content": combined}
]
)
final_summary = final_response.choices[0].message.content
# Step 5: 完成
update_progress(5, 5, "处理完成")
task_store.update_status(
task_id,
TaskStatus.COMPLETED,
result={"summary": final_summary, "chunk_count": len(chunks)}
)
ws_manager.publish_completion(task_id, {"summary": final_summary})
def run_worker():
"""Worker 主循环"""
logger.info("Worker 启动,等待任务...")
task_handlers = {
"document_process": process_document,
}
while True:
try:
# BRPOP:阻塞等待,超时 1 秒
result = r.brpop(TASK_QUEUE, timeout=1)
if not result:
continue
_, raw_task = result
task_data = json.loads(raw_task)
task_id = task_data["task_id"]
task_type = task_data["task_type"]
logger.info(f"开始处理任务 {task_id},类型: {task_type}")
# 标记为处理中
task_store.update_status(task_id, TaskStatus.PROCESSING)
handler = task_handlers.get(task_type)
if not handler:
task_store.update_status(
task_id, TaskStatus.FAILED,
error=f"Unknown task type: {task_type}"
)
continue
handler(task_id, task_data["payload"])
logger.info(f"任务 {task_id} 处理完成")
except Exception as e:
logger.error(f"处理任务时出错: {e}", exc_info=True)
if 'task_id' in locals():
task_store.update_status(
task_id, TaskStatus.FAILED,
error=str(e)
)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
run_worker()第三层:WebSocket 实时推送
轮询可以工作,但 WebSocket 体验更好。我用 Redis Pub/Sub 做跨进程通信,Worker 发布进度,API 服务订阅并推给客户端:
# websocket_manager.py
import redis
import json
from typing import Optional
class WebSocketManager:
"""
负责通过 Redis Pub/Sub 发布任务进度
Worker 调用 publish_*,API 服务订阅并转发给 WebSocket 客户端
"""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
def _channel(self, task_id: str) -> str:
return f"task_progress:{task_id}"
def publish_progress(self, task_id: str, progress) -> None:
message = {
"type": "progress",
"task_id": task_id,
"percentage": progress.percentage,
"step_description": progress.step_description,
}
self.redis.publish(self._channel(task_id), json.dumps(message))
def publish_completion(self, task_id: str, result: dict) -> None:
message = {
"type": "completed",
"task_id": task_id,
"result": result,
}
self.redis.publish(self._channel(task_id), json.dumps(message))
def publish_error(self, task_id: str, error: str) -> None:
message = {
"type": "failed",
"task_id": task_id,
"error": error,
}
self.redis.publish(self._channel(task_id), json.dumps(message))FastAPI 的 WebSocket 端点:
# websocket_endpoint.py
from fastapi import WebSocket, WebSocketDisconnect
import redis
import json
import asyncio
from task_store import TaskStore
r = redis.Redis(host="localhost", port=6379, decode_responses=True)
task_store = TaskStore(r)
@app.websocket("/ws/tasks/{task_id}")
async def websocket_task_progress(websocket: WebSocket, task_id: str):
await websocket.accept()
# 先检查任务是否存在
task = task_store.get(task_id)
if not task:
await websocket.send_json({"type": "error", "message": "Task not found"})
await websocket.close()
return
# 如果任务已经完成,直接返回结果
if task.status.value in ("completed", "failed"):
await websocket.send_json({
"type": task.status.value,
"result": task.result,
"error": task.error,
})
await websocket.close()
return
# 订阅进度频道
pubsub = r.pubsub()
channel = f"task_progress:{task_id}"
pubsub.subscribe(channel)
try:
while True:
message = pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
if message:
data = json.loads(message["data"])
await websocket.send_json(data)
# 任务结束,关闭连接
if data["type"] in ("completed", "failed"):
break
await asyncio.sleep(0.05) # 50ms 轮询 pubsub
except WebSocketDisconnect:
pass
finally:
pubsub.unsubscribe(channel)
pubsub.close()前端怎么接
给个简单的前端示例,React 风格的伪代码展示逻辑:
// 提交任务
async function submitDocument(file) {
const response = await fetch('/tasks/document-process', {
method: 'POST',
body: JSON.stringify({ document_url: file.url }),
headers: { 'Content-Type': 'application/json' }
});
const { task_id, websocket_url } = await response.json();
// 连接 WebSocket
connectWebSocket(task_id, websocket_url);
}
function connectWebSocket(taskId, wsUrl) {
const ws = new WebSocket(`ws://localhost:8000${wsUrl}`);
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
switch (data.type) {
case 'progress':
updateProgressBar(data.percentage, data.step_description);
break;
case 'completed':
showResult(data.result);
ws.close();
break;
case 'failed':
showError(data.error);
ws.close();
break;
}
};
ws.onerror = () => {
// WebSocket 失败时降级到轮询
startPolling(taskId);
};
}
// 降级轮询方案
function startPolling(taskId) {
const interval = setInterval(async () => {
const response = await fetch(`/tasks/${taskId}`);
const data = await response.json();
if (data.status === 'completed') {
showResult(data.result);
clearInterval(interval);
} else if (data.status === 'failed') {
showError(data.error);
clearInterval(interval);
} else if (data.progress) {
updateProgressBar(data.progress.percentage, data.progress.step_description);
}
}, 2000); // 每 2 秒轮询一次
}WebSocket 挂了自动降级到轮询,这个容错逻辑很重要,生产里网络环境复杂。
几个工程细节
并发控制。 Worker 不要无限并发,LLM API 有速率限制。我用信号量控制:
import asyncio
# 最多同时处理 5 个 LLM 调用
llm_semaphore = asyncio.Semaphore(5)
async def call_llm_with_limit(prompt: str):
async with llm_semaphore:
# 调用 LLM
...任务优先级。 付费用户的任务应该先处理。可以用多个队列实现:
# 两个队列,先消费高优先级
queues = ["ai_tasks:premium", "ai_tasks:standard"]
result = r.brpop(queues, timeout=1)任务超时。 Worker 里要设置任务最长处理时间,避免某个任务卡死占用 Worker:
import signal
class TimeoutError(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutError("Task processing timeout")
# 设置 5 分钟超时
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(300)
try:
handler(task_id, payload)
except TimeoutError:
task_store.update_status(task_id, TaskStatus.FAILED, error="Task timeout")
finally:
signal.alarm(0) # 取消超时重复消费保护。 Worker 崩溃时任务可能被重复消费,要做幂等处理:
def process_document(task_id: str, payload: dict):
task = task_store.get(task_id)
# 如果任务已经完成,直接跳过
if task and task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED):
logger.warning(f"任务 {task_id} 已处理,跳过重复消费")
return
# ... 继续处理这套方案的实际效果
我把公司那个 PDF 处理功能改成异步之后:
- 用户提交文档,立刻看到"任务已提交",0 延迟
- 进度条实时更新,每一步处理完都有反馈
- 任务完成后 WebSocket 推送,结果自动展示
- 处理失败有明确错误提示,可以重试
之前产品经理说"体验不行",改完之后他说"这个感觉靠谱多了"。
用户愿不愿意等,很大程度上取决于他们在等待时能不能看到进展。同样等 2 分钟,一个空白 loading 和一个有进度反馈的 loading,用户感受完全不同。
