Spring AI与Kafka集成:构建异步AI处理流水线
Spring AI与Kafka集成:构建异步AI处理流水线
那个让200个用户骂了整整一夜的Bug
2024年9月,我的朋友林晓东在一家法律科技公司做研发总监,他们做了一个合同AI审查系统——用户上传PDF合同,AI自动分析风险条款,给出修改建议。
上线第一周,产品大卖。律师们纷纷称赞,效率提升明显。
但第二周,出事了。
周三晚上,一家大型律所把一批合同(共217份)同时提交给系统审查,每份合同平均30页,AI处理一份大约需要45秒。
这217个HTTP请求同时打进来,全部在等待AI处理完成返回。45秒之后,Nginx的upstream timeout触发了——所有请求返回502。
用户界面显示"服务器错误"。律师们不知道任务有没有处理,重新提交,又是217个请求……
林晓东是被一通投诉电话叫醒的,那时是凌晨2点,服务器还在转圈,OOM日志在飞速滚动。
他打电话给我,我问了他一个问题:"你的AI调用是同步的还是异步的?"
沉默了三秒。
"同步的。"
就是这个问题。
同步调用AI,对于长任务来说是架构性的错误。你不能让HTTP连接等待AI的45秒处理,这会耗尽连接池,最终拖垮整个服务。
解决方案只有一个:异步处理 + 消息队列。
这篇文章,就是我们当时重构的完整方案。
先说结论(TL;DR)
| 处理模式 | 适用场景 | 最大等待时间 | 推荐方案 |
|---|---|---|---|
| 同步调用 | 简单问答、<3秒响应 | 10秒 | Spring AI直接调用 |
| 异步+轮询 | 单文档处理、进度可见 | 分钟级 | Kafka + SSE进度推送 |
| 批量异步 | 批量文档处理 | 小时级 | Kafka Streams |
| 实时流处理 | 实时分析、监控 | 秒级 | Kafka + Spring AI Stream |
核心架构决策:
- HTTP接口只负责接收任务、返回任务ID(<100ms)
- 真正的AI处理放在Kafka Consumer中
- 用户通过SSE(Server-Sent Events)获取实时进度
整体架构设计
核心实现一:消息格式设计
package com.laozhang.kafka.message;
import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import lombok.Data;
import lombok.experimental.SuperBuilder;
import java.time.Instant;
import java.util.Map;
/**
* AI任务消息基类
* 设计原则:消息必须包含追踪所需的全部元数据,消息体要幂等,消息大小控制在1MB以内
*/
@Data
@SuperBuilder
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "messageType")
@JsonSubTypes({
@JsonSubTypes.Type(value = DocumentUploadMessage.class, name = "DOCUMENT_UPLOAD"),
@JsonSubTypes.Type(value = DocumentExtractMessage.class, name = "DOCUMENT_EXTRACT"),
@JsonSubTypes.Type(value = DocumentEmbedMessage.class, name = "DOCUMENT_EMBED"),
@JsonSubTypes.Type(value = DocumentCompleteMessage.class, name = "DOCUMENT_COMPLETE")
})
public abstract class AiTaskMessage {
/** 全局唯一的任务ID(用户请求时生成,贯穿整个流水线) */
private String taskId;
/** 租户ID(多租户系统必须) */
private String tenantId;
/** 用户ID */
private String userId;
/** 消息发送时间 */
private Instant sentAt;
/** 消息版本(用于Schema演进) */
private int version = 1;
/** 重试次数 */
private int retryCount = 0;
/** 追踪上下文(分布式追踪) */
private Map<String, String> traceContext;
/** 任务优先级 */
private TaskPriority priority = TaskPriority.NORMAL;
public enum TaskPriority { NORMAL, HIGH, URGENT }
}package com.laozhang.kafka.message;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
/**
* 文档上传消息(流水线第一步)
*/
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class DocumentUploadMessage extends AiTaskMessage {
/** 文档在对象存储中的Key(不直接包含文件内容!) */
private String objectStorageKey;
private String originalFileName;
private long fileSizeBytes;
private String fileType;
private String language;
private String documentCategory;
private String webhookUrl;
}package com.laozhang.kafka.message;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
import java.util.List;
import java.util.Map;
/**
* 文档嵌入消息(流水线第三步)
*/
@Data
@SuperBuilder
@EqualsAndHashCode(callSuper = true)
public class DocumentEmbedMessage extends AiTaskMessage {
private Long documentId;
private List<TextChunk> chunks;
private String embeddingModel;
@Data
public static class TextChunk {
private int chunkIndex;
private String content;
private int tokenCount;
private Map<String, Object> metadata;
}
}核心实现二:文件上传与任务分发
package com.laozhang.kafka.service;
import com.laozhang.kafka.message.DocumentUploadMessage;
import com.laozhang.kafka.entity.ProcessingTask;
import com.laozhang.kafka.repository.ProcessingTaskRepository;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.kafka.support.SendResult;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.time.Instant;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
/**
* 文档处理任务提交服务
* 立即返回taskId,不等待AI处理
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class DocumentTaskSubmitService {
private final ObjectStorageService objectStorageService;
private final ProcessingTaskRepository taskRepository;
private final KafkaTemplate<String, AiTaskMessage> kafkaTemplate;
private static final String UPLOAD_TOPIC = "document.upload";
public String submitDocument(MultipartFile file, String tenantId,
String userId, String documentCategory) {
String taskId = UUID.randomUUID().toString();
log.info("提交文档处理任务: taskId={}, file={}, tenant={}",
taskId, file.getOriginalFilename(), tenantId);
String objectKey = uploadToObjectStorage(taskId, file);
ProcessingTask task = ProcessingTask.builder()
.taskId(taskId)
.tenantId(tenantId)
.userId(userId)
.originalFileName(file.getOriginalFilename())
.fileSizeBytes(file.getSize())
.status(ProcessingTask.Status.PENDING)
.progress(0)
.submittedAt(Instant.now())
.build();
taskRepository.save(task);
DocumentUploadMessage message = DocumentUploadMessage.builder()
.taskId(taskId)
.tenantId(tenantId)
.userId(userId)
.objectStorageKey(objectKey)
.originalFileName(file.getOriginalFilename())
.fileSizeBytes(file.getSize())
.fileType(detectFileType(file))
.documentCategory(documentCategory)
.sentAt(Instant.now())
.build();
CompletableFuture<SendResult<String, AiTaskMessage>> future =
kafkaTemplate.send(UPLOAD_TOPIC, taskId, message);
future.whenComplete((result, ex) -> {
if (ex != null) {
log.error("Kafka消息发送失败: taskId={}", taskId, ex);
taskRepository.updateStatus(taskId, ProcessingTask.Status.FAILED,
"消息队列发送失败: " + ex.getMessage());
} else {
log.debug("Kafka消息发送成功: taskId={}, partition={}, offset={}",
taskId,
result.getRecordMetadata().partition(),
result.getRecordMetadata().offset());
}
});
return taskId;
}
private String uploadToObjectStorage(String taskId, MultipartFile file) {
String objectKey = "documents/" + taskId + "/" + file.getOriginalFilename();
try {
objectStorageService.upload(objectKey, file.getInputStream(),
file.getContentType(), file.getSize());
return objectKey;
} catch (Exception e) {
throw new DocumentUploadException("文件上传到对象存储失败", e);
}
}
private String detectFileType(MultipartFile file) {
String name = file.getOriginalFilename();
if (name == null) return "UNKNOWN";
if (name.endsWith(".pdf")) return "PDF";
if (name.endsWith(".docx") || name.endsWith(".doc")) return "DOCX";
if (name.endsWith(".txt")) return "TXT";
return "UNKNOWN";
}
}核心实现三:文本提取Worker
package com.laozhang.kafka.consumer;
import com.laozhang.kafka.message.DocumentUploadMessage;
import com.laozhang.kafka.message.DocumentExtractMessage;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.springframework.kafka.annotation.KafkaListener;
import org.springframework.kafka.annotation.RetryableTopic;
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.kafka.retrytopic.DltStrategy;
import org.springframework.kafka.retrytopic.TopicSuffixingStrategy;
import org.springframework.retry.annotation.Backoff;
import org.springframework.stereotype.Component;
/**
* 文本提取Worker
* 消费 document.upload 消息,提取文本内容,发布到下一步
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class DocumentExtractConsumer {
private final ObjectStorageService objectStorageService;
private final DocumentTextExtractor textExtractor;
private final KafkaTemplate<String, AiTaskMessage> kafkaTemplate;
private final ProgressReporter progressReporter;
private final ProcessingTaskRepository taskRepository;
private static final String EXTRACT_OUTPUT_TOPIC = "document.extract";
/**
* @RetryableTopic: 自动重试3次,指数退避,失败发死信队列
*/
@RetryableTopic(
attempts = "3",
backoff = @Backoff(delay = 1000, multiplier = 2.0, maxDelay = 30000),
dltStrategy = DltStrategy.FAIL_ON_ERROR,
topicSuffixingStrategy = TopicSuffixingStrategy.SUFFIX_WITH_INDEX_VALUE,
dltTopicSuffix = ".dlq"
)
@KafkaListener(
topics = "document.upload",
groupId = "document-extract-group",
concurrency = "5"
)
public void consume(ConsumerRecord<String, DocumentUploadMessage> record) {
DocumentUploadMessage message = record.value();
String taskId = message.getTaskId();
log.info("开始提取文档文本: taskId={}, file={}", taskId, message.getOriginalFileName());
try {
taskRepository.updateStatus(taskId, ProcessingTask.Status.EXTRACTING);
progressReporter.report(taskId, 10, "正在提取文档内容...");
byte[] fileContent = objectStorageService.download(message.getObjectStorageKey());
ExtractedText extractedText = textExtractor.extract(
fileContent, message.getFileType(), message.getLanguage());
log.info("文本提取完成: taskId={}, 页数={}, 字符数={}",
taskId, extractedText.getPageCount(), extractedText.getContent().length());
progressReporter.report(taskId, 25, "文档内容提取完成,准备分块...");
documentRepository.saveExtractedText(taskId, extractedText);
DocumentExtractMessage nextMessage = DocumentExtractMessage.builder()
.taskId(taskId)
.tenantId(message.getTenantId())
.userId(message.getUserId())
.extractedText(extractedText.getContent())
.pageCount(extractedText.getPageCount())
.documentCategory(message.getDocumentCategory())
.sentAt(java.time.Instant.now())
.build();
kafkaTemplate.send(EXTRACT_OUTPUT_TOPIC, taskId, nextMessage);
} catch (UnsupportedFileTypeException e) {
log.error("不支持的文件类型: taskId={}, type={}", taskId, message.getFileType());
taskRepository.updateStatus(taskId, ProcessingTask.Status.FAILED,
"不支持的文件类型: " + message.getFileType());
progressReporter.reportError(taskId, "不支持的文件类型");
// 不抛出异常,不触发重试
} catch (Exception e) {
log.error("文本提取失败: taskId={}", taskId, e);
progressReporter.reportError(taskId, "文本提取失败: " + e.getMessage());
throw e; // 触发@RetryableTopic重试
}
}
/**
* 死信队列处理器 - 重试耗尽后进入DLQ
*/
@KafkaListener(topics = "document.upload.dlq", groupId = "document-dlq-group")
public void handleDlq(ConsumerRecord<String, DocumentUploadMessage> record) {
DocumentUploadMessage message = record.value();
log.error("文档提取彻底失败,进入DLQ: taskId={}", message.getTaskId());
taskRepository.updateStatus(message.getTaskId(),
ProcessingTask.Status.FAILED_PERMANENT,
"重试" + message.getRetryCount() + "次后仍然失败");
progressReporter.reportFinalError(message.getTaskId(),
"文档处理失败,请检查文件格式是否正确,或联系客服");
alertService.sendAlert(
String.format("文档处理彻底失败: taskId=%s, file=%s",
message.getTaskId(), message.getOriginalFileName()));
}
}核心实现四:AI嵌入Worker
package com.laozhang.kafka.consumer;
import com.laozhang.kafka.message.DocumentEmbedMessage;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.kafka.annotation.KafkaListener;
import org.springframework.kafka.annotation.RetryableTopic;
import org.springframework.retry.annotation.Backoff;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* 向量嵌入Worker
* 调用Spring AI生成向量并存储
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class DocumentEmbedConsumer {
private final EmbeddingModel embeddingModel;
private final VectorStore vectorStore;
private final KafkaTemplate<String, AiTaskMessage> kafkaTemplate;
private final ProgressReporter progressReporter;
private final ProcessingTaskRepository taskRepository;
private static final int BATCH_SIZE = 20;
@RetryableTopic(
attempts = "5",
backoff = @Backoff(delay = 2000, multiplier = 2.0, maxDelay = 60000),
include = {OpenAiRateLimitException.class, java.net.SocketTimeoutException.class}
)
@KafkaListener(
topics = "document.embed",
groupId = "document-embed-group",
concurrency = "3"
)
public void consume(ConsumerRecord<String, DocumentEmbedMessage> record) {
DocumentEmbedMessage message = record.value();
String taskId = message.getTaskId();
List<DocumentEmbedMessage.TextChunk> chunks = message.getChunks();
log.info("开始向量嵌入: taskId={}, chunks={}", taskId, chunks.size());
taskRepository.updateStatus(taskId, ProcessingTask.Status.EMBEDDING);
List<Document> allDocuments = new ArrayList<>();
int processedChunks = 0;
for (int i = 0; i < chunks.size(); i += BATCH_SIZE) {
List<DocumentEmbedMessage.TextChunk> batch =
chunks.subList(i, Math.min(i + BATCH_SIZE, chunks.size()));
List<Document> batchDocuments = batch.stream()
.map(chunk -> {
Map<String, Object> metadata = new java.util.HashMap<>(chunk.getMetadata());
metadata.put("task_id", taskId);
metadata.put("tenant_id", message.getTenantId());
metadata.put("chunk_index", chunk.getChunkIndex());
return new Document(chunk.getContent(), metadata);
})
.toList();
try {
vectorStore.add(batchDocuments);
allDocuments.addAll(batchDocuments);
processedChunks += batch.size();
int progress = 50 + (int) ((double) processedChunks / chunks.size() * 40);
progressReporter.report(taskId, progress,
String.format("向量化进度:%d/%d", processedChunks, chunks.size()));
} catch (Exception e) {
log.error("批次嵌入失败: taskId={}, batchStart={}", taskId, i, e);
throw e;
}
}
DocumentCompleteMessage completeMessage = DocumentCompleteMessage.builder()
.taskId(taskId)
.tenantId(message.getTenantId())
.userId(message.getUserId())
.totalChunks(chunks.size())
.totalDocumentsIndexed(allDocuments.size())
.sentAt(java.time.Instant.now())
.build();
kafkaTemplate.send("document.complete", taskId, completeMessage);
log.info("向量嵌入完成: taskId={}, indexed={}条", taskId, allDocuments.size());
}
}核心实现五:进度追踪与SSE推送
package com.laozhang.kafka.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
/**
* 任务进度上报服务
* 进度信息写入Redis,SSE端点从Redis读取推送给前端
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class ProgressReporter {
private final StringRedisTemplate redisTemplate;
private static final Duration PROGRESS_TTL = Duration.ofHours(24);
public void report(String taskId, int progress, String message) {
Map<String, String> progressData = new HashMap<>();
progressData.put("progress", String.valueOf(progress));
progressData.put("status", "PROCESSING");
progressData.put("message", message);
progressData.put("updatedAt", String.valueOf(System.currentTimeMillis()));
String key = "task:progress:" + taskId;
redisTemplate.opsForHash().putAll(key, progressData);
redisTemplate.expire(key, PROGRESS_TTL);
redisTemplate.convertAndSend("task-progress-channel",
taskId + ":" + progress + ":" + message);
log.debug("进度上报: taskId={}, progress={}%", taskId, progress);
}
public void reportError(String taskId, String errorMessage) {
Map<String, String> data = new HashMap<>();
data.put("progress", "-1");
data.put("status", "ERROR");
data.put("message", errorMessage);
data.put("updatedAt", String.valueOf(System.currentTimeMillis()));
String key = "task:progress:" + taskId;
redisTemplate.opsForHash().putAll(key, data);
redisTemplate.expire(key, PROGRESS_TTL);
redisTemplate.convertAndSend("task-progress-channel", taskId + ":ERROR:" + errorMessage);
}
public void reportFinalError(String taskId, String errorMessage) {
Map<String, String> data = new HashMap<>();
data.put("progress", "-1");
data.put("status", "FAILED_PERMANENT");
data.put("message", errorMessage);
data.put("updatedAt", String.valueOf(System.currentTimeMillis()));
String key = "task:progress:" + taskId;
redisTemplate.opsForHash().putAll(key, data);
redisTemplate.expire(key, PROGRESS_TTL);
redisTemplate.convertAndSend("task-progress-channel", taskId + ":FAILED:" + errorMessage);
}
@SuppressWarnings("unchecked")
public TaskProgress getProgress(String taskId) {
String key = "task:progress:" + taskId;
Map<Object, Object> data = redisTemplate.opsForHash().entries(key);
if (data.isEmpty()) return null;
return TaskProgress.builder()
.taskId(taskId)
.progress(Integer.parseInt((String) data.getOrDefault("progress", "0")))
.status((String) data.getOrDefault("status", "UNKNOWN"))
.message((String) data.getOrDefault("message", ""))
.result((String) data.get("result"))
.build();
}
}package com.laozhang.kafka.controller;
import com.laozhang.kafka.service.ProgressReporter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
import java.time.Duration;
import java.util.concurrent.ConcurrentHashMap;
/**
* SSE进度推送Controller
* 前端通过EventSource订阅任务进度
*/
@Slf4j
@RestController
@RequestMapping("/api/tasks")
@RequiredArgsConstructor
public class TaskProgressController {
private final ProgressReporter progressReporter;
private final RedisMessageListenerContainer listenerContainer;
private final ConcurrentHashMap<String, Sinks.Many<ServerSentEvent<String>>> activeSinks =
new ConcurrentHashMap<>();
@GetMapping(value = "/{taskId}/progress", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<String>> streamProgress(@PathVariable String taskId) {
log.info("客户端订阅任务进度: taskId={}", taskId);
TaskProgress currentProgress = progressReporter.getProgress(taskId);
Sinks.Many<ServerSentEvent<String>> sink = Sinks.many().multicast().directBestEffort();
activeSinks.put(taskId, sink);
MessageListener listener = setupRedisListener(taskId, sink);
return sink.asFlux()
.startWith(buildCurrentProgressEvent(taskId, currentProgress))
.mergeWith(Flux.interval(Duration.ofSeconds(30))
.map(tick -> ServerSentEvent.<String>builder()
.event("heartbeat")
.data("{\"alive\":true}")
.build()))
.doOnCancel(() -> {
log.info("客户端断开SSE连接: taskId={}", taskId);
activeSinks.remove(taskId);
listenerContainer.removeMessageListener(listener);
});
}
private MessageListener setupRedisListener(String taskId,
Sinks.Many<ServerSentEvent<String>> sink) {
MessageListener listener = (message, pattern) -> {
String body = new String(message.getBody());
if (body.startsWith(taskId + ":")) {
String[] parts = body.split(":", 3);
String progressStr = parts[1];
String msg = parts.length > 2 ? parts[2] : "";
String sseData = String.format(
"{\"taskId\":\"%s\",\"progress\":\"%s\",\"message\":\"%s\"}",
taskId, progressStr, msg);
sink.tryEmitNext(ServerSentEvent.<String>builder()
.event("progress").data(sseData).build());
if ("100".equals(progressStr) || "FAILED".equals(progressStr)) {
sink.tryEmitComplete();
}
}
};
listenerContainer.addMessageListener(listener, new ChannelTopic("task-progress-channel"));
return listener;
}
private ServerSentEvent<String> buildCurrentProgressEvent(String taskId,
TaskProgress progress) {
if (progress == null) {
return ServerSentEvent.<String>builder()
.event("progress")
.data(String.format("{\"taskId\":\"%s\",\"progress\":\"0\",\"message\":\"等待处理...\"}",
taskId))
.build();
}
return ServerSentEvent.<String>builder()
.event("progress")
.data(String.format("{\"taskId\":\"%s\",\"progress\":\"%d\",\"status\":\"%s\",\"message\":\"%s\"}",
taskId, progress.getProgress(), progress.getStatus(), progress.getMessage()))
.build();
}
@GetMapping("/{taskId}/status")
public ResponseEntity<TaskProgress> getStatus(@PathVariable String taskId) {
TaskProgress progress = progressReporter.getProgress(taskId);
if (progress == null) return ResponseEntity.notFound().build();
return ResponseEntity.ok(progress);
}
}核心实现六:背压控制
# application.yml
spring:
kafka:
bootstrap-servers: ${KAFKA_BOOTSTRAP_SERVERS:localhost:9092}
producer:
key-serializer: org.apache.kafka.common.serialization.StringSerializer
value-serializer: org.springframework.kafka.support.serializer.JsonSerializer
acks: all
retries: 3
compression-type: lz4
consumer:
key-deserializer: org.apache.kafka.common.serialization.StringDeserializer
value-deserializer: org.springframework.kafka.support.serializer.JsonDeserializer
properties:
spring.json.trusted.packages: "com.laozhang.kafka.message"
listener:
shutdown-timeout: 30000@Configuration
public class KafkaConsumerConfig {
@Bean
public Map<String, Object> aiTaskConsumerProps() {
Map<String, Object> props = new HashMap<>();
props.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, 5);
props.put(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 360000);
props.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false);
props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");
return props;
}
@Bean
public ConcurrentKafkaListenerContainerFactory<String, AiTaskMessage>
aiTaskKafkaListenerContainerFactory(
ConsumerFactory<String, AiTaskMessage> consumerFactory) {
ConcurrentKafkaListenerContainerFactory<String, AiTaskMessage> factory =
new ConcurrentKafkaListenerContainerFactory<>();
factory.setConsumerFactory(consumerFactory);
factory.getContainerProperties()
.setAckMode(ContainerProperties.AckMode.MANUAL_IMMEDIATE);
factory.getContainerProperties().setIdleBetweenPolls(1000);
return factory;
}
}/**
* 带手动确认的Consumer示例
*/
@KafkaListener(topics = "document.embed", groupId = "embed-group",
containerFactory = "aiTaskKafkaListenerContainerFactory")
public void consumeWithAck(
ConsumerRecord<String, DocumentEmbedMessage> record,
Acknowledgment acknowledgment) {
String taskId = record.value().getTaskId();
try {
processEmbedding(record.value());
acknowledgment.acknowledge();
log.debug("消息处理完成并确认: taskId={}", taskId);
} catch (RecoverableException e) {
log.warn("可恢复错误,消息将重试: taskId={}", taskId, e);
// 不调用acknowledge,消息会重新投递
} catch (FatalException e) {
log.error("不可恢复错误: taskId={}", taskId, e);
acknowledgment.acknowledge(); // 提交偏移量,不再重试
taskRepository.markAsFailed(taskId, e.getMessage());
}
}核心实现七:顺序保证
/**
* 顺序保证:使用taskId作为消息Key,相同Key总是发到同一分区,同一分区内消息有序
*/
@Service
public class OrderedKafkaProducer {
private final KafkaTemplate<String, AiTaskMessage> kafkaTemplate;
public void sendOrdered(String topic, String taskId, AiTaskMessage message) {
kafkaTemplate.send(topic, taskId, message)
.whenComplete((result, ex) -> {
if (ex == null) {
log.debug("有序消息发送成功: topic={}, key={}, partition={}",
topic, taskId, result.getRecordMetadata().partition());
} else {
log.error("有序消息发送失败: topic={}, key={}", topic, taskId, ex);
throw new MessageSendException("消息发送失败", ex);
}
});
}
}核心实现八:消费积压监控
@Component
@Slf4j
@RequiredArgsConstructor
public class KafkaLagMonitor {
private final AdminClient kafkaAdminClient;
private final AlertService alertService;
@Scheduled(fixedDelay = 60000)
public void checkConsumerLag() {
Map<String, Integer> lagAlertThresholds = Map.of(
"document.upload", 50,
"document.extract", 100,
"document.embed", 200
);
lagAlertThresholds.forEach((topic, threshold) -> {
long lag = calculateLag(topic);
if (lag > threshold) {
log.warn("Kafka消费积压告警: topic={}, lag={}, threshold={}", topic, lag, threshold);
alertService.sendAlert(String.format(
"AI处理积压告警!Topic: %s, 积压: %d条(阈值: %d条)", topic, lag, threshold));
}
});
}
private long calculateLag(String topic) {
// 实际实现调用Kafka AdminClient API计算消费滞后
return 0L;
}
}生产环境注意事项
踩坑1:消息体包含大文件内容
绝对不要把文件内容放到Kafka消息里!Kafka单条消息默认最大1MB,超过会失败。正确做法:文件存对象存储,消息里只放引用Key。
踩坑2:重试幂等性
如果第二步处理成功但发Kafka消息失败,整个流水线会从第二步重试。确保每一步的处理是幂等的,用task_id + step作为去重Key。
踩坑3:SSE连接数管理
大量用户同时订阅SSE会消耗大量连接。设置连接超时(5分钟),任务完成后主动关闭SSE流,避免连接泄漏。
踩坑4:Consumer Group ID管理
不同环境(dev/staging/prod)必须使用不同的Consumer Group ID,否则测试环境会消费生产环境的消息!
常见问题解答
Q1:为什么不用RabbitMQ而选Kafka?
A:对于AI处理流水线,Kafka的核心优势:消息持久化(可回溯重放)、高吞吐(处理批量文档)、原生支持分区顺序。RabbitMQ适合低延迟任务调度,Kafka适合高吞吐数据流水线。
Q2:Consumer宕机了,正在处理的消息怎么办?
A:使用手动偏移量提交(MANUAL_IMMEDIATE模式),只有处理成功才提交offset。Consumer宕机后,Kafka把未提交的消息重新分配给其他Consumer。配合幂等处理,不会有消息丢失。
Q3:如何保证消息不丢失?
A:三重保证:1)Producer设置acks=all;2)Consumer手动提交offset;3)业务逻辑写入数据库后再提交offset。
Q4:任务处理超时了怎么办?
A:在任务创建时记录deadline,定时任务扫描超时任务并发送告警。对超时任务选择:重新发布消息触发重处理,或标记为失败通知用户。关键是不让用户无限等待。
Q5:如何做流水线的全链路追踪?
A:在消息体的traceContext字段中传递TraceId(Spring Cloud Sleuth或Micrometer Tracing)。每个Consumer接收消息时从traceContext恢复TraceId,Zipkin或Jaeger就能看到完整跨服务链路。
Q6:消息格式变更了怎么向前兼容?
A:消息体中包含version字段。Consumer按版本号路由处理:if (message.getVersion() >= 2) { handleV2(message); } else { handleV1(message); }。新版本Consumer能处理老版本消息,老版本Consumer忽略新字段(Jackson默认忽略未知字段)。
总结
林晓东的故事告诉我们:同步调用AI处理长任务,是架构性错误。
可操作行动清单:
异步架构投入的复杂度,是AI系统能撑住生产流量的基础设施投资,值得。
