Spring AI与Redis深度集成:会话管理与语义缓存实战
Spring AI与Redis深度集成:会话管理与语义缓存实战
那个让用户气到退款的AI客服
去年年底,我接到一个朋友的求助电话,他在某家电商公司负责技术,他们上线了一个AI客服系统,用的是GPT-4,听起来很高级对不对?但用户投诉像雪片一样飞来。
投诉内容高度雷同:"这个AI客服是傻子吗?我刚说完要退的商品编号,下一句就问我要退哪个商品!"
我朋友给我发了一段对话记录:
用户:我要退订单号 ORD-20241205-8847 里面的商品
AI:您好!您想退款吗?请问是哪个订单?
用户:我刚说了!ORD-20241205-8847
AI:好的,请问您要退哪个商品呢?
用户:(已崩溃)他们花了80万采购了这套系统,结果上线第一周就有237名用户要求退款,理由是AI客服"弱智"。运营总监急得要上房揭瓦,技术团队连续三天加班排查。
问题其实很简单:他们的AI客服没有会话记忆。每次用户发一条消息,系统就单独调用一次LLM API,完全不带上下文历史。从AI的视角看,每次对话都是一个全新的人。
更气人的是,他们的客服系统当时日均对话量大概是5000次,高峰期同时在线用户约400人。这些对话记录全部存在Java进程的内存里,一次重启,所有进行中的会话全部消失。用户只要刷新页面,就得从头开始。
我花了半天时间帮他们重构了会话管理模块,用Redis做持久化存储,同时加上了语义缓存,把LLM调用量直接降了62%。本文就是这次实战的完整技术复盘。
先说结论(TL;DR)
| 方案 | 适用场景 | 实现难度 | 成本影响 |
|---|---|---|---|
| InMemoryChatMemory | 开发测试 | 低 | 无 |
| Redis ChatMemory(单用户) | 生产单租户 | 中 | 极小 |
| Redis多租户会话隔离 | 生产多租户SaaS | 中高 | 极小 |
| 语义缓存 | 高重复查询场景 | 高 | 降低60-90% |
| Redis Cluster高可用 | 大规模生产 | 高 | 基础设施成本 |
核心结论:
- 生产环境必须用Redis持久化会话,不能用内存
- 多租户场景用
{tenantId}:{userId}:{sessionId}三层Key命名 - 语义缓存是降成本最有效的手段,相似度阈值建议0.92
- 会话窗口不要超过20条消息,超出截断最早的消息
ChatMemory原理解析:消息是如何在内存中管理的
在写代码之前,我们先搞清楚Spring AI的ChatMemory到底是怎么工作的,这样出了问题你才知道从哪查。
消息历史的数据结构
Spring AI里,一次对话的历史被表示为一个List<Message>,每条消息有三种类型:
// UserMessage - 用户发的消息
new UserMessage("我要退ORD-20241205-8847的商品")
// AssistantMessage - AI回复的消息
new AssistantMessage("您好,我已找到订单,请问要退哪个商品?")
// SystemMessage - 系统提示词
new SystemMessage("你是一个专业的电商客服助手,请根据对话历史回答用户问题")每次调用LLM时,Spring AI会把这些消息按顺序拼接起来发给模型,这就是LLM "记住"上下文的本质:每次调用都把历史塞进Prompt。
InMemoryChatMemory的问题
Spring AI默认的InMemoryChatMemory把消息存在ConcurrentHashMap里:
- 数据在JVM内存里,重启即消失
- 多实例部署时,不同实例的会话互不共享
- 没有TTL,内存会无限增长
消息窗口机制
MessageWindowChatMemory加了一个滑动窗口:当消息超过设定数量时,丢弃最早的对话,只保留最近N条。这个设计有工程意义:LLM的上下文窗口有大小限制(GPT-4是128K tokens),历史太长会增加token消耗和费用。
方案一:基于Redis的ChatMemory实现
依赖配置
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
</dependencies>Redis配置
spring:
data:
redis:
host: ${REDIS_HOST:localhost}
port: ${REDIS_PORT:6379}
password: ${REDIS_PASSWORD:}
database: 0
lettuce:
pool:
max-active: 20
max-idle: 10
min-idle: 5
max-wait: 2000ms
shutdown-timeout: 200ms
connect-timeout: 3000ms
timeout: 3000ms
ai:
chat:
memory:
key-prefix: "chat:memory:"
max-messages: 20
ttl: 86400核心实现:RedisChatMemory
package com.laozhang.ai.memory;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.*;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
/**
* 基于Redis的ChatMemory实现
* 支持会话持久化、多实例共享、TTL自动过期
*/
@Slf4j
@Component
public class RedisChatMemory implements ChatMemory {
private static final String KEY_PREFIX = "chat:memory:";
private static final int DEFAULT_MAX_MESSAGES = 20;
private static final long DEFAULT_TTL_SECONDS = 86400L;
private final RedisTemplate<String, String> redisTemplate;
private final ObjectMapper objectMapper;
private final int maxMessages;
private final long ttlSeconds;
public RedisChatMemory(RedisTemplate<String, String> redisTemplate,
ObjectMapper objectMapper) {
this.redisTemplate = redisTemplate;
this.objectMapper = objectMapper;
this.maxMessages = DEFAULT_MAX_MESSAGES;
this.ttlSeconds = DEFAULT_TTL_SECONDS;
}
/**
* 添加消息到会话历史
* 使用Redis List数据结构,RPUSH追加到列表末尾
*/
@Override
public void add(String conversationId, List<Message> messages) {
if (conversationId == null || messages == null || messages.isEmpty()) return;
String redisKey = buildKey(conversationId);
try {
List<String> serializedMessages = messages.stream()
.map(this::serializeMessage)
.filter(Objects::nonNull)
.collect(Collectors.toList());
if (!serializedMessages.isEmpty()) {
redisTemplate.opsForList().rightPushAll(redisKey, serializedMessages);
// 刷新TTL,每次有新消息就续期
redisTemplate.expire(redisKey, ttlSeconds, TimeUnit.SECONDS);
// 裁剪列表,只保留最近maxMessages条
trimMessages(redisKey);
log.debug("Added {} messages to conversation {}", messages.size(), conversationId);
}
} catch (Exception e) {
// 不抛异常,允许降级到无记忆模式
log.error("Failed to add messages to Redis for conversation: {}", conversationId, e);
}
}
/**
* 获取会话历史消息
*/
@Override
public List<Message> get(String conversationId, int lastN) {
if (conversationId == null) return Collections.emptyList();
String redisKey = buildKey(conversationId);
try {
long listSize = Optional.ofNullable(
redisTemplate.opsForList().size(redisKey)).orElse(0L);
if (listSize == 0) return Collections.emptyList();
long startIndex = (lastN < 0 || lastN >= listSize) ? 0 : listSize - lastN;
List<String> serializedMessages = redisTemplate.opsForList()
.range(redisKey, startIndex, -1);
if (serializedMessages == null || serializedMessages.isEmpty()) {
return Collections.emptyList();
}
return serializedMessages.stream()
.map(this::deserializeMessage)
.filter(Objects::nonNull)
.collect(Collectors.toList());
} catch (Exception e) {
log.error("Failed to get messages from Redis for conversation: {}", conversationId, e);
return Collections.emptyList();
}
}
/**
* 清除会话历史
*/
@Override
public void clear(String conversationId) {
if (conversationId == null) return;
try {
redisTemplate.delete(buildKey(conversationId));
log.info("Cleared conversation: {}", conversationId);
} catch (Exception e) {
log.error("Failed to clear conversation: {}", conversationId, e);
}
}
private void trimMessages(String redisKey) {
try {
long size = Optional.ofNullable(
redisTemplate.opsForList().size(redisKey)).orElse(0L);
if (size > maxMessages) {
redisTemplate.opsForList().trim(redisKey, size - maxMessages, -1);
}
} catch (Exception e) {
log.warn("Failed to trim messages for key: {}", redisKey, e);
}
}
private String buildKey(String conversationId) {
return KEY_PREFIX + conversationId;
}
private String serializeMessage(Message message) {
try {
Map<String, Object> messageMap = new HashMap<>();
if (message instanceof UserMessage) {
messageMap.put("type", "USER");
} else if (message instanceof AssistantMessage) {
messageMap.put("type", "ASSISTANT");
} else if (message instanceof SystemMessage) {
messageMap.put("type", "SYSTEM");
} else {
log.warn("Unknown message type: {}", message.getClass().getName());
return null;
}
messageMap.put("content", message.getText());
messageMap.put("timestamp", System.currentTimeMillis());
return objectMapper.writeValueAsString(messageMap);
} catch (Exception e) {
log.error("Failed to serialize message", e);
return null;
}
}
private Message deserializeMessage(String json) {
try {
Map<String, Object> messageMap = objectMapper.readValue(
json, new TypeReference<Map<String, Object>>() {});
String type = (String) messageMap.get("type");
String content = (String) messageMap.get("content");
return switch (type) {
case "USER" -> new UserMessage(content);
case "ASSISTANT" -> new AssistantMessage(content);
case "SYSTEM" -> new SystemMessage(content);
default -> {
log.warn("Unknown message type in Redis: {}", type);
yield null;
}
};
} catch (Exception e) {
log.error("Failed to deserialize message from JSON: {}", json, e);
return null;
}
}
}Redis配置类
package com.laozhang.ai.config;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.StringRedisSerializer;
@Configuration
public class RedisConfig {
/**
* 配置RedisTemplate,使用String序列化器
* 避免Java对象序列化导致的乱码和版本兼容性问题
*/
@Bean
public RedisTemplate<String, String> redisTemplate(RedisConnectionFactory connectionFactory) {
RedisTemplate<String, String> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
StringRedisSerializer stringSerializer = new StringRedisSerializer();
template.setKeySerializer(stringSerializer);
template.setValueSerializer(stringSerializer);
template.setHashKeySerializer(stringSerializer);
template.setHashValueSerializer(stringSerializer);
template.afterPropertiesSet();
return template;
}
@Bean
public ObjectMapper objectMapper() {
ObjectMapper mapper = new ObjectMapper();
mapper.registerModule(new JavaTimeModule());
mapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS);
return mapper;
}
}在Service中使用RedisChatMemory
package com.laozhang.ai.service;
import com.laozhang.ai.memory.RedisChatMemory;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.stereotype.Service;
import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;
import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY;
@Slf4j
@Service
@RequiredArgsConstructor
public class CustomerServiceChatService {
private final ChatClient.Builder chatClientBuilder;
private final RedisChatMemory redisChatMemory;
/**
* 带记忆的对话
*/
public String chat(String sessionId, String userMessage) {
log.info("Chat request - sessionId: {}, message: {}", sessionId, userMessage);
try {
ChatClient chatClient = chatClientBuilder
.defaultSystem("""
你是一个专业的电商客服助手。
你需要根据对话历史,记住用户提到的订单号、商品信息等关键内容。
回答要简洁专业,遇到退款问题按照标准流程处理。
""")
.defaultAdvisors(
MessageChatMemoryAdvisor.builder(redisChatMemory).build()
)
.build();
String response = chatClient.prompt()
.user(userMessage)
.advisors(advisorSpec -> advisorSpec
.param(CHAT_MEMORY_CONVERSATION_ID_KEY, sessionId)
.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 10)
)
.call()
.content();
log.info("Chat response - sessionId: {}, responseLength: {}",
sessionId, response != null ? response.length() : 0);
return response;
} catch (Exception e) {
log.error("Chat failed for sessionId: {}", sessionId, e);
throw new RuntimeException("对话处理失败,请稍后重试", e);
}
}
public void clearSession(String sessionId) {
redisChatMemory.clear(sessionId);
log.info("Session cleared: {}", sessionId);
}
}方案二:Redis多租户会话隔离
如果你在做SaaS产品,不同企业的用户会话必须严格隔离,这涉及数据安全、配额管理和计费。
多租户Key设计
# Key格式:chat:memory:{tenantId}:{userId}:{sessionId}
chat:memory:tenant_001:user_12345:session_abc
chat:memory:tenant_002:user_99999:session_xyz三层隔离的意义:
tenantId:租户级别,可以按租户设置不同TTL和消息窗口大小userId:用户级别,可以查看某用户的历史会话sessionId:会话级别,一个用户可以有多个并发会话
多租户ChatMemory实现
package com.laozhang.ai.memory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
/**
* 多租户会话管理器
* 支持按租户配置不同的会话策略
*/
@Slf4j
@Component
public class MultiTenantChatMemoryManager {
private final Map<String, TenantMemoryConfig> tenantConfigs = new ConcurrentHashMap<>();
private final RedisChatMemory redisChatMemory;
private final RedisTemplate<String, String> redisTemplate;
public MultiTenantChatMemoryManager(RedisChatMemory redisChatMemory,
RedisTemplate<String, String> redisTemplate) {
this.redisChatMemory = redisChatMemory;
this.redisTemplate = redisTemplate;
initDefaultConfigs();
}
public String buildTenantSessionId(String tenantId, String userId, String sessionId) {
return String.format("%s:%s:%s", tenantId, userId, sessionId);
}
public void addMessages(String tenantId, String userId, String sessionId,
List<Message> messages) {
TenantMemoryConfig config = getTenantConfig(tenantId);
String compositeKey = buildTenantSessionId(tenantId, userId, sessionId);
if (!checkTenantQuota(tenantId, userId)) {
log.warn("Tenant quota exceeded - tenantId: {}, userId: {}", tenantId, userId);
return;
}
redisChatMemory.add(compositeKey, messages);
String redisKey = "chat:memory:" + compositeKey;
redisTemplate.expire(redisKey, config.getTtlSeconds(), TimeUnit.SECONDS);
}
public List<Message> getMessages(String tenantId, String userId,
String sessionId, int lastN) {
String compositeKey = buildTenantSessionId(tenantId, userId, sessionId);
TenantMemoryConfig config = getTenantConfig(tenantId);
int effectiveLastN = Math.min(lastN, config.getMaxMessages());
return redisChatMemory.get(compositeKey, effectiveLastN);
}
private boolean checkTenantQuota(String tenantId, String userId) {
TenantMemoryConfig config = getTenantConfig(tenantId);
String countKey = "chat:session:count:" + tenantId + ":" + userId;
String countStr = redisTemplate.opsForValue().get(countKey);
long currentCount = countStr != null ? Long.parseLong(countStr) : 0;
return currentCount < config.getMaxConcurrentSessions();
}
private TenantMemoryConfig getTenantConfig(String tenantId) {
return tenantConfigs.getOrDefault(tenantId, TenantMemoryConfig.defaultConfig());
}
private void initDefaultConfigs() {
// 免费版:消息窗口10条,TTL 1小时,最多3个并发会话
tenantConfigs.put("FREE_TIER", new TenantMemoryConfig(10, 3600L, 3));
// 专业版:消息窗口20条,TTL 24小时,最多20个并发会话
tenantConfigs.put("PRO_TIER", new TenantMemoryConfig(20, 86400L, 20));
// 企业版:消息窗口50条,TTL 7天,无限并发会话
tenantConfigs.put("ENTERPRISE", new TenantMemoryConfig(50, 604800L, Integer.MAX_VALUE));
}
public static class TenantMemoryConfig {
private final int maxMessages;
private final long ttlSeconds;
private final int maxConcurrentSessions;
public TenantMemoryConfig(int maxMessages, long ttlSeconds, int maxConcurrentSessions) {
this.maxMessages = maxMessages;
this.ttlSeconds = ttlSeconds;
this.maxConcurrentSessions = maxConcurrentSessions;
}
public static TenantMemoryConfig defaultConfig() {
return new TenantMemoryConfig(20, 86400L, 10);
}
public int getMaxMessages() { return maxMessages; }
public long getTtlSeconds() { return ttlSeconds; }
public int getMaxConcurrentSessions() { return maxConcurrentSessions; }
}
}语义缓存:让相似问题直接命中缓存
这是本文技术含量最高的部分,也是成本优化效果最明显的。
为什么需要语义缓存
普通的缓存是精确匹配,语义缓存通过向量相似度搜索,让语义相同的不同表达命中同一个缓存。
语义缓存完整实现
package com.laozhang.ai.cache;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.concurrent.TimeUnit;
/**
* 基于Redis的语义缓存
* 相似问题直接命中缓存,降低LLM调用成本
*/
@Slf4j
@Component
public class SemanticCacheService {
private static final double SIMILARITY_THRESHOLD = 0.92;
private static final long CACHE_TTL_SECONDS = 3600L;
private static final String CACHE_KEY_PREFIX = "semantic:cache:";
private final RedisTemplate<String, String> redisTemplate;
private final EmbeddingModel embeddingModel;
public SemanticCacheService(RedisTemplate<String, String> redisTemplate,
EmbeddingModel embeddingModel) {
this.redisTemplate = redisTemplate;
this.embeddingModel = embeddingModel;
}
/**
* 查询语义缓存
*/
public Optional<String> get(String question) {
try {
float[] questionVector = generateEmbedding(question);
List<CacheEntry> candidates = vectorSearch(questionVector, 5);
if (candidates.isEmpty()) return Optional.empty();
CacheEntry best = candidates.stream()
.max(Comparator.comparingDouble(CacheEntry::getSimilarity))
.orElse(null);
if (best == null || best.getSimilarity() < SIMILARITY_THRESHOLD) {
return Optional.empty();
}
log.info("Semantic cache HIT - similarity: {}, question: {}",
String.format("%.4f", best.getSimilarity()), truncate(best.getQuestion()));
return Optional.of(best.getAnswer());
} catch (Exception e) {
log.error("Semantic cache lookup failed", e);
return Optional.empty();
}
}
/**
* 将问题和答案存入语义缓存
*/
public void put(String question, String answer) {
try {
float[] vector = generateEmbedding(question);
String cacheId = UUID.randomUUID().toString();
String redisKey = CACHE_KEY_PREFIX + cacheId;
Map<String, String> hashFields = new HashMap<>();
hashFields.put("question", question);
hashFields.put("answer", answer);
hashFields.put("timestamp", String.valueOf(System.currentTimeMillis()));
hashFields.put("vector", encodeVector(vector));
redisTemplate.opsForHash().putAll(redisKey, hashFields);
redisTemplate.expire(redisKey, CACHE_TTL_SECONDS, TimeUnit.SECONDS);
log.debug("Stored semantic cache entry - id: {}", cacheId);
} catch (Exception e) {
log.error("Failed to store semantic cache entry", e);
}
}
/**
* 计算余弦相似度
*/
public double cosineSimilarity(float[] vectorA, float[] vectorB) {
if (vectorA.length != vectorB.length) {
throw new IllegalArgumentException("Vector dimensions must match");
}
double dotProduct = 0.0, normA = 0.0, normB = 0.0;
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += vectorA[i] * vectorA[i];
normB += vectorB[i] * vectorB[i];
}
if (normA == 0 || normB == 0) return 0.0;
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
private float[] generateEmbedding(String text) {
long startTime = System.currentTimeMillis();
float[] vector = embeddingModel.embed(text);
log.debug("Generated embedding in {}ms, dim: {}",
System.currentTimeMillis() - startTime, vector.length);
return vector;
}
/**
* 向量搜索(简化实现,生产建议用RediSearch KNN)
*/
private List<CacheEntry> vectorSearch(float[] queryVector, int topK) {
List<CacheEntry> results = new ArrayList<>();
Set<String> keys = redisTemplate.keys(CACHE_KEY_PREFIX + "*");
if (keys == null || keys.isEmpty()) return results;
for (String key : keys) {
try {
Map<Object, Object> fields = redisTemplate.opsForHash().entries(key);
if (fields == null || fields.isEmpty()) continue;
String vectorEncoded = (String) fields.get("vector");
if (vectorEncoded == null) continue;
float[] cachedVector = decodeVector(vectorEncoded);
double similarity = cosineSimilarity(queryVector, cachedVector);
if (similarity > SIMILARITY_THRESHOLD * 0.9) {
results.add(new CacheEntry(
(String) fields.get("question"),
(String) fields.get("answer"),
similarity
));
}
} catch (Exception e) {
log.warn("Failed to process cache entry: {}", key, e);
}
}
results.sort((a, b) -> Double.compare(b.getSimilarity(), a.getSimilarity()));
return results.subList(0, Math.min(topK, results.size()));
}
private String encodeVector(float[] vector) {
byte[] bytes = new byte[vector.length * 4];
for (int i = 0; i < vector.length; i++) {
int intBits = Float.floatToIntBits(vector[i]);
bytes[i * 4] = (byte) (intBits >> 24);
bytes[i * 4 + 1] = (byte) (intBits >> 16);
bytes[i * 4 + 2] = (byte) (intBits >> 8);
bytes[i * 4 + 3] = (byte) intBits;
}
return Base64.getEncoder().encodeToString(bytes);
}
private float[] decodeVector(String encoded) {
byte[] bytes = Base64.getDecoder().decode(encoded);
float[] vector = new float[bytes.length / 4];
for (int i = 0; i < vector.length; i++) {
int intBits = ((bytes[i * 4] & 0xFF) << 24) |
((bytes[i * 4 + 1] & 0xFF) << 16) |
((bytes[i * 4 + 2] & 0xFF) << 8) |
(bytes[i * 4 + 3] & 0xFF);
vector[i] = Float.intBitsToFloat(intBits);
}
return vector;
}
private String truncate(String text) {
if (text == null) return "null";
return text.length() > 50 ? text.substring(0, 50) + "..." : text;
}
public static class CacheEntry {
private final String question;
private final String answer;
private final double similarity;
public CacheEntry(String question, String answer, double similarity) {
this.question = question;
this.answer = answer;
this.similarity = similarity;
}
public String getQuestion() { return question; }
public String getAnswer() { return answer; }
public double getSimilarity() { return similarity; }
}
}语义缓存集成到Chat服务
package com.laozhang.ai.service;
import com.laozhang.ai.cache.SemanticCacheService;
import com.laozhang.ai.memory.RedisChatMemory;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Timer;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.stereotype.Service;
import java.util.Optional;
/**
* 带语义缓存的AI对话服务
*/
@Slf4j
@Service
public class SemanticCachedChatService {
private final ChatClient chatClient;
private final SemanticCacheService semanticCache;
private final Counter cacheHitCounter;
private final Counter cacheMissCounter;
private final Timer llmCallTimer;
public SemanticCachedChatService(
ChatClient.Builder chatClientBuilder,
RedisChatMemory redisChatMemory,
SemanticCacheService semanticCache,
MeterRegistry meterRegistry) {
this.chatClient = chatClientBuilder
.defaultAdvisors(MessageChatMemoryAdvisor.builder(redisChatMemory).build())
.build();
this.semanticCache = semanticCache;
this.cacheHitCounter = Counter.builder("ai.semantic.cache.hit")
.description("语义缓存命中次数").register(meterRegistry);
this.cacheMissCounter = Counter.builder("ai.semantic.cache.miss")
.description("语义缓存未命中次数").register(meterRegistry);
this.llmCallTimer = Timer.builder("ai.llm.call.duration")
.description("LLM调用耗时").register(meterRegistry);
}
/**
* 发送消息(带语义缓存)
*/
public ChatResponse chat(String sessionId, String userMessage) {
long startTime = System.currentTimeMillis();
// 1. 先查语义缓存
Optional<String> cachedAnswer = semanticCache.get(userMessage);
if (cachedAnswer.isPresent()) {
cacheHitCounter.increment();
long elapsed = System.currentTimeMillis() - startTime;
log.info("Semantic cache HIT - sessionId: {}, elapsed: {}ms", sessionId, elapsed);
return new ChatResponse(cachedAnswer.get(), true, elapsed);
}
// 2. 缓存未命中,调用LLM
cacheMissCounter.increment();
String response = llmCallTimer.record(() ->
chatClient.prompt()
.user(userMessage)
.advisors(spec -> spec.param("chat_memory_conversation_id", sessionId))
.call()
.content()
);
long elapsed = System.currentTimeMillis() - startTime;
// 3. 异步存入语义缓存
asyncStoreToCache(userMessage, response);
log.info("LLM call completed - sessionId: {}, elapsed: {}ms", sessionId, elapsed);
return new ChatResponse(response, false, elapsed);
}
private void asyncStoreToCache(String question, String answer) {
try {
semanticCache.put(question, answer);
} catch (Exception e) {
log.warn("Failed to store to semantic cache", e);
}
}
public record ChatResponse(String content, boolean fromCache, long elapsedMs) {}
}会话过期与清理策略
TTL设计枚举
package com.laozhang.ai.config;
/**
* 会话TTL策略枚举
*/
public enum SessionTtlStrategy {
/** 短期客服会话:30分钟无操作自动过期 */
CUSTOMER_SERVICE(1800L),
/** 日常对话:24小时过期 */
DAILY_CHAT(86400L),
/** 项目协作:7天过期 */
PROJECT_WORK(604800L),
/** 永久会话:不过期(手动清理) */
PERMANENT(-1L);
private final long seconds;
SessionTtlStrategy(long seconds) { this.seconds = seconds; }
public long getSeconds() { return seconds; }
}会话清理定时任务
package com.laozhang.ai.task;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.Set;
/**
* 会话清理定时任务
*/
@Slf4j
@Component
public class SessionCleanupTask {
private final RedisTemplate<String, String> redisTemplate;
public SessionCleanupTask(RedisTemplate<String, String> redisTemplate) {
this.redisTemplate = redisTemplate;
}
/**
* 每天凌晨2点执行统计
*/
@Scheduled(cron = "0 0 2 * * ?")
public void dailyCleanup() {
log.info("Starting daily session cleanup...");
try {
long chatMemoryCount = countKeys("chat:memory:*");
long semanticCacheCount = countKeys("semantic:cache:*");
log.info("Session stats - chatMemory: {}, semanticCache: {}",
chatMemoryCount, semanticCacheCount);
if (chatMemoryCount > 100000) {
log.warn("Chat memory key count is too high: {}. Consider adjusting TTL.",
chatMemoryCount);
}
} catch (Exception e) {
log.error("Daily cleanup failed", e);
}
}
@Scheduled(cron = "0 0 * * * ?")
public void hourlyStats() {
log.info("Hourly stats check completed");
}
private long countKeys(String pattern) {
Set<String> keys = redisTemplate.keys(pattern);
return keys != null ? keys.size() : 0;
}
}会话数据分析:挖掘用户行为模式
package com.laozhang.ai.analytics;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
/**
* 会话数据分析服务
*/
@Slf4j
@Service
public class SessionAnalyticsService {
private final RedisTemplate<String, String> redisTemplate;
public SessionAnalyticsService(RedisTemplate<String, String> redisTemplate) {
this.redisTemplate = redisTemplate;
}
/**
* 记录会话指标
*/
public void recordSessionMetrics(String sessionId, String userId,
String question, String answer,
long responseTimeMs, boolean fromCache) {
String date = java.time.LocalDate.now().toString();
try {
redisTemplate.opsForValue().increment("analytics:daily:questions:" + date);
if (fromCache) {
redisTemplate.opsForValue().increment("analytics:daily:cache_hits:" + date);
}
redisTemplate.opsForZSet().add(
"analytics:response_times:" + date,
sessionId + ":" + System.currentTimeMillis(),
responseTimeMs
);
// 记录高频关键词
List<String> keywords = extractKeywords(question);
for (String keyword : keywords) {
redisTemplate.opsForZSet().incrementScore(
"analytics:keywords:" + date, keyword, 1);
}
} catch (Exception e) {
log.error("Failed to record session metrics", e);
}
}
/**
* 获取今日统计报告
*/
public DailyReport getDailyReport(String date) {
String questionCountStr = redisTemplate.opsForValue()
.get("analytics:daily:questions:" + date);
String cacheHitStr = redisTemplate.opsForValue()
.get("analytics:daily:cache_hits:" + date);
long totalQuestions = questionCountStr != null ? Long.parseLong(questionCountStr) : 0;
long cacheHits = cacheHitStr != null ? Long.parseLong(cacheHitStr) : 0;
double cacheHitRate = totalQuestions > 0 ? (double) cacheHits / totalQuestions : 0;
Set<org.springframework.data.redis.core.ZSetOperations.TypedTuple<String>> topKeywords =
redisTemplate.opsForZSet().reverseRangeWithScores(
"analytics:keywords:" + date, 0, 9);
Map<String, Double> keywordFrequency = new LinkedHashMap<>();
if (topKeywords != null) {
for (var tuple : topKeywords) {
keywordFrequency.put(tuple.getValue(), tuple.getScore());
}
}
return new DailyReport(date, totalQuestions, cacheHits, cacheHitRate, keywordFrequency);
}
private List<String> extractKeywords(String text) {
String[] words = text.split("[,。!?、\\s]+");
return Arrays.stream(words)
.filter(w -> w.length() >= 2 && w.length() <= 8)
.limit(5)
.collect(Collectors.toList());
}
public record DailyReport(
String date, long totalQuestions, long cacheHits,
double cacheHitRate, Map<String, Double> topKeywords
) {}
}高可用设计:Redis Cluster + 故障转移
Redis Cluster生产配置
# application-prod.yml
spring:
data:
redis:
cluster:
nodes:
- redis-node-1:7001
- redis-node-2:7002
- redis-node-3:7003
- redis-node-4:7004
- redis-node-5:7005
- redis-node-6:7006
max-redirects: 3
password: ${REDIS_CLUSTER_PASSWORD}
lettuce:
cluster:
refresh:
adaptive: true
period: 30000ms
pool:
max-active: 50
max-idle: 20
min-idle: 10
max-wait: 3000ms自适应拓扑刷新
package com.laozhang.ai.config;
import io.lettuce.core.cluster.ClusterClientOptions;
import io.lettuce.core.cluster.ClusterTopologyRefreshOptions;
import org.springframework.boot.autoconfigure.data.redis.LettuceClientConfigurationBuilderCustomizer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Profile;
import java.time.Duration;
/**
* Redis Cluster高可用配置
*/
@Configuration
@Profile("prod")
public class RedisClusterConfig {
@Bean
public LettuceClientConfigurationBuilderCustomizer lettuceClientConfigurationBuilderCustomizer() {
return clientConfigurationBuilder -> {
ClusterTopologyRefreshOptions refreshOptions = ClusterTopologyRefreshOptions.builder()
.enablePeriodicRefresh(Duration.ofSeconds(30))
.enableAdaptiveRefreshTrigger(
ClusterTopologyRefreshOptions.RefreshTrigger.MOVED_REDIRECT,
ClusterTopologyRefreshOptions.RefreshTrigger.PERSISTENT_RECONNECTS
)
.adaptiveRefreshTriggersTimeout(Duration.ofSeconds(30))
.build();
ClusterClientOptions clusterClientOptions = ClusterClientOptions.builder()
.topologyRefreshOptions(refreshOptions)
.autoReconnect(true)
.validateClusterNodeMembership(true)
.build();
clientConfigurationBuilder.clientOptions(clusterClientOptions);
};
}
}生产环境注意事项与踩坑记录
坑1:序列化版本兼容性
当你修改了Message的数据结构,Redis里的旧数据会反序列化失败。务必在deserializeMessage中做防御性处理,反序列化失败时返回null并记录日志,不要抛异常。
坑2:Redis KEYS命令在生产禁止使用
KEYS *命令会阻塞Redis,当Key数量超过几十万时会导致Redis停顿数秒。生产环境只能用SCAN命令。
// 错误示范 - 生产禁止!
Set<String> keys = redisTemplate.keys("chat:memory:*");
// 正确做法:用SCAN分批扫描,每次100个
ScanOptions options = ScanOptions.scanOptions()
.match("chat:memory:*")
.count(100)
.build();坑3:语义缓存写入需要幂等保护
同一个问题可能被多个线程同时处理,需要分布式锁:
String questionHash = Integer.toHexString(question.hashCode());
String lockKey = "lock:semantic:cache:" + questionHash;
Boolean locked = redisTemplate.opsForValue()
.setIfAbsent(lockKey, "1", 5, TimeUnit.SECONDS);
if (Boolean.TRUE.equals(locked)) {
try {
semanticCache.put(question, answer);
} finally {
redisTemplate.delete(lockKey);
}
}坑4:Embedding向量维度变化
模型升级(如从text-embedding-ada-002到text-embedding-3-large)会导致向量维度变化,旧数据不可用。建议在Key中包含模型版本号,升级时全量重建向量索引。
性能测试数据
在生产环境测试中(4核8GB Redis,GPT-4 Turbo,华东区域):
| 测试场景 | 平均响应时间 | P99响应时间 | QPS |
|---|---|---|---|
| 无缓存,无记忆 | 1800ms | 4200ms | 45 |
| Redis会话记忆(10条) | 1850ms | 4300ms | 43 |
| 语义缓存命中 | 45ms | 120ms | 2200 |
| 语义缓存未命中 | 1920ms | 4400ms | 41 |
语义缓存命中率:在客服场景下,重复问题占比约65%,实际测试缓存命中率为58%,LLM调用量减少约58%,月均成本从1.2万降到0.5万元。
常见问题解答
Q1:Redis会话数据丢失了怎么办?
A:会话丢失是Redis持久化配置问题。确保Redis开启了AOF持久化:appendonly yes + appendfsync everysec。生产环境应使用Redis Sentinel或Cluster保证高可用。会话数据也可以考虑异步备份到MySQL,作为最终兜底。
Q2:多个Spring Boot实例会不会有会话冲突?
A:不会,这正是Redis的优势。所有实例共享同一个Redis,用户无论请求到哪个实例,都能读到完整的会话历史。唯一需要注意的是Key命名规则在所有实例中保持一致。
Q3:语义缓存的相似度阈值设多少合适?
A:根据业务场景调整:
- 客服场景:0.90-0.93(问题有标准答案,相似问题可以共用)
- 知识问答:0.95-0.97(需要更精确,避免似是而非的答案)
- 创意写作:不建议用语义缓存(每次都需要个性化)
最好的做法是上线后收集数据,用A/B测试找到最优值。
Q4:会话消息窗口设多大合适?
A:推荐20条以内(10轮对话)。实测数据:
- 10条消息:token成本约500-800 tokens/次
- 20条消息:token成本约1000-1600 tokens/次
- 50条消息:token成本约2500-4000 tokens/次
真正需要超长记忆的场景应该用向量数据库做长期记忆,而不是无限扩大消息窗口。
Q5:怎么防止用户通过会话历史注入恶意提示词?
A:在存储消息时做内容过滤:过滤包含ignore previous instructions等注入关键词;限制单条消息最大长度(建议2000字符);对HTML标签进行转义。
Q6:如何在不重启服务的情况下清除某个用户的所有会话?
A:通过SCAN命令找到该用户的所有Key,批量DELETE:
public void clearAllUserSessions(String tenantId, String userId) {
String pattern = "chat:memory:" + tenantId + ":" + userId + ":*";
List<String> keys = scanKeys(pattern);
if (!keys.isEmpty()) {
redisTemplate.delete(keys);
log.info("Cleared {} sessions for user: {}", keys.size(), userId);
}
}总结
今天我们完整实现了Spring AI + Redis的会话管理和语义缓存体系。用这套方案改造后,会话丢失率从35%降到了0%,LLM调用成本降低了58%,客服满意度从2.1分回升到4.3分(5分制)。
可操作行动清单:
