第1918篇:分布式缓存Redis在AI应用中的高级用法——语义缓存与会话存储
第1918篇:分布式缓存Redis在AI应用中的高级用法——语义缓存与会话存储
AI 应用有一个普遍的痛点:调用 LLM 又贵又慢。
贵:GPT-4 每次对话几毛钱,高并发场景一天费用轻松破万。
慢:一次 LLM 调用 2~10 秒,用户体验难以接受。
大多数团队的第一反应是"加缓存",但 AI 应用的缓存和普通 API 缓存有本质区别:用户的问题千变万化,你没法用精确的 key 做命中。"北京今天天气怎么样"和"北京现在天气如何"是两个不同的字符串,但语义完全一样。
这就引出了本文的核心话题:语义缓存(Semantic Cache)——用向量相似度判断问题是否"足够像",足够像就返回缓存结果,不需要重新调用 LLM。
一、语义缓存的核心思路
传统缓存:key = hash(query),完全匹配才命中。
语义缓存:key = embedding(query),相似度 > 阈值就命中。
实现语义缓存需要两个组件:
- 向量搜索:找到和当前问题最相似的历史问题
- 缓存存储:存储问题-答案对
Redis 同时支持这两件事(RediSearch 做向量搜索 + 普通 String/Hash 存缓存内容),是语义缓存的天然平台。
二、语义缓存的完整实现
2.1 数据结构设计
// 缓存条目
@Data
@Builder
public class SemanticCacheEntry {
private String cacheKey; // 唯一 ID,UUID
private String originalQuery; // 原始问题
private String cachedResponse; // 缓存的回答
private float[] queryEmbedding; // 问题的向量
private String modelId; // 使用的 LLM 模型
private int hitCount; // 命中次数
private long createdAt; // 创建时间(毫秒)
private long lastHitAt; // 最后命中时间
}2.2 Redis 索引创建
@Component
@RequiredArgsConstructor
@Slf4j
public class SemanticCacheIndexManager {
private final UnifiedJedis jedis;
private static final String CACHE_INDEX = "semantic_cache_idx";
private static final int EMBEDDING_DIM = 1536;
public void initIndex() {
try {
jedis.ftInfo(CACHE_INDEX);
return;
} catch (Exception ignored) {}
IndexDefinition def = new IndexDefinition(IndexDefinition.Type.HASH)
.setPrefixes("scache:");
Schema schema = new Schema()
.addTagField("model_id")
.addNumericField("created_at")
.addNumericField("hit_count")
.addVectorField("query_embedding",
Schema.VectorField.VectorAlgo.HNSW,
Map.of(
"TYPE", "FLOAT32",
"DIM", String.valueOf(EMBEDDING_DIM),
"DISTANCE_METRIC", "COSINE",
"M", "16",
"EF_CONSTRUCTION", "200",
"EF_RUNTIME", "80"
)
);
jedis.ftCreate(CACHE_INDEX,
IndexOptions.defaultOptions().setDefinition(def),
schema);
log.info("语义缓存索引创建成功");
}
}2.3 核心语义缓存服务
@Service
@RequiredArgsConstructor
@Slf4j
public class SemanticCacheService {
private final UnifiedJedis jedis;
private final EmbeddingClient embeddingClient;
private final MeterRegistry meterRegistry;
// 相似度阈值:超过此值认为命中缓存
private static final double SIMILARITY_THRESHOLD = 0.92;
// 缓存条目过期时间:24小时
private static final int CACHE_TTL_SECONDS = 86400;
// 最大缓存条目数(防止无限膨胀)
private static final int MAX_CACHE_SIZE = 50000;
/**
* 查询语义缓存
* 返回 Optional:命中则返回缓存结果,否则返回空
*/
public Optional<String> get(String query, String modelId) {
long start = System.currentTimeMillis();
try {
// 1. 向量化查询
float[] queryEmbedding = embeddingClient.embed(query);
byte[] embeddingBytes = floatToBytes(queryEmbedding);
// 2. 向量相似度搜索
String filter = modelId != null
? "@model_id:{" + modelId + "}"
: "*";
String queryStr = filter + "=>[KNN 5 @query_embedding $vec AS score]";
Query searchQuery = new Query(queryStr)
.addParam("vec", embeddingBytes)
.returnFields("cache_key", "cached_response",
"original_query", "score")
.setSortBy("score", true)
.limit(0, 5)
.dialect(2);
SearchResult result = jedis.ftSearch(CACHE_INDEX, searchQuery);
if (result.getDocuments().isEmpty()) {
recordMiss(System.currentTimeMillis() - start);
return Optional.empty();
}
// 3. 检查最高相似度
Document topDoc = result.getDocuments().get(0);
double distance = Double.parseDouble(topDoc.getString("score"));
double similarity = 1.0 - distance; // cosine distance -> similarity
if (similarity >= SIMILARITY_THRESHOLD) {
String cacheKey = topDoc.getString("cache_key");
String cachedResponse = topDoc.getString("cached_response");
// 4. 更新命中统计
updateHitStats(cacheKey);
recordHit(similarity, System.currentTimeMillis() - start);
log.debug("语义缓存命中,query={}, similarity={:.4f}, " +
"original={}", query, similarity,
topDoc.getString("original_query"));
return Optional.of(cachedResponse);
}
recordMiss(System.currentTimeMillis() - start);
return Optional.empty();
} catch (Exception e) {
log.error("语义缓存查询异常", e);
return Optional.empty(); // 降级:缓存失败不影响主流程
}
}
/**
* 写入语义缓存
*/
public void put(String query, String response, String modelId) {
try {
// 检查缓存大小,防止无限膨胀
// 简单策略:超过最大值时不再写入(更好的策略是 LRU 淘汰)
if (getCacheSize() >= MAX_CACHE_SIZE) {
log.warn("语义缓存已满,跳过写入");
return;
}
float[] queryEmbedding = embeddingClient.embed(query);
String cacheKey = UUID.randomUUID().toString();
String redisKey = "scache:" + cacheKey;
long now = System.currentTimeMillis();
// 写入 Hash(文本字段)
jedis.hset(redisKey, Map.of(
"cache_key", cacheKey,
"original_query", query,
"cached_response", response,
"model_id", modelId != null ? modelId : "default",
"hit_count", "0",
"created_at", String.valueOf(now),
"last_hit_at", String.valueOf(now)
));
// 写入向量字段(二进制)
jedis.hset(redisKey.getBytes(),
"query_embedding".getBytes(),
floatToBytes(queryEmbedding));
// 设置 TTL
jedis.expire(redisKey, CACHE_TTL_SECONDS);
log.debug("写入语义缓存,key={}, query={}", cacheKey, query);
} catch (Exception e) {
log.error("语义缓存写入异常", e);
// 写入失败不影响主流程
}
}
private void updateHitStats(String cacheKey) {
String redisKey = "scache:" + cacheKey;
jedis.hincrBy(redisKey, "hit_count", 1);
jedis.hset(redisKey, "last_hit_at",
String.valueOf(System.currentTimeMillis()));
// 刷新 TTL(命中的缓存延长过期时间)
jedis.expire(redisKey, CACHE_TTL_SECONDS);
}
private long getCacheSize() {
try {
Map<String, Object> info = jedis.ftInfo(CACHE_INDEX);
return Long.parseLong(info.get("num_docs").toString());
} catch (Exception e) {
return 0;
}
}
private void recordHit(double similarity, long latencyMs) {
meterRegistry.counter("semantic_cache.hit").increment();
meterRegistry.gauge("semantic_cache.hit_similarity", similarity);
meterRegistry.timer("semantic_cache.query_duration")
.record(latencyMs, java.util.concurrent.TimeUnit.MILLISECONDS);
}
private void recordMiss(long latencyMs) {
meterRegistry.counter("semantic_cache.miss").increment();
meterRegistry.timer("semantic_cache.query_duration")
.record(latencyMs, java.util.concurrent.TimeUnit.MILLISECONDS);
}
private byte[] floatToBytes(float[] floats) {
ByteBuffer buf = ByteBuffer.allocate(floats.length * 4)
.order(ByteOrder.LITTLE_ENDIAN);
for (float f : floats) buf.putFloat(f);
return buf.array();
}
}2.4 在 LLM 调用层集成语义缓存
@Service
@RequiredArgsConstructor
@Slf4j
public class CachedLlmService {
private final SemanticCacheService cacheService;
private final LlmClient llmClient;
/**
* 带语义缓存的 LLM 调用
*/
public String chat(String userMessage, String modelId) {
long start = System.currentTimeMillis();
// 1. 查语义缓存
Optional<String> cached = cacheService.get(userMessage, modelId);
if (cached.isPresent()) {
log.info("语义缓存命中,节省 LLM 调用,耗时 {}ms",
System.currentTimeMillis() - start);
return cached.get();
}
// 2. 缓存未命中,调用 LLM
String response = llmClient.chat(userMessage, modelId);
long llmLatency = System.currentTimeMillis() - start;
log.info("LLM 调用完成,耗时 {}ms", llmLatency);
// 3. 异步写入缓存(不阻塞返回)
CompletableFuture.runAsync(() ->
cacheService.put(userMessage, response, modelId)
).exceptionally(e -> {
log.error("异步写入语义缓存失败", e);
return null;
});
return response;
}
}三、AI 会话存储:多轮对话的上下文管理
多轮对话是 AI 应用的标配,会话存储要解决几个问题:
- 会话数据持久化(用户重新打开应用,历史对话还在)
- 上下文窗口管理(LLM 有 token 限制,不能把所有历史都塞进去)
- 多端同步(手机端和 PC 端的对话历史同步)
- 会话超时清理(避免 Redis 内存无限增长)
3.1 会话数据结构设计
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatSession {
private String sessionId;
private String userId;
private String title; // 对话标题(首条消息截取)
private List<ChatMessage> messages;
private long createdAt;
private long updatedAt;
private int totalTokens; // 累计 token 消耗
private String modelId;
}
@Data
@Builder
public class ChatMessage {
private String messageId;
private String role; // user / assistant / system
private String content;
private long timestamp;
private int tokenCount; // 该条消息的 token 数
private String modelId; // 生成该消息用的模型
}3.2 会话存储 Service
@Service
@RequiredArgsConstructor
@Slf4j
public class ChatSessionService {
private final StringRedisTemplate redisTemplate;
private final ObjectMapper objectMapper;
// 会话 key 前缀
private static final String SESSION_KEY_PREFIX = "chat:session:";
// 用户会话列表 key 前缀
private static final String USER_SESSIONS_KEY = "chat:user:sessions:";
// 会话默认 TTL:7 天
private static final long SESSION_TTL_DAYS = 7;
// 单个会话最大消息数(超出时裁剪历史)
private static final int MAX_MESSAGES_PER_SESSION = 100;
/**
* 创建新会话
*/
public ChatSession createSession(String userId, String modelId) {
String sessionId = UUID.randomUUID().toString();
ChatSession session = ChatSession.builder()
.sessionId(sessionId)
.userId(userId)
.messages(new ArrayList<>())
.createdAt(System.currentTimeMillis())
.updatedAt(System.currentTimeMillis())
.totalTokens(0)
.modelId(modelId)
.build();
saveSession(session);
// 把 sessionId 加入用户的会话列表(Sorted Set,按更新时间排序)
String userKey = USER_SESSIONS_KEY + userId;
redisTemplate.opsForZSet().add(
userKey, sessionId, (double) session.getCreatedAt());
// 用户会话列表的 TTL 延长到 30 天
redisTemplate.expire(userKey, Duration.ofDays(30));
log.info("创建新会话,sessionId={}, userId={}", sessionId, userId);
return session;
}
/**
* 添加消息到会话
*/
public void addMessage(String sessionId, ChatMessage message) {
ChatSession session = getSession(sessionId);
if (session == null) {
throw new SessionNotFoundException("会话不存在: " + sessionId);
}
// 生成消息 ID
message.setMessageId(UUID.randomUUID().toString());
message.setTimestamp(System.currentTimeMillis());
session.getMessages().add(message);
session.setUpdatedAt(System.currentTimeMillis());
session.setTotalTokens(session.getTotalTokens() + message.getTokenCount());
// 自动截断:保留最近 N 条(但保留 system 消息)
if (session.getMessages().size() > MAX_MESSAGES_PER_SESSION) {
pruneOldMessages(session);
}
// 更新会话标题(用第一条用户消息)
if (session.getTitle() == null && "user".equals(message.getRole())) {
String title = message.getContent();
session.setTitle(title.length() > 50
? title.substring(0, 50) + "..."
: title);
}
saveSession(session);
// 更新 Sorted Set 中的时间戳(保持按最新活跃排序)
String userKey = USER_SESSIONS_KEY + session.getUserId();
redisTemplate.opsForZSet().add(
userKey, sessionId, (double) session.getUpdatedAt());
}
/**
* 构建 LLM 上下文:从会话历史中选取最近的消息,控制 token 数量
*/
public List<ChatMessage> buildContext(
String sessionId, int maxTokens) {
ChatSession session = getSession(sessionId);
if (session == null) return Collections.emptyList();
List<ChatMessage> allMessages = session.getMessages();
List<ChatMessage> context = new ArrayList<>();
// 从最新消息往前取,直到 token 超限
int tokenCount = 0;
for (int i = allMessages.size() - 1; i >= 0; i--) {
ChatMessage msg = allMessages.get(i);
tokenCount += msg.getTokenCount();
if (tokenCount > maxTokens) break;
context.add(0, msg);
}
// 始终保留 system 消息(如果有的话)
Optional<ChatMessage> systemMsg = allMessages.stream()
.filter(m -> "system".equals(m.getRole()))
.findFirst();
if (systemMsg.isPresent() && !context.contains(systemMsg.get())) {
context.add(0, systemMsg.get());
}
return context;
}
/**
* 获取用户最近的会话列表(按活跃时间排序)
*/
public List<ChatSession> getUserRecentSessions(
String userId, int limit) {
String userKey = USER_SESSIONS_KEY + userId;
// 取最近 N 个(Sorted Set 按时间戳降序)
Set<String> sessionIds = redisTemplate.opsForZSet()
.reverseRange(userKey, 0, limit - 1);
if (sessionIds == null || sessionIds.isEmpty()) {
return Collections.emptyList();
}
return sessionIds.stream()
.map(this::getSession)
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
private void saveSession(ChatSession session) {
String key = SESSION_KEY_PREFIX + session.getSessionId();
try {
String json = objectMapper.writeValueAsString(session);
redisTemplate.opsForValue().set(
key, json, Duration.ofDays(SESSION_TTL_DAYS));
} catch (JsonProcessingException e) {
throw new RuntimeException("会话序列化失败", e);
}
}
public ChatSession getSession(String sessionId) {
String key = SESSION_KEY_PREFIX + sessionId;
String json = redisTemplate.opsForValue().get(key);
if (json == null) return null;
try {
// 每次访问刷新 TTL
redisTemplate.expire(key, Duration.ofDays(SESSION_TTL_DAYS));
return objectMapper.readValue(json, ChatSession.class);
} catch (JsonProcessingException e) {
log.error("会话反序列化失败,sessionId={}", sessionId, e);
return null;
}
}
/**
* 裁剪历史消息(保留最近消息,保护 system 消息)
*/
private void pruneOldMessages(ChatSession session) {
List<ChatMessage> messages = session.getMessages();
int targetSize = MAX_MESSAGES_PER_SESSION * 3 / 4; // 保留 75%
// 找出所有 system 消息
List<ChatMessage> systemMessages = messages.stream()
.filter(m -> "system".equals(m.getRole()))
.collect(Collectors.toList());
// 保留最近的 targetSize 条非 system 消息
List<ChatMessage> nonSystemMessages = messages.stream()
.filter(m -> !"system".equals(m.getRole()))
.collect(Collectors.toList());
List<ChatMessage> keptMessages = nonSystemMessages.stream()
.skip(Math.max(0, nonSystemMessages.size() - targetSize))
.collect(Collectors.toList());
// system 消息放最前面
List<ChatMessage> newMessages = new ArrayList<>(systemMessages);
newMessages.addAll(keptMessages);
session.setMessages(newMessages);
log.debug("裁剪会话历史,sessionId={}, 原{}条,裁剪至{}条",
session.getSessionId(), messages.size(), newMessages.size());
}
}四、AI 应用的 Redis 最佳实践
4.1 Key 设计规范
语义缓存:scache:{uuid}
会话数据:chat:session:{sessionId}
用户会话列表:chat:user:sessions:{userId}
用户兴趣向量:user:interest:{userId}
限流计数器:ratelimit:{userId}:{minute}统一的 key 前缀便于监控和清理,可以用 SCAN 命令按前缀扫描。
4.2 内存使用监控
@Scheduled(fixedDelay = 60000)
public void monitorMemoryUsage() {
String info = redisTemplate.execute(
(RedisCallback<String>) conn ->
new String(conn.info("memory"))
);
// 解析 used_memory 和 maxmemory
// 当使用率超过 80% 时告警
double usageRate = parseMemoryUsageRate(info);
if (usageRate > 0.8) {
log.warn("Redis 内存使用率 {:.1f}%,请关注", usageRate * 100);
alertService.sendAlert("Redis 内存告警", usageRate);
}
meterRegistry.gauge("redis.memory.usage_rate", usageRate);
}五、踩坑记录
坑1:相似度阈值设置太低导致错误命中
阈值设 0.85,结果"今天北京天气"和"今天上海天气"命中了同一条缓存,用户拿到了错误的天气信息。实体类查询(含地名、人名、时间等具体信息)必须用更高的阈值(0.95+),甚至直接禁用语义缓存。
解决方案:根据查询类型动态调整阈值,或者用 LLM 判断是否包含需要精确匹配的实体。
坑2:会话序列化的性能问题
会话消息多了(100条+)之后,序列化和反序列化耗时显著增加,而且 Redis 里存的 JSON 字符串也越来越大。
解决方案:消息列表单独用 Redis List 存储,读取时只取最近 N 条:
// 更好的方案:消息列表用 Redis List 存储
String msgKey = "chat:msgs:" + sessionId;
// 写入消息
redisTemplate.opsForList().rightPush(msgKey,
objectMapper.writeValueAsString(message));
// 裁剪超出的历史
redisTemplate.opsForList().trim(msgKey, -100, -1); // 只保留最近 100 条
// 读取最近 N 条
List<String> msgs = redisTemplate.opsForList().range(msgKey, -n, -1);坑3:语义缓存命中率统计误导优化方向
表面上命中率 40%,看起来不错。但实际上有些高频问题命中率 95%,很多长尾问题命中率 0%。正确的做法是分类统计命中率,针对高价值但低命中的类别专门优化。
语义缓存是 AI 应用降本增效最有价值的技术之一。在我们的知识库问答项目里,引入语义缓存后 LLM 调用量减少了 35%,每月节省了将近两万的 API 费用。这个投入产出比非常划算。
但语义缓存不是万能的,对于需要实时信息的查询(时间、价格、库存等),必须绕开缓存直接调 LLM。做好查询意图识别,是语义缓存能否真正发挥价值的关键。
