AI 服务的优雅停机——长任务不能被强杀
AI 服务的优雅停机——长任务不能被强杀
那是我们做灰度发布的一个普通下午。
运维同学按照惯例执行滚动重启,发布脚本里是标准的 kill -9,等个 30 秒再拉起新版本。以前这套流程从来没出过问题。
但那天,几个用户反馈他们的 AI 写作辅助卡在一半——屏幕上出现了半截的文章,后面再也没有新的字出来,刷新之后内容消失了。有个用户写了一篇产品分析报告,正好在 GPT 生成最关键的结论段落时被强杀了,前面几千字的铺垫全没了。
那之后运维同学每次发布前都会先问我:「现在有人在用 AI 吗?」
这个问题本身就是个架构问题的信号:我们的停机流程完全没有感知到 AI 长任务的存在。
一、AI 应用和普通应用的停机差异
先说清楚为什么 AI 应用的优雅停机比普通应用难。
普通 Web 应用的请求生命周期通常在几百毫秒到几秒之间。Spring Boot 默认的 graceful shutdown 会等待正在处理的请求完成,超时时间默认 30 秒。这对普通应用来说绰绰有余。
但 AI 应用有几类特殊的长任务:
流式响应(Streaming Response)
用户在接收流式输出,可能已经持续了 20-30 秒,还没有完成。如果在这时候关闭服务,用户看到的是截断的半段内容,比直接报错还难受——因为用户不确定这是 AI 写完了还是出了问题。
长时间推理任务
某些场景下(代码生成、长文写作、文档分析),单次请求的 AI 调用可能需要 1-3 分钟。30 秒的 graceful shutdown 超时远远不够。
批量异步任务
很多 AI 应用有后台批量处理任务,比如对文档批量做 embedding、对知识库批量更新向量。这类任务可能正在中间某个步骤,强杀会导致数据不一致。
多步骤 Agent 任务
AI Agent 可能正在执行一个多工具调用的任务链,中途断掉会让任务状态悬空,难以恢复。
二、Spring Boot Graceful Shutdown 的局限
Spring Boot 2.3 开始支持 server.shutdown=graceful,原理是:
- 收到 SIGTERM 信号后,停止接受新请求
- 等待正在处理的请求完成
- 超过
spring.lifecycle.timeout-per-shutdown-phase配置的超时时间后强制退出
server:
shutdown: graceful
spring:
lifecycle:
timeout-per-shutdown-phase: 30s # 默认 30 秒这对普通 Web 请求有效。但有两个问题:
问题一:超时时间对 AI 任务不够。把超时改成 5 分钟?那每次部署都要等 5 分钟才能完成实例替换,滚动发布会极其缓慢。
问题二:无法区分任务类型。流式响应和批量任务的处理策略完全不同,统一等待是错的。流式响应应该尽快通知客户端,让用户知道发生了什么;批量任务应该保存进度,下次启动继续。
三、自定义 ShutdownHook 的实现
3.1 AI 任务注册表
首先,我们需要一个地方追踪所有正在进行的 AI 任务:
@Component
@Slf4j
public class AiTaskRegistry {
// 使用 ConcurrentHashMap 保证线程安全
private final ConcurrentHashMap<String, AiTask> activeTasks = new ConcurrentHashMap<>();
// 是否正在停机
private volatile boolean shuttingDown = false;
/**
* 注册一个 AI 任务
*/
public void register(AiTask task) {
if (shuttingDown) {
throw new ServiceUnavailableException("服务正在重启,暂时无法接受新请求,请稍后再试");
}
activeTasks.put(task.getTaskId(), task);
log.debug("AI task registered: {} (type={})", task.getTaskId(), task.getType());
}
/**
* 任务完成后注销
*/
public void unregister(String taskId) {
activeTasks.remove(taskId);
log.debug("AI task unregistered: {}", taskId);
}
/**
* 获取所有活跃任务(快照,不影响原集合)
*/
public List<AiTask> getActiveTasks() {
return new ArrayList<>(activeTasks.values());
}
/**
* 标记为停机状态(停止接受新任务)
*/
public void setShuttingDown() {
this.shuttingDown = true;
log.info("AiTaskRegistry is now in shutdown mode, {} tasks still active",
activeTasks.size());
}
public int getActiveCount() {
return activeTasks.size();
}
public boolean isShuttingDown() {
return shuttingDown;
}
}3.2 AI 任务模型定义
@Data
@Builder
public class AiTask {
private String taskId;
private AiTaskType type;
private long startTime;
private String userId;
private String scene;
// 流式任务需要持有对 SSE emitter 或 Flux sink 的引用
private StreamTerminator streamTerminator;
// 批量任务需要持有进度保存的回调
private ProgressSaver progressSaver;
public enum AiTaskType {
STREAM_RESPONSE, // 流式响应
BATCH_EMBEDDING, // 批量 Embedding
LONG_INFERENCE, // 长时间推理
AGENT_TASK // Agent 多步骤任务
}
@FunctionalInterface
public interface StreamTerminator {
void terminate(String reason);
}
@FunctionalInterface
public interface ProgressSaver {
void saveProgress();
}
}3.3 自定义 Shutdown Handler
@Component
@Slf4j
public class AiGracefulShutdownHandler implements SmartLifecycle {
@Autowired
private AiTaskRegistry taskRegistry;
@Value("${ai.shutdown.stream-grace-period-ms:3000}")
private long streamGracePeriodMs;
@Value("${ai.shutdown.batch-save-timeout-ms:30000}")
private long batchSaveTimeoutMs;
private volatile boolean running = false;
@Override
public void start() {
this.running = true;
log.info("AiGracefulShutdownHandler started");
}
@Override
public void stop(Runnable callback) {
log.info("AiGracefulShutdownHandler stopping...");
this.running = false;
// 标记停机状态,阻止新任务注册
taskRegistry.setShuttingDown();
try {
handleShutdown();
} finally {
callback.run(); // 通知 Spring 这个 Bean 已经停止
log.info("AiGracefulShutdownHandler stopped");
}
}
private void handleShutdown() {
List<AiTask> tasks = taskRegistry.getActiveTasks();
if (tasks.isEmpty()) {
log.info("No active AI tasks, shutting down immediately");
return;
}
log.info("Found {} active AI tasks during shutdown", tasks.size());
// 按类型分组处理
Map<AiTask.AiTaskType, List<AiTask>> tasksByType = tasks.stream()
.collect(Collectors.groupingBy(AiTask::getType));
// 1. 先处理流式响应——快速通知客户端
List<AiTask> streamTasks = tasksByType.getOrDefault(AiTask.AiTaskType.STREAM_RESPONSE, List.of());
terminateStreamTasks(streamTasks);
// 2. 等待一个短暂的缓冲期,让流式终止帧发出去
if (!streamTasks.isEmpty()) {
sleep(streamGracePeriodMs);
}
// 3. 保存批量任务进度
List<AiTask> batchTasks = tasksByType.getOrDefault(AiTask.AiTaskType.BATCH_EMBEDDING, List.of());
saveBatchProgress(batchTasks);
// 4. 等待批量任务保存完成
waitForBatchSave(batchTasks);
// 5. 记录还未完成的任务(供重启后恢复)
List<AiTask> remainingTasks = taskRegistry.getActiveTasks();
if (!remainingTasks.isEmpty()) {
log.warn("Shutdown with {} tasks still incomplete: {}",
remainingTasks.size(),
remainingTasks.stream().map(AiTask::getTaskId).collect(Collectors.joining(", "))
);
}
}
private void terminateStreamTasks(List<AiTask> streamTasks) {
for (AiTask task : streamTasks) {
try {
if (task.getStreamTerminator() != null) {
task.getStreamTerminator().terminate("服务正在重启,请重新发送您的问题");
log.info("Stream task {} terminated gracefully", task.getTaskId());
}
} catch (Exception e) {
log.warn("Failed to terminate stream task {}: {}", task.getTaskId(), e.getMessage());
}
}
}
private void saveBatchProgress(List<AiTask> batchTasks) {
for (AiTask task : batchTasks) {
try {
if (task.getProgressSaver() != null) {
log.info("Saving progress for batch task {}", task.getTaskId());
task.getProgressSaver().saveProgress();
}
} catch (Exception e) {
log.error("Failed to save progress for batch task {}: {}",
task.getTaskId(), e.getMessage());
}
}
}
private void waitForBatchSave(List<AiTask> batchTasks) {
if (batchTasks.isEmpty()) return;
long deadline = System.currentTimeMillis() + batchSaveTimeoutMs;
while (System.currentTimeMillis() < deadline) {
boolean allDone = batchTasks.stream()
.noneMatch(t -> taskRegistry.getActiveTasks().contains(t));
if (allDone) {
log.info("All batch tasks completed or saved");
return;
}
sleep(500);
}
log.warn("Batch task save timeout after {}ms", batchSaveTimeoutMs);
}
private void sleep(long ms) {
try {
Thread.sleep(ms);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
@Override
public boolean isRunning() {
return running;
}
@Override
public int getPhase() {
// 在 Spring 默认的 Web Server 停机之后执行(Web Server phase 是 Integer.MAX_VALUE)
return Integer.MAX_VALUE - 100;
}
}3.4 在流式响应中集成任务注册
@Service
@Slf4j
public class StreamingAiService {
@Autowired
private AiTaskRegistry taskRegistry;
@Autowired
private OpenAiClient openAiClient;
/**
* 流式 AI 响应,集成优雅停机
*/
public Flux<String> streamChat(String userId, String prompt) {
String taskId = UUID.randomUUID().toString();
// 创建用于停机时通知客户端的终止器
Sinks.Many<String> sink = Sinks.many().unicast().onBackpressureBuffer();
AiTask task = AiTask.builder()
.taskId(taskId)
.type(AiTask.AiTaskType.STREAM_RESPONSE)
.startTime(System.currentTimeMillis())
.userId(userId)
.streamTerminator(reason -> {
// 停机时发送一个特殊的终止消息给客户端
sink.tryEmitNext("\n\n[SYSTEM] " + reason);
sink.tryEmitComplete();
})
.build();
try {
taskRegistry.register(task);
} catch (ServiceUnavailableException e) {
return Flux.error(e);
}
// 开始流式调用,完成或出错后注销任务
return openAiClient.streamChat(prompt)
.doOnNext(token -> {
// 将 AI 输出转发给 sink
sink.tryEmitNext(token);
})
.doOnComplete(() -> {
sink.tryEmitComplete();
taskRegistry.unregister(taskId);
log.debug("Stream task {} completed", taskId);
})
.doOnError(e -> {
sink.tryEmitError(e);
taskRegistry.unregister(taskId);
log.warn("Stream task {} failed: {}", taskId, e.getMessage());
})
.thenMany(sink.asFlux());
}
}3.5 批量任务的进度保存
@Service
@Slf4j
public class BatchEmbeddingService {
@Autowired
private AiTaskRegistry taskRegistry;
@Autowired
private EmbeddingProgressRepository progressRepo;
@Autowired
private VectorStoreClient vectorStoreClient;
public void startBatchEmbedding(String jobId, List<Document> documents) {
String taskId = "batch-" + jobId;
// 记录当前处理进度的原子引用
AtomicInteger processedCount = new AtomicInteger(0);
AiTask task = AiTask.builder()
.taskId(taskId)
.type(AiTask.AiTaskType.BATCH_EMBEDDING)
.startTime(System.currentTimeMillis())
.progressSaver(() -> {
// 停机时保存当前处理到哪个文档
int progress = processedCount.get();
progressRepo.saveCheckpoint(jobId, progress);
log.info("Saved batch job {} progress at document {}/{}",
jobId, progress, documents.size());
})
.build();
taskRegistry.register(task);
try {
// 检查是否有断点续传记录
int startFrom = progressRepo.getCheckpoint(jobId).orElse(0);
log.info("Starting batch embedding job {}, from document {}", jobId, startFrom);
for (int i = startFrom; i < documents.size(); i++) {
// 停机检查:如果收到停机信号,不处理更多文档
if (taskRegistry.isShuttingDown()) {
log.info("Shutdown detected, stopping batch job {} at document {}", jobId, i);
processedCount.set(i);
break;
}
Document doc = documents.get(i);
float[] embedding = vectorStoreClient.embed(doc.getContent());
vectorStoreClient.upsert(doc.getId(), embedding, doc.getMetadata());
processedCount.set(i + 1);
// 每处理 100 个文档保存一次进度
if ((i + 1) % 100 == 0) {
progressRepo.saveCheckpoint(jobId, i + 1);
}
}
// 全部完成,清除断点记录
if (!taskRegistry.isShuttingDown()) {
progressRepo.deleteCheckpoint(jobId);
log.info("Batch embedding job {} completed successfully", jobId);
}
} finally {
taskRegistry.unregister(taskId);
}
}
}四、配置和测试
4.1 完整配置
server:
shutdown: graceful
spring:
lifecycle:
# Spring 默认的 Web 层超时
timeout-per-shutdown-phase: 60s
ai:
shutdown:
# 流式任务终止后等待的缓冲时间
stream-grace-period-ms: 3000
# 批量任务保存进度的最长等待时间
batch-save-timeout-ms: 300004.2 验证停机行为
模拟停机场景的测试:
@SpringBootTest
@Slf4j
class GracefulShutdownTest {
@Autowired
private StreamingAiService streamingService;
@Autowired
private AiGracefulShutdownHandler shutdownHandler;
@Autowired
private AiTaskRegistry taskRegistry;
@Test
void testStreamTaskReceivesShutdownNotification() throws InterruptedException {
List<String> receivedTokens = Collections.synchronizedList(new ArrayList<>());
// 启动一个流式任务(使用 Mock AI 客户端)
Flux<String> stream = streamingService.streamChat("test-user", "写一篇文章");
stream.subscribe(receivedTokens::add);
// 确认任务已注册
Thread.sleep(100);
assertThat(taskRegistry.getActiveCount()).isEqualTo(1);
// 模拟停机
CountDownLatch shutdownLatch = new CountDownLatch(1);
shutdownHandler.stop(shutdownLatch::countDown);
// 等待停机完成
boolean completed = shutdownLatch.await(10, TimeUnit.SECONDS);
assertThat(completed).isTrue();
// 验证流式任务收到了停机通知
assertThat(receivedTokens)
.anyMatch(token -> token.contains("SYSTEM") && token.contains("重启"));
log.info("Received tokens during shutdown: {}", receivedTokens);
}
}五、真实场景复盘
回到开篇那次事故。
如果当时有这套机制,会发生什么:
- 运维执行重启,发出 SIGTERM 信号
- Spring Boot 的 graceful shutdown 开始,停止接受新请求
AiGracefulShutdownHandler检测到 3 个正在进行的流式任务- 向这 3 个任务发送终止消息,用户屏幕上看到「服务正在重启,请重新发送您的问题」
- 等待 3 秒缓冲期后,确认消息已发出
- 服务安全退出
用户会知道发生了什么,会重新发送请求,不会丢失他们的意图,只是有 10 几秒的中断。
这和强杀导致的半截截断完全是两种用户体验。
总结
AI 应用的优雅停机核心在于「感知任务类型,分类处理」:
- 流式响应:快速通知客户端,发送有意义的停机消息
- 批量任务:保存断点进度,支持重启续传
- Spring Boot 的
graceful shutdown是基础,但不够,需要用SmartLifecycle扩展 - 任务注册表是核心数据结构,停机时的所有决策都基于它
优雅停机是生产 AI 应用的基本素养。用户可以接受偶尔的重启,但无法接受没有解释的截断。
