第1633篇:Spring AI中的ChatClient高级用法——拦截器、装饰器与全链路追踪
第1633篇:Spring AI中的ChatClient高级用法——拦截器、装饰器与全链路追踪
ChatClient是Spring AI里最常用的组件,但大多数教程里它的角色就是"发请求、拿回复",用法停留在最基础的层面。今天我想聊一聊ChatClient真正有意思的地方——如何通过拦截器和装饰器把它打造成一个可观测、可治理的AI调用核心。
这个话题来自一次真实的故障排查。有个同事某天跑来问我:生产上AI回答开始变差了,但我不知道是模型的问题还是Prompt改了。我问他有没有日志,他说有请求日志,但没有Prompt内容。我说那没法排查。这次之后我们花了一周时间做了一套全链路追踪方案,今天把核心内容分享出来。
ChatClient的核心设计
先理解一下ChatClient的核心设计。它不是直接包装ChatModel,而是在上面提供了一套Fluent API,同时内置了Advisor机制。
ChatClient
├── ChatModel(底层模型调用)
├── Advisors(前置/后置处理链)
└── DefaultChatClientRequest(请求构建)Advisor是Spring AI里的核心扩展点,本质上是一个拦截链,类似Servlet Filter但专门为AI调用设计的:
public interface ChatClientRequestAdvisor extends Ordered {
AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object> context);
default int getOrder() {
return Ordered.LOWEST_PRECEDENCE;
}
}
public interface ChatClientResponseAdvisor extends Ordered {
ChatClientResponse adviseResponse(ChatClientResponse response, Map<String, Object> context);
}也有把请求和响应处理合在一起的:
public interface AroundAdvisor extends RequestResponseAdvisor {
String NAME = AroundAdvisor.class.getName();
AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain);
}自己写第一个Advisor:日志记录
最简单也最实用的Advisor——把每次调用的Prompt和Response都记录下来:
@Slf4j
@Component
public class FullLoggingAdvisor implements AroundAdvisor {
private final ObjectMapper objectMapper;
// 是否记录完整Prompt内容(生产上可能要脱敏)
@Value("${ai.logging.log-prompt:true}")
private boolean logPrompt;
@Value("${ai.logging.log-response:true}")
private boolean logResponse;
public FullLoggingAdvisor(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest,
CallAroundAdvisorChain chain) {
long startTime = System.currentTimeMillis();
String traceId = MDC.get("traceId");
if (traceId == null) {
traceId = UUID.randomUUID().toString().replace("-", "").substring(0, 16);
MDC.put("traceId", traceId);
}
if (logPrompt) {
logRequest(advisedRequest, traceId);
}
AdvisedResponse response;
try {
response = chain.nextAroundCall(advisedRequest);
} catch (Exception e) {
log.error("[AI调用失败] traceId={}, 耗时={}ms, 错误={}",
traceId, System.currentTimeMillis() - startTime, e.getMessage());
throw e;
}
long duration = System.currentTimeMillis() - startTime;
if (logResponse) {
logResponse(response, traceId, duration);
}
// 把调用信息存到context里,方便其他Advisor读取
Map<String, Object> context = new HashMap<>(response.adviseContext());
context.put("ai.call.duration", duration);
context.put("ai.call.traceId", traceId);
return new AdvisedResponse(response.response(), context);
}
private void logRequest(AdvisedRequest request, String traceId) {
try {
List<Map<String, String>> messages = request.messages().stream()
.map(msg -> Map.of(
"role", msg.getMessageType().getValue(),
"content", truncate(msg.getContent(), 500)
))
.collect(Collectors.toList());
log.info("[AI请求] traceId={}, model={}, messageCount={}, messages={}",
traceId,
request.chatOptions() != null ?
((ChatOptions) request.chatOptions()).getModel() : "default",
messages.size(),
objectMapper.writeValueAsString(messages));
} catch (Exception e) {
log.warn("记录AI请求日志失败", e);
}
}
private void logResponse(AdvisedResponse response, String traceId, long duration) {
try {
ChatResponse chatResponse = response.response();
if (chatResponse != null && !chatResponse.getResults().isEmpty()) {
String content = chatResponse.getResults().get(0)
.getOutput().getContent();
// Token使用量(如果模型有返回)
Long promptTokens = null;
Long completionTokens = null;
if (chatResponse.getMetadata() != null &&
chatResponse.getMetadata().getUsage() != null) {
promptTokens = chatResponse.getMetadata().getUsage().getPromptTokens();
completionTokens = chatResponse.getMetadata().getUsage().getGenerationTokens();
}
log.info("[AI响应] traceId={}, 耗时={}ms, promptTokens={}, completionTokens={}, content={}",
traceId, duration, promptTokens, completionTokens,
truncate(content, 300));
}
} catch (Exception e) {
log.warn("记录AI响应日志失败", e);
}
}
private String truncate(String text, int maxLength) {
if (text == null) return "";
if (text.length() <= maxLength) return text;
return text.substring(0, maxLength) + "...[截断]";
}
@Override
public int getOrder() {
return Ordered.HIGHEST_PRECEDENCE; // 最先执行,保证所有调用都被记录
}
}安全合规Advisor:敏感词过滤
在很多场景下,需要在发送给模型之前过滤掉Prompt里的敏感信息,或者检查模型响应是否包含不合规内容:
@Component
@Slf4j
public class ContentSafetyAdvisor implements AroundAdvisor {
private final ContentSafetyService safetyService;
public ContentSafetyAdvisor(ContentSafetyService safetyService) {
this.safetyService = safetyService;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest,
CallAroundAdvisorChain chain) {
// 前置检查:过滤请求中的敏感内容
AdvisedRequest sanitizedRequest = sanitizeRequest(advisedRequest);
// 调用模型
AdvisedResponse response = chain.nextAroundCall(sanitizedRequest);
// 后置检查:检查响应内容
return checkResponse(response);
}
private AdvisedRequest sanitizeRequest(AdvisedRequest request) {
List<Message> sanitizedMessages = request.messages().stream()
.map(this::sanitizeMessage)
.collect(Collectors.toList());
return AdvisedRequest.from(request)
.withMessages(sanitizedMessages)
.build();
}
private Message sanitizeMessage(Message message) {
String content = message.getContent();
if (content == null) return message;
// 脱敏处理:手机号、身份证、银行卡等
String sanitized = content
.replaceAll("1[3-9]\\d{9}", "***手机号***")
.replaceAll("\\d{17}[\\dXx]", "***身份证***")
.replaceAll("\\d{16,19}", "***卡号***");
if (!sanitized.equals(content)) {
log.info("请求内容中检测到敏感信息,已脱敏处理");
}
if (message instanceof UserMessage) {
return new UserMessage(sanitized);
} else if (message instanceof SystemMessage) {
return new SystemMessage(sanitized);
}
return message;
}
private AdvisedResponse checkResponse(AdvisedResponse response) {
if (response.response() == null) return response;
String responseContent = response.response().getResults().stream()
.map(r -> r.getOutput().getContent())
.collect(Collectors.joining());
SafetyCheckResult checkResult = safetyService.check(responseContent);
if (checkResult.isUnsafe()) {
log.warn("AI响应内容违规,类型: {}", checkResult.getViolationType());
// 替换为安全的回复
ChatResponse safeResponse = buildSafeResponse(checkResult.getViolationType());
return new AdvisedResponse(safeResponse, response.adviseContext());
}
return response;
}
private ChatResponse buildSafeResponse(String violationType) {
AssistantMessage message = new AssistantMessage(
"抱歉,我无法回答这个问题。如有需要请联系人工客服。"
);
return new ChatResponse(List.of(new Generation(message)));
}
@Override
public int getOrder() {
return 10; // 在日志记录之后执行
}
}成本控制Advisor:Token计费和限额
AI调用的成本控制是很多团队忽视的,等账单来了才发现超了预算:
@Component
@Slf4j
public class CostControlAdvisor implements AroundAdvisor {
private final TokenBudgetService budgetService;
private final MeterRegistry meterRegistry;
public CostControlAdvisor(TokenBudgetService budgetService,
MeterRegistry meterRegistry) {
this.budgetService = budgetService;
this.meterRegistry = meterRegistry;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest request,
CallAroundAdvisorChain chain) {
// 从context里拿用户ID(需要提前放进去)
String userId = (String) request.adviseContext().get("userId");
String businessUnit = (String) request.adviseContext().get("businessUnit");
// 检查预算是否超限
if (userId != null && budgetService.isExceeded(userId)) {
throw new TokenBudgetExceededException(
"用户 " + userId + " 的Token配额已用尽,请联系管理员");
}
AdvisedResponse response = chain.nextAroundCall(request);
// 记录Token消耗
if (response.response() != null &&
response.response().getMetadata() != null &&
response.response().getMetadata().getUsage() != null) {
var usage = response.response().getMetadata().getUsage();
long totalTokens = usage.getTotalTokens() != null ? usage.getTotalTokens() : 0L;
// 扣减用户配额
if (userId != null) {
budgetService.deduct(userId, totalTokens);
}
// Metrics上报
Counter.builder("ai.tokens.total")
.tag("user", userId != null ? userId : "unknown")
.tag("business_unit", businessUnit != null ? businessUnit : "unknown")
.tag("model", extractModel(request))
.register(meterRegistry)
.increment(totalTokens);
}
return response;
}
private String extractModel(AdvisedRequest request) {
if (request.chatOptions() instanceof ChatOptions opts) {
return opts.getModel() != null ? opts.getModel() : "unknown";
}
return "unknown";
}
@Override
public int getOrder() {
return 20;
}
}缓存Advisor:相同问题复用结果
有些场景下相同的Prompt会重复调用,可以做语义缓存:
@Component
@Slf4j
public class SemanticCacheAdvisor implements AroundAdvisor {
private final EmbeddingModel embeddingModel;
private final VectorStore vectorStore;
private final double similarityThreshold;
public SemanticCacheAdvisor(EmbeddingModel embeddingModel,
VectorStore vectorStore,
@Value("${ai.cache.similarity-threshold:0.95}")
double similarityThreshold) {
this.embeddingModel = embeddingModel;
this.vectorStore = vectorStore;
this.similarityThreshold = similarityThreshold;
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest request,
CallAroundAdvisorChain chain) {
// 只缓存用户消息
String userInput = extractUserInput(request);
if (userInput == null || userInput.length() < 10) {
return chain.nextAroundCall(request);
}
// 查找语义相似的缓存
List<Document> cachedResults = vectorStore.similaritySearch(
SearchRequest.query(userInput)
.withTopK(1)
.withSimilarityThreshold(similarityThreshold)
);
if (!cachedResults.isEmpty()) {
Document cached = cachedResults.get(0);
String cachedResponse = (String) cached.getMetadata().get("response");
log.debug("命中语义缓存,相似度: {}",
cached.getMetadata().get("similarity"));
AssistantMessage message = new AssistantMessage(cachedResponse);
ChatResponse cachedChatResponse = new ChatResponse(
List.of(new Generation(message))
);
Map<String, Object> context = new HashMap<>(request.adviseContext());
context.put("cache.hit", true);
return new AdvisedResponse(cachedChatResponse, context);
}
// 缓存未命中,正常调用
AdvisedResponse response = chain.nextAroundCall(request);
// 把结果存入缓存
if (response.response() != null && !response.response().getResults().isEmpty()) {
String responseContent = response.response().getResults().get(0)
.getOutput().getContent();
Document cacheDoc = Document.builder()
.content(userInput)
.metadata(Map.of("response", responseContent,
"timestamp", System.currentTimeMillis()))
.build();
vectorStore.add(List.of(cacheDoc));
}
return response;
}
private String extractUserInput(AdvisedRequest request) {
return request.messages().stream()
.filter(m -> m instanceof UserMessage)
.map(Message::getContent)
.findFirst()
.orElse(null);
}
@Override
public int getOrder() {
return 5; // 尽早执行,避免不必要的日志记录
}
}把Advisor组装起来
重点来了。Advisor要按正确的顺序组装,顺序不对会出问题:
等一下,上面这个顺序其实有个问题——CacheAdvisor我给的order=5,比LoggingAdvisor的HIGHEST_PRECEDENCE大,所以实际执行顺序要看Spring的Ordered接口语义。HIGHEST_PRECEDENCE是Integer.MIN_VALUE,数值最小表示最先执行。
所以实际顺序是:LoggingAdvisor → CacheAdvisor → SafetyAdvisor → CostAdvisor。
如果你想让Cache在Log之前执行(这样缓存命中的请求就不会被记录,减少日志量),需要把CacheAdvisor的order设置成比HIGHEST_PRECEDENCE更小的数,这在实践中有点别扭。更清晰的做法是明确设置每个Advisor的order数值:
// 建议的顺序设置
CacheAdvisor: order = -200 // 最先,缓存命中直接返回
LoggingAdvisor: order = -100 // 记录所有通过缓存检查的请求
SafetyAdvisor: order = 0 // 内容安全检查
CostAdvisor: order = 100 // 成本控制(最后才扣费)配置ChatClient时:
@Configuration
public class ChatClientConfig {
@Bean
public ChatClient chatClient(
ChatModel chatModel,
FullLoggingAdvisor loggingAdvisor,
ContentSafetyAdvisor safetyAdvisor,
CostControlAdvisor costAdvisor,
SemanticCacheAdvisor cacheAdvisor) {
return ChatClient.builder(chatModel)
// 默认System Prompt
.defaultSystem("你是一个专业的AI助手,请用中文回答问题。")
// 注册全局Advisor
.defaultAdvisors(
cacheAdvisor,
loggingAdvisor,
safetyAdvisor,
costAdvisor
)
.build();
}
}在请求时动态传递上下文
很多Advisor需要知道当前是哪个用户在调用,这个信息需要在调用时动态注入:
@Service
public class AIService {
private final ChatClient chatClient;
public AIService(ChatClient chatClient) {
this.chatClient = chatClient;
}
public String chat(String userId, String businessUnit, String userInput) {
return chatClient.prompt()
.system("你是一个专业的AI助手")
.user(userInput)
// 动态注入上下文,Advisor里可以读取
.advisors(a -> a
.param("userId", userId)
.param("businessUnit", businessUnit)
.param("requestId", MDC.get("requestId"))
)
.call()
.content();
}
}装饰器模式:包装ChatModel本身
除了Advisor,有时候需要在更底层的位置拦截,比如我想统计每个模型的实际调用延迟(排除Advisor的处理时间)。这时候可以用装饰器包装ChatModel:
public class MetricsChatModelDecorator implements ChatModel {
private final ChatModel delegate;
private final MeterRegistry meterRegistry;
private final String modelName;
public MetricsChatModelDecorator(ChatModel delegate,
MeterRegistry meterRegistry,
String modelName) {
this.delegate = delegate;
this.meterRegistry = meterRegistry;
this.modelName = modelName;
}
@Override
public ChatResponse call(Prompt prompt) {
Timer.Sample sample = Timer.start(meterRegistry);
String status = "success";
try {
ChatResponse response = delegate.call(prompt);
return response;
} catch (Exception e) {
status = "error";
throw e;
} finally {
sample.stop(Timer.builder("ai.model.latency")
.tag("model", modelName)
.tag("status", status)
.register(meterRegistry));
}
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
long startTime = System.currentTimeMillis();
return delegate.stream(prompt)
.doOnComplete(() -> {
meterRegistry.counter("ai.model.stream.complete",
"model", modelName).increment();
})
.doOnError(e -> {
meterRegistry.counter("ai.model.stream.error",
"model", modelName,
"error", e.getClass().getSimpleName()).increment();
});
}
}在AutoConfiguration里用装饰器包装:
@Bean
public ChatModel chatModel(OpenAiChatModel openAiChatModel,
MeterRegistry meterRegistry) {
return new MetricsChatModelDecorator(
openAiChatModel,
meterRegistry,
"gpt-4"
);
}全链路追踪的完整效果
把上面这些组合起来后,一次AI调用的完整日志大概长这样:
[INFO] [AI请求] traceId=a3f2b1c4, model=gpt-4, messageCount=2, messages=[{"role":"system","content":"你是..."}, {"role":"user","content":"帮我分析这份报告..."}]
[INFO] [内容检查] traceId=a3f2b1c4, 请求内容安全
[DEBUG] [Token配额] traceId=a3f2b1c4, 用户user123剩余配额: 45000
[INFO] [AI响应] traceId=a3f2b1c4, 耗时=2341ms, promptTokens=450, completionTokens=823, content="根据报告内容..."
[INFO] [Token扣减] traceId=a3f2b1c4, 用户user123扣减1273 tokens,剩余: 43727配合分布式追踪系统(比如Jaeger或Zipkin),这个traceId可以串联整个请求链路——从用户发起HTTP请求,到AI服务,再到底层模型调用,全程可追踪。
当线上出现问题的时候,这套体系的价值就体现出来了。之前那个同事的问题,如果有了这些日志,五分钟就能定位:是Prompt变了还是模型响应质量变了。
