企业AI助手从0到1(三):对话系统与权限管控
企业AI助手从0到1(三):对话系统与权限管控
开篇故事:普通员工问出了高管薪资,差点出大事
上线第3天,陈建国接到了HR总监的紧急电话。
"建国,你们的AI出问题了,一个仓库理货员问了AI'公司总经理月薪多少',AI居然告诉他了!"
准确的数字。
陈建国当场冷汗,连夜排查。
问题的根因,他后来跟我说,既简单又荒唐:知识库里有一份"2024年薪酬调研报告",里面有所有职级的薪资范围,包括高管。这份文档的权限设置是"public"(全员可见),因为有人上传的时候搞错了。
但这只是表面原因。真正的深层原因是:系统从设计上就缺乏细粒度的权限控制,知识库权限是文档级别的,不是字段级别的,也没有内容脱敏机制。
这次事故的处置结果:
- 相关文档立即下线(影响业务2小时)
- 系统暂停使用1周,进行安全整改
- HR总监要求所有文档重新审核权限
- 陈建国在全公司IT会议上做了复盘报告
整改方案里,有3项是我帮他设计的:
- 文档权限分级:public/department/manager/executive四个级别
- 问答内容过滤:用户权限不足时,敏感字段自动脱敏或拒绝回答
- 问题意图检测:识别询问薪资、个人信息等敏感话题,触发权限检查
今天这篇文章,我们把对话系统完整建起来,权限管控是核心,代码全部可直接用于生产。
一、对话API设计:流式与非流式
1.1 接口设计原则
原则1:先有非流式,再做流式
非流式调试简单,确认逻辑正确后,再加流式SSE
原则2:流式优先用于前端,非流式用于API集成
人看的界面:流式(用户体验更好)
系统间调用:非流式(更容易处理)
原则3:流式中断要能恢复
SSE连接断了,前端能从断点继续(或重新发问)1.2 完整ChatController
// controller/ChatController.java
package com.enterprise.aiassistant.controller;
import com.enterprise.aiassistant.dto.request.ChatRequest;
import com.enterprise.aiassistant.dto.request.FeedbackRequest;
import com.enterprise.aiassistant.dto.request.SessionCreateRequest;
import com.enterprise.aiassistant.dto.response.ApiResponse;
import com.enterprise.aiassistant.dto.response.ChatResponse;
import com.enterprise.aiassistant.dto.response.SessionResponse;
import com.enterprise.aiassistant.entity.User;
import com.enterprise.aiassistant.service.ChatService;
import com.enterprise.aiassistant.service.SessionService;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springframework.http.MediaType;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.List;
@RestController
@RequestMapping("/api/chat")
@RequiredArgsConstructor
@Validated
public class ChatController {
private final ChatService chatService;
private final SessionService sessionService;
// ===== 会话管理 =====
@PostMapping("/sessions")
public ApiResponse<SessionResponse> createSession(
@Valid @RequestBody SessionCreateRequest request,
@AuthenticationPrincipal User user) {
return ApiResponse.success(sessionService.create(request, user));
}
@GetMapping("/sessions")
public ApiResponse<List<SessionResponse>> listSessions(
@AuthenticationPrincipal User user,
@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "20") int size) {
return ApiResponse.success(sessionService.listByUser(user, page, size));
}
@GetMapping("/sessions/{sessionId}")
public ApiResponse<SessionResponse> getSession(
@PathVariable String sessionId,
@AuthenticationPrincipal User user) {
return ApiResponse.success(sessionService.getBySessionId(sessionId, user));
}
@DeleteMapping("/sessions/{sessionId}")
public ApiResponse<Void> deleteSession(
@PathVariable String sessionId,
@AuthenticationPrincipal User user) {
sessionService.delete(sessionId, user);
return ApiResponse.success(null);
}
// ===== 消息发送 =====
/**
* 非流式对话
*/
@PostMapping("/sessions/{sessionId}/messages")
public ApiResponse<ChatResponse> chat(
@PathVariable String sessionId,
@Valid @RequestBody ChatRequest request,
@AuthenticationPrincipal User user) {
return ApiResponse.success(chatService.chat(sessionId, request, user));
}
/**
* 流式对话(SSE)
* 前端连接方式:EventSource / fetch with ReadableStream
*/
@GetMapping(value = "/sessions/{sessionId}/messages/stream",
produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter chatStream(
@PathVariable String sessionId,
@RequestParam String message,
@RequestParam(required = false) Long kbId,
@AuthenticationPrincipal User user) {
SseEmitter emitter = new SseEmitter(60_000L); // 60秒超时
chatService.chatStream(sessionId, message, kbId, user, emitter);
return emitter;
}
/**
* 查询历史消息
*/
@GetMapping("/sessions/{sessionId}/messages")
public ApiResponse<?> listMessages(
@PathVariable String sessionId,
@AuthenticationPrincipal User user,
@RequestParam(defaultValue = "0") int page,
@RequestParam(defaultValue = "50") int size) {
return ApiResponse.success(chatService.listMessages(sessionId, user, page, size));
}
/**
* 提交消息反馈(点赞/点踩)
*/
@PostMapping("/messages/{messageId}/feedback")
public ApiResponse<Void> submitFeedback(
@PathVariable Long messageId,
@Valid @RequestBody FeedbackRequest request,
@AuthenticationPrincipal User user) {
chatService.submitFeedback(messageId, request, user);
return ApiResponse.success(null);
}
}二、会话管理服务
// service/SessionService.java
package com.enterprise.aiassistant.service;
import com.enterprise.aiassistant.dto.request.SessionCreateRequest;
import com.enterprise.aiassistant.dto.response.SessionResponse;
import com.enterprise.aiassistant.entity.Session;
import com.enterprise.aiassistant.entity.User;
import com.enterprise.aiassistant.repository.SessionRepository;
import lombok.RequiredArgsConstructor;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Sort;
import org.springframework.stereotype.Service;
import java.time.LocalDateTime;
import java.util.List;
import java.util.UUID;
@Service
@RequiredArgsConstructor
public class SessionService {
private final SessionRepository sessionRepository;
public SessionResponse create(SessionCreateRequest request, User user) {
Session session = new Session();
session.setUser(user);
session.setSessionId(UUID.randomUUID().toString());
session.setTitle(request.getTitle() != null
? request.getTitle() : "新对话");
// 如果指定了知识库,存入session(后续对话默认使用该知识库)
if (request.getKbId() != null) {
session.setKbId(request.getKbId());
}
session.setCreatedAt(LocalDateTime.now());
session.setLastActive(LocalDateTime.now());
Session saved = sessionRepository.save(session);
return SessionResponse.from(saved);
}
public List<SessionResponse> listByUser(User user, int page, int size) {
var pageable = PageRequest.of(page, size,
Sort.by(Sort.Direction.DESC, "lastActive"));
return sessionRepository.findByUser(user, pageable)
.stream()
.map(SessionResponse::from)
.toList();
}
public SessionResponse getBySessionId(String sessionId, User user) {
Session session = sessionRepository.findBySessionId(sessionId)
.orElseThrow(() -> new RuntimeException("会话不存在: " + sessionId));
// 验证会话属于当前用户
if (!session.getUser().getId().equals(user.getId())) {
throw new RuntimeException("无权访问该会话");
}
return SessionResponse.from(session);
}
public void delete(String sessionId, User user) {
Session session = sessionRepository.findBySessionId(sessionId)
.orElseThrow(() -> new RuntimeException("会话不存在"));
if (!session.getUser().getId().equals(user.getId())) {
throw new RuntimeException("无权删除该会话");
}
sessionRepository.delete(session);
}
}三、核心:对话服务(含权限管控)
3.1 ChatService完整实现
// service/ChatService.java
package com.enterprise.aiassistant.service;
import com.enterprise.aiassistant.dto.request.ChatRequest;
import com.enterprise.aiassistant.dto.request.FeedbackRequest;
import com.enterprise.aiassistant.dto.response.ChatResponse;
import com.enterprise.aiassistant.entity.Message;
import com.enterprise.aiassistant.entity.Session;
import com.enterprise.aiassistant.entity.User;
import com.enterprise.aiassistant.repository.MessageRepository;
import com.enterprise.aiassistant.repository.SessionRepository;
import com.enterprise.aiassistant.security.SensitiveTopicDetector;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Sort;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatService {
private final ChatClient chatClient;
private final KnowledgeRetrievalService retrievalService;
private final SessionRepository sessionRepository;
private final MessageRepository messageRepository;
private final SensitiveTopicDetector sensitiveTopicDetector;
private static final String SYSTEM_PROMPT = """
你是一个企业内部AI助手,基于公司知识库回答员工问题。
回答规则:
1. 只基于提供的参考文档回答,不要添加文档中没有的信息
2. 如果参考文档中没有相关内容,明确说"我在知识库中没有找到相关信息,建议联系HR/相关部门"
3. 回答要简洁准确,步骤类问题用编号列表
4. 所有回答末尾附注"以上信息仅供参考,具体以官方文件为准"
5. 不要编造数字、日期、比例等具体数据
参考文档:
%s
""";
/**
* 非流式对话
*/
public com.enterprise.aiassistant.dto.response.ChatResponse chat(
String sessionId, ChatRequest request, User user) {
long startTime = System.currentTimeMillis();
// 1. 验证会话
Session session = getAndValidateSession(sessionId, user);
// 2. 敏感话题检测
SensitiveTopicDetector.Result sensitiveCheck =
sensitiveTopicDetector.check(request.getMessage(), user);
if (sensitiveCheck.isBlocked()) {
return buildBlockedResponse(sensitiveCheck.getReason());
}
// 3. 检索相关知识
Long kbId = request.getKbId() != null ? request.getKbId() : session.getKbId();
List<Document> relevantDocs = retrievalService.retrieve(
request.getMessage(), user, kbId);
// 4. 构建上下文(最近10条历史消息)
List<org.springframework.ai.chat.messages.Message> history =
buildConversationHistory(session.getId(), 10);
// 5. 构建系统Prompt(注入检索结果)
String contextText = buildContextText(relevantDocs);
String systemPrompt = String.format(SYSTEM_PROMPT, contextText);
// 6. 调用LLM
List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
messages.add(new SystemMessage(systemPrompt));
messages.addAll(history);
messages.add(new UserMessage(request.getMessage()));
String answer;
int tokensUsed = 0;
try {
org.springframework.ai.chat.model.ChatResponse aiResponse =
chatClient.prompt()
.messages(messages)
.call()
.chatResponse();
answer = aiResponse.getResult().getOutput().getText();
if (aiResponse.getMetadata() != null
&& aiResponse.getMetadata().getUsage() != null) {
tokensUsed = (int) aiResponse.getMetadata().getUsage().getTotalTokens();
}
} catch (Exception e) {
log.error("LLM调用失败: {}", e.getMessage(), e);
throw new RuntimeException("AI服务暂时不可用,请稍后重试");
}
long latencyMs = System.currentTimeMillis() - startTime;
// 7. 异步保存对话记录
saveMessageAsync(session, request.getMessage(), answer,
tokensUsed, (int) latencyMs, relevantDocs);
// 8. 更新会话最后活跃时间
sessionRepository.updateLastActive(session.getId(), LocalDateTime.now());
return com.enterprise.aiassistant.dto.response.ChatResponse.builder()
.answer(answer)
.sources(buildSources(relevantDocs))
.tokensUsed(tokensUsed)
.latencyMs((int) latencyMs)
.build();
}
/**
* 流式对话(SSE)
*/
public void chatStream(String sessionId, String message, Long kbId,
User user, SseEmitter emitter) {
CompletableFuture.runAsync(() -> {
long startTime = System.currentTimeMillis();
try {
// 1. 验证
Session session = getAndValidateSession(sessionId, user);
// 2. 敏感话题检测
SensitiveTopicDetector.Result sensitiveCheck =
sensitiveTopicDetector.check(message, user);
if (sensitiveCheck.isBlocked()) {
emitter.send(SseEmitter.event()
.name("blocked")
.data("{\"reason\":\"" + sensitiveCheck.getReason() + "\"}"));
emitter.complete();
return;
}
// 3. 检索
Long effectiveKbId = kbId != null ? kbId : session.getKbId();
List<Document> relevantDocs =
retrievalService.retrieve(message, user, effectiveKbId);
// 4. 发送"检索完成"事件(前端可显示来源文档)
emitter.send(SseEmitter.event()
.name("sources")
.data(buildSourcesJson(relevantDocs)));
// 5. 构建消息
List<org.springframework.ai.chat.messages.Message> history =
buildConversationHistory(session.getId(), 10);
String contextText = buildContextText(relevantDocs);
String systemPrompt = String.format(SYSTEM_PROMPT, contextText);
List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
messages.add(new SystemMessage(systemPrompt));
messages.addAll(history);
messages.add(new UserMessage(message));
// 6. 流式调用LLM
StringBuilder fullAnswer = new StringBuilder();
chatClient.prompt()
.messages(messages)
.stream()
.chatResponse()
.doOnNext(chunk -> {
String text = chunk.getResult().getOutput().getText();
if (text != null && !text.isEmpty()) {
fullAnswer.append(text);
try {
emitter.send(SseEmitter.event()
.name("chunk")
.data("{\"content\":" + escapeJson(text) + "}"));
} catch (Exception e) {
log.warn("SSE发送失败: {}", e.getMessage());
}
}
})
.doOnError(e -> {
log.error("流式响应错误: {}", e.getMessage(), e);
try {
emitter.send(SseEmitter.event()
.name("error")
.data("{\"message\":\"AI服务暂时不可用\"}"));
emitter.complete();
} catch (Exception ex) {
emitter.completeWithError(ex);
}
})
.doOnComplete(() -> {
long latencyMs = System.currentTimeMillis() - startTime;
try {
emitter.send(SseEmitter.event()
.name("done")
.data("{\"latencyMs\":" + latencyMs + "}"));
emitter.complete();
} catch (Exception e) {
emitter.completeWithError(e);
}
// 异步保存
saveMessageAsync(session, message, fullAnswer.toString(),
0, (int) latencyMs, relevantDocs);
sessionRepository.updateLastActive(
session.getId(), LocalDateTime.now());
})
.subscribe();
} catch (Exception e) {
log.error("流式对话异常: {}", e.getMessage(), e);
try {
emitter.send(SseEmitter.event()
.name("error")
.data("{\"message\":\"" + e.getMessage() + "\"}"));
emitter.complete();
} catch (Exception ex) {
emitter.completeWithError(ex);
}
}
});
}
private Session getAndValidateSession(String sessionId, User user) {
Session session = sessionRepository.findBySessionId(sessionId)
.orElseThrow(() -> new RuntimeException("会话不存在: " + sessionId));
if (!session.getUser().getId().equals(user.getId())) {
throw new RuntimeException("无权访问该会话");
}
return session;
}
private List<org.springframework.ai.chat.messages.Message> buildConversationHistory(
Long sessionId, int maxMessages) {
var pageable = PageRequest.of(0, maxMessages,
Sort.by(Sort.Direction.DESC, "createdAt"));
List<Message> recentMessages = messageRepository
.findBySessionId(sessionId, pageable);
// 反转顺序(从旧到新)
List<org.springframework.ai.chat.messages.Message> history = new ArrayList<>();
for (int i = recentMessages.size() - 1; i >= 0; i--) {
Message msg = recentMessages.get(i);
if ("user".equals(msg.getRole())) {
history.add(new UserMessage(msg.getContent()));
} else {
history.add(new AssistantMessage(msg.getContent()));
}
}
return history;
}
private String buildContextText(List<Document> docs) {
if (docs.isEmpty()) {
return "(知识库中未找到相关文档)";
}
StringBuilder sb = new StringBuilder();
for (int i = 0; i < docs.size(); i++) {
Document doc = docs.get(i);
String docTitle = (String) doc.getMetadata().getOrDefault("doc_title", "未知文档");
sb.append(String.format("【文档%d: %s】\n%s\n\n",
i + 1, docTitle, doc.getText()));
}
return sb.toString();
}
private String buildSourcesJson(List<Document> docs) {
StringBuilder sb = new StringBuilder("[");
for (int i = 0; i < docs.size(); i++) {
Document doc = docs.get(i);
if (i > 0) sb.append(",");
sb.append(String.format("{\"docTitle\":\"%s\",\"similarity\":%.2f}",
doc.getMetadata().getOrDefault("doc_title", "未知"),
doc.getScore()));
}
sb.append("]");
return sb.toString();
}
private List<com.enterprise.aiassistant.dto.response.SourceInfo> buildSources(
List<Document> docs) {
return docs.stream()
.map(d -> new com.enterprise.aiassistant.dto.response.SourceInfo(
(String) d.getMetadata().getOrDefault("doc_title", "未知文档"),
d.getScore(),
d.getText().substring(0, Math.min(200, d.getText().length()))
))
.toList();
}
@Async
protected void saveMessageAsync(Session session, String userMessage,
String assistantAnswer, int tokensUsed,
int latencyMs, List<Document> retrievedDocs) {
try {
// 保存用户消息
Message userMsg = new Message();
userMsg.setSession(session);
userMsg.setRole("user");
userMsg.setContent(userMessage);
userMsg.setCreatedAt(LocalDateTime.now());
messageRepository.save(userMsg);
// 保存AI回复
Message assistantMsg = new Message();
assistantMsg.setSession(session);
assistantMsg.setRole("assistant");
assistantMsg.setContent(assistantAnswer);
assistantMsg.setTokensUsed(tokensUsed);
assistantMsg.setLatencyMs(latencyMs);
// 存储检索到的文档ID,用于溯源
if (!retrievedDocs.isEmpty()) {
List<String> docIds = retrievedDocs.stream()
.map(d -> d.getMetadata().getOrDefault("doc_id", "").toString())
.toList();
// 序列化为JSON存入retrieved_chunks字段
assistantMsg.setRetrievedChunks(docIds.toString());
}
assistantMsg.setCreatedAt(LocalDateTime.now());
messageRepository.save(assistantMsg);
// 更新会话消息计数
sessionRepository.incrementMessageCount(session.getId());
} catch (Exception e) {
log.error("保存对话记录失败: {}", e.getMessage(), e);
// 不影响主流程,只记录日志
}
}
public void submitFeedback(Long messageId, FeedbackRequest request, User user) {
Message message = messageRepository.findById(messageId)
.orElseThrow(() -> new RuntimeException("消息不存在"));
// 验证消息属于当前用户的会话
if (!message.getSession().getUser().getId().equals(user.getId())) {
throw new RuntimeException("无权操作该消息");
}
messageRepository.updateFeedback(messageId,
request.getFeedback(), request.getFeedbackText());
}
private com.enterprise.aiassistant.dto.response.ChatResponse buildBlockedResponse(
String reason) {
return com.enterprise.aiassistant.dto.response.ChatResponse.builder()
.answer("抱歉,您的问题涉及敏感信息,我无法回答。" + reason
+ "如有需要,请联系对应部门负责人。")
.sources(List.of())
.tokensUsed(0)
.latencyMs(0)
.build();
}
private String escapeJson(String text) {
return "\"" + text.replace("\\", "\\\\")
.replace("\"", "\\\"")
.replace("\n", "\\n")
.replace("\r", "\\r") + "\"";
}
public List<?> listMessages(String sessionId, User user, int page, int size) {
Session session = getAndValidateSession(sessionId, user);
var pageable = PageRequest.of(page, size,
Sort.by(Sort.Direction.ASC, "createdAt"));
return messageRepository.findBySessionId(session.getId(), pageable);
}
}四、核心安全:敏感话题检测器
这是陈建国事故整改后新增的最重要组件。
// security/SensitiveTopicDetector.java
package com.enterprise.aiassistant.security;
import com.enterprise.aiassistant.entity.User;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.regex.Pattern;
@Slf4j
@Component
public class SensitiveTopicDetector {
/**
* 敏感话题规则列表
* 每条规则包含:关键词模式、最低权限角色、拦截原因
*/
private static final List<SensitiveRule> RULES = List.of(
// 薪资相关:需要manager及以上角色
new SensitiveRule(
Pattern.compile("(?:薪资|工资|月薪|年薪|收入|薪酬|报酬|底薪|绩效奖金)" +
".*(?:多少|几|是什么|怎么算)"),
"manager",
"薪资信息为敏感信息,需要manager权限"
),
// 个人信息查询:仅管理员
new SensitiveRule(
Pattern.compile("(?:谁的|某某|他|她|张.{1,3}|李.{1,3}).*" +
"(?:工资|薪资|地址|电话|身份证|家庭)"),
"admin",
"查询他人个人信息需要admin权限"
),
// 高管信息
new SensitiveRule(
Pattern.compile("(?:总经理|CEO|董事长|副总|高管|领导).*" +
"(?:薪资|工资|年薪|股权|期权)"),
"executive",
"高管薪酬信息为机密信息"
),
// 财务数据
new SensitiveRule(
Pattern.compile("(?:公司|我们|企业).*(?:净利润|营收|利润|亏损|融资|估值)"),
"manager",
"财务数据为敏感信息"
)
);
/**
* 检测问题是否涉及敏感话题
*/
public Result check(String question, User user) {
if (question == null || question.isBlank()) {
return Result.allowed();
}
for (SensitiveRule rule : RULES) {
if (rule.pattern().matcher(question).find()) {
// 检查用户是否有足够权限
if (!hasRequiredRole(user, rule.requiredRole())) {
log.warn("敏感话题拦截: user={}, question={}, rule={}",
user.getUsername(),
question.substring(0, Math.min(30, question.length())),
rule.reason());
return Result.blocked(rule.reason());
}
}
}
return Result.allowed();
}
/**
* 角色权限检查(角色层级:employee < manager < executive < admin)
*/
private boolean hasRequiredRole(User user, String requiredRole) {
int userLevel = getRoleLevel(user.getRole());
int requiredLevel = getRoleLevel(requiredRole);
return userLevel >= requiredLevel;
}
private int getRoleLevel(String role) {
return switch (role) {
case "employee" -> 1;
case "manager" -> 2;
case "executive"-> 3;
case "admin" -> 4;
default -> 0;
};
}
public record SensitiveRule(Pattern pattern, String requiredRole, String reason) {}
public record Result(boolean isBlocked, String reason) {
public static Result allowed() { return new Result(false, null); }
public static Result blocked(String reason) { return new Result(true, reason); }
}
}五、Spring Security权限集成
5.1 JWT认证过滤器
// security/JwtAuthFilter.java
package com.enterprise.aiassistant.security;
import com.enterprise.aiassistant.entity.User;
import com.enterprise.aiassistant.repository.UserRepository;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException;
import java.util.List;
@Slf4j
@Component
@RequiredArgsConstructor
public class JwtAuthFilter extends OncePerRequestFilter {
private final JwtTokenProvider tokenProvider;
private final UserRepository userRepository;
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response,
FilterChain chain)
throws ServletException, IOException {
String token = extractToken(request);
if (token != null) {
try {
if (tokenProvider.validateToken(token)) {
Long userId = tokenProvider.getUserIdFromToken(token);
User user = userRepository.findById(userId)
.orElse(null);
if (user != null && user.isActive()) {
// 构建Spring Security认证对象
var authorities = List.of(
new SimpleGrantedAuthority("ROLE_" + user.getRole().toUpperCase())
);
var authentication = new UsernamePasswordAuthenticationToken(
user, null, authorities);
SecurityContextHolder.getContext()
.setAuthentication(authentication);
}
}
} catch (Exception e) {
log.warn("JWT认证失败: {}", e.getMessage());
// 不抛出异常,让后续Filter判断
}
}
chain.doFilter(request, response);
}
private String extractToken(HttpServletRequest request) {
String bearerToken = request.getHeader("Authorization");
if (StringUtils.hasText(bearerToken) && bearerToken.startsWith("Bearer ")) {
return bearerToken.substring(7);
}
return null;
}
}5.2 JWT Token提供者
// security/JwtTokenProvider.java
package com.enterprise.aiassistant.security;
import io.jsonwebtoken.*;
import io.jsonwebtoken.security.Keys;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import javax.crypto.SecretKey;
import java.nio.charset.StandardCharsets;
import java.util.Date;
@Slf4j
@Component
public class JwtTokenProvider {
@Value("${app.jwt.secret}")
private String jwtSecret;
@Value("${app.jwt.expiration}")
private long jwtExpiration;
private SecretKey getSigningKey() {
return Keys.hmacShaKeyFor(jwtSecret.getBytes(StandardCharsets.UTF_8));
}
public String generateToken(Long userId, String username, String role) {
Date now = new Date();
Date expiry = new Date(now.getTime() + jwtExpiration);
return Jwts.builder()
.subject(userId.toString())
.claim("username", username)
.claim("role", role)
.issuedAt(now)
.expiration(expiry)
.signWith(getSigningKey())
.compact();
}
public boolean validateToken(String token) {
try {
Jwts.parser()
.verifyWith(getSigningKey())
.build()
.parseSignedClaims(token);
return true;
} catch (JwtException | IllegalArgumentException e) {
log.warn("JWT验证失败: {}", e.getMessage());
return false;
}
}
public Long getUserIdFromToken(String token) {
Claims claims = Jwts.parser()
.verifyWith(getSigningKey())
.build()
.parseSignedClaims(token)
.getPayload();
return Long.parseLong(claims.getSubject());
}
}六、对话日志:合规级别的完整记录
6.1 审计日志表
-- 审计日志表(合规需求:完整记录所有操作)
CREATE TABLE audit_logs (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
username VARCHAR(100) NOT NULL,
department VARCHAR(100),
action VARCHAR(50) NOT NULL,
-- ASK_QUESTION / VIEW_SESSION / DELETE_SESSION / SENSITIVE_BLOCKED
resource_type VARCHAR(50),
resource_id VARCHAR(100),
detail JSONB, -- 详细信息(问题内容、使用的知识库等)
ip_address VARCHAR(50),
user_agent TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX idx_audit_user ON audit_logs(user_id, created_at DESC);
CREATE INDEX idx_audit_action ON audit_logs(action, created_at DESC);6.2 审计日志服务
// service/AuditLogService.java
package com.enterprise.aiassistant.service;
import com.enterprise.aiassistant.entity.User;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import java.util.Map;
@Slf4j
@Service
@RequiredArgsConstructor
public class AuditLogService {
private final JdbcTemplate jdbcTemplate;
private final ObjectMapper objectMapper;
/**
* 异步记录审计日志
* 审计日志是合规要求,必须记录,但不能阻塞主流程
*/
@Async
public void log(User user, String action, String resourceType,
String resourceId, Map<String, Object> detail) {
try {
String detailJson = objectMapper.writeValueAsString(detail);
String ipAddress = getClientIp();
jdbcTemplate.update(
"""
INSERT INTO audit_logs
(user_id, username, department, action,
resource_type, resource_id, detail, ip_address)
VALUES (?, ?, ?, ?, ?, ?, ?::jsonb, ?)
""",
user.getId(), user.getUsername(), user.getDepartment(),
action, resourceType, resourceId, detailJson, ipAddress
);
} catch (Exception e) {
// 审计日志失败不能影响主业务
log.error("审计日志记录失败: action={}, error={}", action, e.getMessage());
}
}
private String getClientIp() {
try {
var attrs = (ServletRequestAttributes)
RequestContextHolder.currentRequestAttributes();
String ip = attrs.getRequest().getHeader("X-Forwarded-For");
if (ip == null || ip.isBlank()) {
ip = attrs.getRequest().getRemoteAddr();
}
return ip;
} catch (Exception e) {
return "unknown";
}
}
}七、用户反馈收集
7.1 反馈DTO
// dto/request/FeedbackRequest.java
package com.enterprise.aiassistant.dto.request;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import lombok.Data;
@Data
public class FeedbackRequest {
@NotBlank(message = "反馈类型不能为空")
@Pattern(regexp = "like|dislike", message = "反馈类型只能是like或dislike")
private String feedback;
// 文字反馈(可选)
private String feedbackText;
}7.2 反馈统计查询
// repository/MessageRepository.java(关键查询方法)
package com.enterprise.aiassistant.repository;
import com.enterprise.aiassistant.entity.Message;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import java.time.LocalDate;
import java.util.List;
public interface MessageRepository extends JpaRepository<Message, Long> {
@Query("""
SELECT m FROM Message m
WHERE m.session.id = :sessionId
ORDER BY m.createdAt DESC
""")
List<Message> findBySessionId(@Param("sessionId") Long sessionId,
org.springframework.data.domain.Pageable pageable);
@Modifying
@Query("""
UPDATE Message m
SET m.feedback = :feedback, m.feedbackText = :feedbackText
WHERE m.id = :messageId
""")
void updateFeedback(@Param("messageId") Long messageId,
@Param("feedback") String feedback,
@Param("feedbackText") String feedbackText);
@Query("""
SELECT m.content FROM Message m
WHERE m.role = 'user'
AND m.retrievedChunks = '[]'
AND CAST(m.createdAt AS LocalDate) = :date
""")
List<String> findMissedQueries(@Param("date") LocalDate date);
@Query("""
SELECT COUNT(m) FROM Message m
WHERE m.role = 'user'
AND CAST(m.createdAt AS LocalDate) BETWEEN :startDate AND :endDate
""")
long countByDateRange(@Param("startDate") LocalDate startDate,
@Param("endDate") LocalDate endDate);
@Query("""
SELECT COUNT(m) FROM Message m
WHERE m.role = 'user'
AND m.retrievedChunks = '[]'
AND CAST(m.createdAt AS LocalDate) BETWEEN :startDate AND :endDate
""")
long countMissedByDateRange(@Param("startDate") LocalDate startDate,
@Param("endDate") LocalDate endDate);
@Query("""
SELECT COUNT(m) FROM Message m
WHERE m.feedback = :feedback
AND CAST(m.createdAt AS LocalDate) BETWEEN :startDate AND :endDate
""")
long countByFeedback(@Param("feedback") String feedback,
@Param("startDate") LocalDate startDate,
@Param("endDate") LocalDate endDate);
}八、多轮对话上下文策略
8.1 上下文注入的4种策略
策略1:全量历史(不推荐)
把所有历史消息都注入
缺点:Token消耗爆炸,超过context window就报错
策略2:最近N条(推荐,我们用的方案)
只注入最近10条消息
优点:实现简单,Token可控
缺点:长对话丢失早期信息
策略3:摘要压缩
历史消息超过20条时,先用LLM做摘要,再注入
优点:保留长期信息
缺点:实现复杂,多一次LLM调用
策略4:关键信息提取
从历史消息提取关键实体(用户说的名字/数字/前提条件),注入当前对话
优点:信息密度高
缺点:提取逻辑复杂,容易出错陈建国用的是策略2,对1800人的企业场景足够用了。
8.2 上下文Token估算
// util/TokenEstimator.java
package com.enterprise.aiassistant.util;
import org.springframework.stereotype.Component;
@Component
public class TokenEstimator {
/**
* 粗略估算Token数量
* 中文:约1字=1-2 token
* 英文:约4字符=1 token
*/
public int estimate(String text) {
if (text == null) return 0;
long chineseChars = text.chars()
.filter(c -> c >= 0x4E00 && c <= 0x9FFF)
.count();
long otherChars = text.length() - chineseChars;
// 中文字符:1.5 token/字(估算)
// 其他字符:0.25 token/字符(英文约4字符1token)
return (int) (chineseChars * 1.5 + otherChars * 0.25);
}
/**
* 检查是否接近context window上限
* GPT-4o: 128K tokens
* Qwen-Max: 32K tokens
*/
public boolean isNearLimit(List<String> texts, int limitTokens) {
int total = texts.stream().mapToInt(this::estimate).sum();
return total > (int) (limitTokens * 0.8); // 80%时告警
}
}九、前端集成:流式响应实现建议
// 前端SSE接入示例(Vue3/React均适用)
// 使用fetch替代EventSource,支持自定义Headers(如Authorization)
async function chatStream(sessionId, message, kbId, onChunk, onSources, onDone) {
const token = localStorage.getItem('jwt_token');
const url = `/api/chat/sessions/${sessionId}/messages/stream`
+ `?message=${encodeURIComponent(message)}`
+ (kbId ? `&kbId=${kbId}` : '');
const response = await fetch(url, {
method: 'GET',
headers: {
'Authorization': `Bearer ${token}`,
'Accept': 'text/event-stream',
},
});
if (!response.ok) {
throw new Error(`HTTP错误: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\n');
buffer = lines.pop(); // 保留不完整的行
for (const line of lines) {
if (line.startsWith('event: chunk')) {
// 下一行是data
continue;
}
if (line.startsWith('data: ') && buffer.includes('chunk')) {
const data = JSON.parse(line.slice(6));
onChunk(data.content);
} else if (line.startsWith('data: ') && buffer.includes('sources')) {
const sources = JSON.parse(line.slice(6));
onSources(sources);
} else if (line.startsWith('data: ') && buffer.includes('done')) {
const data = JSON.parse(line.slice(6));
onDone(data);
}
}
}
}十、性能指标与优化目标
| 指标 | 目标值 | 陈建国项目实际 | 优化手段 |
|---|---|---|---|
| 首字延迟(流式) | <2秒 | 1.6秒 | Redis缓存热点问题,减少检索时间 |
| 完整响应时间(非流式) | <5秒 | 平均3.2秒 | 减少top_k,缩短Prompt |
| 并发用户支持 | 100人同时在线 | 实测120人 | 连接池调优,异步日志 |
| 权限检查延迟 | <10ms | 2ms(内存判断) | 规则缓存在内存 |
| 审计日志写入 | 异步不阻塞 | 0ms(主流程) | @Async完全异步 |
FAQ
Q1:用户绕过敏感词检测怎么办?比如用拼音问薪资?
A:当前正则检测确实可以绕过。进阶方案:1)用LLM做意图分类(准确率99%+,但每次多一次API调用成本约0.002元);2)对特定知识库(如薪酬相关)在元数据层直接控制权限,比正则更可靠。对于制造企业的HR场景,当前方案够用,复杂绕过攻击概率极低。
Q2:JWT Token过期了怎么处理?
A:提供refresh token接口,前端检测到401响应自动用refresh token换新token,对用户无感知。refresh token有效期7天,存在Redis里,支持主动吊销(用户退出时删Redis里的记录)。
Q3:多个用户同时问问题,会话数据会串吗?
A:不会。每个用户的会话都绑定了user_id,sessionId是UUID,数据库查询都带WHERE user_id = ?过滤,线程安全。
Q4:SSE连接被代理(Nginx)断开怎么办?
A:Nginx需要配置:proxy_read_timeout 300;、proxy_buffering off;、X-Accel-Buffering: no。60秒内没有数据推送,Nginx会断开连接,前端需要处理重连逻辑。
Q5:对话记录需要保留多久?
A:根据合规要求,企业内部对话记录建议保留3年。超过3年的数据迁移到冷存储(如S3),定期清理数据库热数据。
系列预告
这是系列第三篇,我们完成了对话系统的核心功能:会话管理、多轮上下文、流式响应、权限管控(含敏感话题检测)、审计日志、用户反馈。
下一篇(article-178),我们进入上线部署与运营优化。上线第一天1000用户,系统扛住了,但第二天发现了哪些问题?20项上线检查清单、Grafana监控大屏、灰度发布方案——全部揭秘。
