AI 应用的消息队列解耦——异步处理长时间 AI 任务
AI 应用的消息队列解耦——异步处理长时间 AI 任务
"这个功能为什么这么慢?我点了提交,等了 30 秒才出结果。"
这是我们的产品 lead 在内测时的反馈。他说的是我们刚上线的"合同审查"功能——用户上传一份合同 PDF,AI 分析风险条款,给出审查报告。
那个功能的技术实现很朴素:用户 POST 上传文件 → 服务端解析 PDF → 分块 → 向量检索相关条款 → 发给 GPT-4 逐条分析 → 汇总报告 → HTTP 返回。
整个流程跑下来,短合同要 25 秒,长合同要 2 分钟。
25 秒的 HTTP 请求是不可接受的。用户的浏览器会焦虑,移动端的 4G 网络可能中途断连,Nginx 的超时配置默认是 60 秒,长合同直接超时 502。
这件事让我开始认真设计异步 AI 任务处理框架。
哪些 AI 任务应该异步处理
先把判断标准说清楚,不是所有 AI 任务都要异步:
适合同步处理(< 5 秒):
- 单轮对话问答
- 短文本分类
- 关键词提取
- 简单的 RAG 召回问答
必须异步处理(> 10 秒):
- 长文档 AI 分析(合同审查、报告生成)
- 批量文档处理(一次性索引几百份文档)
- 多轮 Agent 任务(需要多步骤推理+工具调用)
- 视频/音频转录后的 AI 分析
- 定时 AI 批处理任务
判断标准很简单:用户能等 10 秒,能等 30 秒,但绝对不能等 2 分钟。超过 10 秒的操作,都应该考虑异步化。
异步任务系统的整体架构
任务状态机设计
在动手写代码之前,先把任务状态机设计好。状态流转是异步任务系统最容易出问题的地方:
public enum TaskStatus {
PENDING, // 已创建,等待执行
QUEUED, // 已入队,等待 Worker 拾取
RUNNING, // Worker 正在执行
COMPLETED, // 执行成功
FAILED, // 执行失败(可重试)
DEAD_LETTER, // 重试耗尽,进入死信(需要人工介入)
CANCELLED // 用户取消
}
// 允许的状态流转
private static final Map<TaskStatus, Set<TaskStatus>> VALID_TRANSITIONS = Map.of(
PENDING, Set.of(QUEUED, CANCELLED),
QUEUED, Set.of(RUNNING, CANCELLED),
RUNNING, Set.of(COMPLETED, FAILED),
FAILED, Set.of(QUEUED, DEAD_LETTER), // 重试或进死信
DEAD_LETTER, Set.of(QUEUED) // 人工介入后可重新入队
);
public boolean canTransitionTo(TaskStatus current, TaskStatus next) {
Set<TaskStatus> allowed = VALID_TRANSITIONS.getOrDefault(current, Set.of());
return allowed.contains(next);
}任务创建:HTTP 同步接口变异步
@RestController
@RequestMapping("/api/v1/tasks")
public class AITaskController {
@Autowired
private AITaskService taskService;
/**
* 提交 AI 任务(异步)
* 立即返回 task_id,不等待执行结果
*/
@PostMapping("/contract-review")
public ResponseEntity<TaskSubmitResponse> submitContractReview(
@RequestParam("file") MultipartFile file,
@RequestHeader("X-Tenant-Id") String tenantId,
@RequestHeader("X-User-Id") String userId
) {
// 1. 保存文件到 OSS,获取文件引用
String fileKey = ossService.upload(file, tenantId);
// 2. 创建任务记录(状态:PENDING)
AITask task = taskService.createTask(AITaskCreateRequest.builder()
.tenantId(tenantId)
.userId(userId)
.taskType(TaskType.CONTRACT_REVIEW)
.inputParams(Map.of("file_key", fileKey, "file_name", file.getOriginalFilename()))
.build()
);
// 3. 发送到 Kafka 队列(状态变为:QUEUED)
taskService.enqueue(task);
// 4. 立即返回 task_id(HTTP 200,不等待执行)
return ResponseEntity.accepted().body(
TaskSubmitResponse.builder()
.taskId(task.getId())
.status(TaskStatus.QUEUED)
.estimatedSeconds(estimateProcessingTime(file.getSize()))
.pollUrl("/api/v1/tasks/" + task.getId() + "/status")
.sseUrl("/api/v1/tasks/" + task.getId() + "/progress")
.build()
);
}
/**
* 查询任务状态(轮询接口)
*/
@GetMapping("/{taskId}/status")
public TaskStatusResponse getStatus(@PathVariable String taskId,
@RequestHeader("X-Tenant-Id") String tenantId) {
AITask task = taskService.getTask(taskId, tenantId);
return TaskStatusResponse.builder()
.taskId(task.getId())
.status(task.getStatus())
.progress(task.getProgressPercent())
.currentStep(task.getCurrentStep())
.result(task.getStatus() == TaskStatus.COMPLETED ? task.getResult() : null)
.errorMessage(task.getStatus() == TaskStatus.FAILED ? task.getErrorMessage() : null)
.createdAt(task.getCreatedAt())
.completedAt(task.getCompletedAt())
.build();
}
/**
* SSE 进度推送(实时进度更新)
*/
@GetMapping(value = "/{taskId}/progress", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter streamProgress(@PathVariable String taskId,
@RequestHeader("X-Tenant-Id") String tenantId) {
SseEmitter emitter = new SseEmitter(300_000L); // 5 分钟超时
// 注册 SSE 连接
sseService.register(taskId, emitter);
// 立即发送当前状态(避免客户端空等)
AITask task = taskService.getTask(taskId, tenantId);
try {
emitter.send(SseEmitter.event()
.name("status")
.data(buildProgressEvent(task)));
} catch (Exception e) {
emitter.complete();
}
return emitter;
}
private int estimateProcessingTime(long fileSize) {
// 粗略估算:每 100KB 约 5 秒
return (int)(fileSize / 1024 / 100 * 5) + 10;
}
}Kafka + Spring AI:异步任务消费者
@Service
@Slf4j
public class ContractReviewWorker {
@Autowired
private AITaskRepository taskRepo;
@Autowired
private OSSService ossService;
@Autowired
private RAGService ragService;
@Autowired
private AICapabilityService aiService;
@Autowired
private SseProgressService sseService;
@KafkaListener(
topics = "ai-tasks-contract-review",
groupId = "contract-review-workers",
concurrency = "3" // 3 个并发消费者,同时处理 3 个合同
)
public void processContractReview(
@Payload AITaskMessage message,
@Header(KafkaHeaders.RECEIVED_TOPIC) String topic,
Acknowledgment ack
) {
String taskId = message.getTaskId();
// 幂等检查:如果任务已完成,直接 ack 跳过(Kafka 重试场景)
AITask task = taskRepo.findById(taskId).orElse(null);
if (task == null) {
log.warn("Task {} not found, skipping", taskId);
ack.acknowledge();
return;
}
if (task.getStatus() == TaskStatus.COMPLETED ||
task.getStatus() == TaskStatus.DEAD_LETTER) {
log.info("Task {} already in terminal state {}, skipping", taskId, task.getStatus());
ack.acknowledge();
return;
}
try {
// 更新状态为 RUNNING
updateTaskStatus(task, TaskStatus.RUNNING, 0, "开始处理合同...");
// 执行实际处理逻辑
ContractReviewResult result = doReview(task, message);
// 更新状态为 COMPLETED
task.setStatus(TaskStatus.COMPLETED);
task.setResult(serialize(result));
task.setCompletedAt(Instant.now());
task.setProgressPercent(100);
taskRepo.save(task);
// 推送完成事件
sseService.pushCompletion(taskId, result);
// 手动 ack(确认消息处理成功)
ack.acknowledge();
} catch (RetryableException e) {
// 可重试的错误(比如 API 临时限速):不 ack,让 Kafka 重试
log.warn("Retryable error for task {}: {}", taskId, e.getMessage());
updateTaskStatus(task, TaskStatus.FAILED, task.getProgressPercent(),
"处理失败,即将重试:" + e.getMessage());
task.setRetryCount(task.getRetryCount() + 1);
taskRepo.save(task);
// 不 ack,Kafka 会根据配置重新投递
} catch (Exception e) {
// 不可重试的错误:ack 并标记为失败
log.error("Non-retryable error for task {}: {}", taskId, e.getMessage(), e);
task.setStatus(TaskStatus.FAILED);
task.setErrorMessage(e.getMessage());
task.setCompletedAt(Instant.now());
taskRepo.save(task);
sseService.pushError(taskId, e.getMessage());
ack.acknowledge();
}
}
private ContractReviewResult doReview(AITask task, AITaskMessage message) {
String fileKey = (String) message.getInputParams().get("file_key");
String tenantId = task.getTenantId();
// Step 1: 下载并解析合同
updateTaskStatus(task, TaskStatus.RUNNING, 10, "解析合同文本...");
byte[] fileContent = ossService.download(fileKey);
String contractText = pdfParser.extractText(fileContent);
// Step 2: 分块
updateTaskStatus(task, TaskStatus.RUNNING, 20, "分析合同结构...");
List<String> clauses = contractParser.extractClauses(contractText);
// Step 3: 向量检索相关法规(RAG)
updateTaskStatus(task, TaskStatus.RUNNING, 35, "检索相关法规...");
List<DocumentChunk> relevantLaws = ragService.retrieve(
tenantId, "合同风险条款", 20
);
// Step 4: 逐条分析(最耗时的部分)
List<ClauseAnalysis> analyses = new ArrayList<>();
int totalClauses = clauses.size();
for (int i = 0; i < clauses.size(); i++) {
String clause = clauses.get(i);
int progress = 35 + (int)((double)(i + 1) / totalClauses * 50);
updateTaskStatus(task, TaskStatus.RUNNING, progress,
String.format("分析第 %d/%d 条款...", i + 1, totalClauses));
ClauseAnalysis analysis = analyzeClause(clause, relevantLaws, tenantId);
analyses.add(analysis);
}
// Step 5: 汇总报告
updateTaskStatus(task, TaskStatus.RUNNING, 90, "生成审查报告...");
String summary = generateSummary(analyses, tenantId);
updateTaskStatus(task, TaskStatus.RUNNING, 95, "完成...");
return ContractReviewResult.builder()
.summary(summary)
.clauseAnalyses(analyses)
.riskLevel(assessOverallRisk(analyses))
.build();
}
private ClauseAnalysis analyzeClause(String clause, List<DocumentChunk> laws, String tenantId) {
String prompt = String.format("""
你是一位专业的合同法律顾问,请分析以下合同条款的风险:
合同条款:
%s
相关法规参考:
%s
请从以下维度给出分析:
1. 风险等级(高/中/低)
2. 具体风险描述
3. 修改建议
""", clause, formatLawContext(laws));
return aiService.chat(ChatRequest.builder()
.appId("contract-review")
.userMessage(prompt)
.modelPreference(ModelPreference.BEST) // 合同审查用最好的模型
.build());
}
private void updateTaskStatus(AITask task, TaskStatus status,
int progress, String currentStep) {
task.setStatus(status);
task.setProgressPercent(progress);
task.setCurrentStep(currentStep);
taskRepo.save(task);
// 推送进度到 SSE 订阅者
sseService.pushProgress(task.getId(), progress, currentStep);
}
}SSE 进度推送服务
SSE(Server-Sent Events)比轮询好在哪里?服务端主动推送,客户端不需要频繁轮询,实时性更好,对服务端的压力也更小。
@Service
public class SseProgressService {
// 任务 ID → SSE 发射器 Map
private final ConcurrentHashMap<String, List<SseEmitter>> emitters =
new ConcurrentHashMap<>();
/**
* 注册新的 SSE 连接
*/
public void register(String taskId, SseEmitter emitter) {
emitters.computeIfAbsent(taskId, k -> new CopyOnWriteArrayList<>()).add(emitter);
// 连接断开时清理
emitter.onCompletion(() -> removeEmitter(taskId, emitter));
emitter.onTimeout(() -> removeEmitter(taskId, emitter));
emitter.onError(e -> removeEmitter(taskId, emitter));
}
/**
* 推送进度更新
*/
public void pushProgress(String taskId, int progress, String message) {
sendToAll(taskId, "progress",
Map.of("progress", progress, "message", message, "timestamp", Instant.now()));
}
/**
* 推送任务完成
*/
public void pushCompletion(String taskId, Object result) {
sendToAll(taskId, "completed",
Map.of("progress", 100, "result", result, "timestamp", Instant.now()));
// 完成后清理连接
emitters.remove(taskId);
}
/**
* 推送错误
*/
public void pushError(String taskId, String errorMessage) {
sendToAll(taskId, "error",
Map.of("message", errorMessage, "timestamp", Instant.now()));
emitters.remove(taskId);
}
private void sendToAll(String taskId, String eventName, Object data) {
List<SseEmitter> taskEmitters = emitters.get(taskId);
if (taskEmitters == null || taskEmitters.isEmpty()) return;
List<SseEmitter> deadEmitters = new ArrayList<>();
for (SseEmitter emitter : taskEmitters) {
try {
emitter.send(SseEmitter.event()
.name(eventName)
.data(objectMapper.writeValueAsString(data)));
} catch (Exception e) {
deadEmitters.add(emitter); // 发送失败说明连接已断开
}
}
taskEmitters.removeAll(deadEmitters);
}
private void removeEmitter(String taskId, SseEmitter emitter) {
List<SseEmitter> taskEmitters = emitters.get(taskId);
if (taskEmitters != null) {
taskEmitters.remove(emitter);
if (taskEmitters.isEmpty()) {
emitters.remove(taskId);
}
}
}
}任务轮询 vs SSE 推送:怎么选
两种方案不是非此即彼,可以同时提供,让客户端根据自己的需求选择:
轮询适合:
- 移动端(SSE 在部分 Android 浏览器支持不稳定)
- 任务执行时间不确定,用户可能关掉页面过一会儿再来查
- 不需要细粒度进度显示,只需要知道是否完成
SSE 适合:
- 需要实时进度条
- 用户在页面上等待,希望有反馈感
- 任务执行有多个明确的阶段("解析中"、"分析中"、"生成报告")
客户端代码示例(JavaScript):
// SSE 进度监听
function trackTaskProgress(taskId) {
const eventSource = new EventSource(`/api/v1/tasks/${taskId}/progress`);
eventSource.addEventListener('progress', (event) => {
const data = JSON.parse(event.data);
updateProgressBar(data.progress, data.message);
});
eventSource.addEventListener('completed', (event) => {
const data = JSON.parse(event.data);
showResult(data.result);
eventSource.close();
});
eventSource.addEventListener('error', (event) => {
const data = JSON.parse(event.data);
showError(data.message);
eventSource.close();
});
// 连接错误时降级到轮询
eventSource.onerror = () => {
eventSource.close();
fallbackToPoll(taskId);
};
}
// 轮询降级方案
async function fallbackToPoll(taskId) {
const poll = async () => {
const res = await fetch(`/api/v1/tasks/${taskId}/status`);
const data = await res.json();
if (data.status === 'COMPLETED') {
showResult(data.result);
return;
} else if (data.status === 'FAILED') {
showError(data.errorMessage);
return;
}
updateProgressBar(data.progress, data.currentStep);
setTimeout(poll, 3000); // 3 秒轮询一次
};
poll();
}Kafka 配置:确保 AI 任务可靠投递
# application.yml - Kafka 配置
spring:
kafka:
bootstrap-servers: kafka:9092
producer:
acks: all # 等待所有副本确认,保证投递可靠性
retries: 3
key-serializer: org.apache.kafka.common.serialization.StringSerializer
value-serializer: org.springframework.kafka.support.serializer.JsonSerializer
consumer:
group-id: ai-task-workers
auto-offset-reset: earliest
enable-auto-commit: false # 手动 ack,确保任务处理完成再提交偏移量
max-poll-records: 1 # 每次只拉取 1 条消息(AI 任务处理时间长,不要批量)
max-poll-interval-ms: 300000 # 5 分钟,给 AI 任务足够的处理时间
listener:
ack-mode: manual_immediate # 手动确认
concurrency: 3 # 3 个消费线程
# 主题配置(通过代码或 Kafka Admin 创建)
# ai-tasks-contract-review:3 个分区,2 个副本
# ai-tasks-dead-letter:存放处理失败的消息重试和死信队列配置:
@Configuration
public class KafkaConfig {
@Bean
public DefaultErrorHandler errorHandler(KafkaTemplate<String, Object> kafkaTemplate) {
// 发送到死信 Topic 的恢复器
DeadLetterPublishingRecoverer recoverer = new DeadLetterPublishingRecoverer(
kafkaTemplate,
(record, ex) -> new TopicPartition("ai-tasks-dead-letter", record.partition())
);
// 重试策略:3 次重试,指数退避
ExponentialBackOffWithMaxRetries backOff = new ExponentialBackOffWithMaxRetries(3);
backOff.setInitialInterval(1000L); // 1 秒后第一次重试
backOff.setMultiplier(3.0); // 每次重试间隔 3 倍
backOff.setMaxInterval(30000L); // 最长等 30 秒
DefaultErrorHandler handler = new DefaultErrorHandler(recoverer, backOff);
// 某些异常不重试,直接进死信
handler.addNotRetryableExceptions(
InvalidTaskException.class,
IllegalArgumentException.class
);
return handler;
}
}一次真实的故障:消费者处理太慢导致堆积
上线第一周,我遇到了一个问题:Kafka 消费者组的 lag(堆积量)持续增长,已经堆了 300 个任务没处理完。
排查下来原因是:max-poll-interval-ms 设置太小(默认 5 分钟,但某些超长合同需要 8 分钟处理),消费者被认为"挂了",Kafka 触发了 rebalance,把消息重新分配,导致同一条消息被反复消费又反复 rebalance,真正处理的任务越来越少。
解决方案:
- 把
max-poll-interval-ms调高到 15 分钟(留足余量) - 在 Worker 里每处理一个阶段,通过
consumer.pause()+ 心跳线程来避免超时
// 长时间任务的心跳保持(防止 Kafka 认为消费者挂了)
@Component
public class TaskHeartbeatManager {
private final ScheduledExecutorService scheduler =
Executors.newSingleThreadScheduledExecutor();
private final Set<String> activeTaskIds = ConcurrentHashMap.newKeySet();
@Autowired
private KafkaListenerEndpointRegistry registry;
public void startHeartbeat(String taskId) {
activeTaskIds.add(taskId);
}
public void stopHeartbeat(String taskId) {
activeTaskIds.remove(taskId);
}
@Scheduled(fixedDelay = 60_000) // 每分钟记录一次活跃任务
public void reportActiveWork() {
if (!activeTaskIds.isEmpty()) {
log.info("Active AI tasks: {}", activeTaskIds);
}
}
}小结
AI 异步任务处理的核心设计要点:
- 任务状态机要严格:定义清晰的状态流转,防止任务卡在中间状态
- 幂等消费:Worker 要能安全地处理重复消息(断点续传)
- 进度推送:用 SSE 主动推进度,比轮询用户体验好
- 可靠投递:Kafka 手动 ack + 死信队列,确保任务不丢失
- 超时配置:
max-poll-interval-ms必须大于最长任务的处理时间
把同步的 AI 任务改成异步之后,用户的等待体验从"茫然等待"变成了"有进度反馈的等待",转化率和满意度都有明显提升。
