第2001篇:从零搭建AI Agent框架——ReAct模式在Java中的工程实现
第2001篇:从零搭建AI Agent框架——ReAct模式在Java中的工程实现
适读人群:有Spring Boot基础、想深入理解AI Agent工程化的Java开发者 | 阅读时长:约22分钟 | 核心价值:理解ReAct模式的本质,掌握在Java中实现可生产的Agent框架
去年我们团队接到一个需求:给内部ERP系统加一个"智能助手",要能查询订单状态、修改交货期、发送邮件通知。
听起来简单,但实现起来踩了很多坑。
最初的方案是:用户说什么,我们就构造一个Prompt发给LLM,LLM返回文本,我们parse文本提取操作,然后执行。这套东西勉强能跑,但:
- 单次对话无法处理需要多步操作的任务(比如"先查一下订单,如果延迟了就自动发邮件给客户")
- LLM的输出格式不稳定,parse逻辑脆弱
- 出错了不知道哪步出的问题,调试困难
后来我认真研究了ReAct(Reasoning + Acting)论文,才明白Agent该怎么设计。这篇文章就是从那次重构里提炼出来的。
先理解ReAct的本质
ReAct不是什么神秘的算法,它就是把LLM的思考过程结构化:
Thought: 我需要先查订单O-20240115的状态
Action: query_order
Action Input: {"order_id": "O-20240115"}
Observation: 订单状态=延迟,预计延迟3天
Thought: 订单延迟了,需要发邮件通知客户
Action: send_email
Action Input: {"to": "customer@example.com", "subject": "订单延迟通知", ...}
Observation: 邮件发送成功
Thought: 任务完成
Final Answer: 已查询到订单O-20240115延迟3天,已自动通知客户。每一轮LLM输出一个Thought(推理)+Action(要用什么工具)+Action Input(工具的输入参数),我们的框架执行这个工具,把结果作为Observation反馈给LLM,然后进入下一轮。
这个循环一直持续到LLM输出Final Answer为止。
核心数据结构设计
先定义Agent框架里最重要的几个概念:
/**
* 工具接口——每个Agent能使用的"能力单元"
*/
public interface AgentTool {
String name(); // 工具名称,LLM在Action字段填这个
String description(); // 工具描述,放进System Prompt告诉LLM这个工具干什么
String parametersSchema(); // 参数的JSON Schema,让LLM知道怎么填参数
/**
* 执行工具,返回执行结果字符串(会作为Observation反馈给LLM)
*/
String execute(Map<String, Object> parameters);
}
/**
* 一步推理的结构化表示
*/
@Data
@Builder
public class AgentStep {
private String thought; // LLM的推理过程
private String actionName; // 调用的工具名称(null表示Final Answer)
private Map<String, Object> actionInput; // 工具参数
private String observation; // 工具执行结果
private String finalAnswer; // 最终答案(非null表示任务完成)
private boolean isError; // 这一步是否出错
private String errorMessage; // 错误信息
}
/**
* Agent执行结果
*/
@Data
@Builder
public class AgentResult {
private String finalAnswer;
private List<AgentStep> steps; // 完整的推理链
private int totalTokensUsed;
private long totalDurationMs;
private boolean succeeded;
private String failureReason;
}Agent的执行引擎
这是整个框架的核心——ReAct循环:
@Service
@Slf4j
public class ReActAgent {
private final ChatClient chatClient;
private final ObjectMapper objectMapper;
// 最大迭代次数,防止无限循环
private static final int MAX_ITERATIONS = 10;
public AgentResult run(String userQuery, List<AgentTool> tools) {
long startTime = System.currentTimeMillis();
List<AgentStep> steps = new ArrayList<>();
// 构建包含工具描述的System Prompt
String systemPrompt = buildSystemPrompt(tools);
// 构建初始对话历史
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage(systemPrompt));
messages.add(new UserMessage(userQuery));
// ReAct主循环
for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++) {
log.debug("Agent迭代 {}/{}", iteration + 1, MAX_ITERATIONS);
// 调用LLM
String llmResponse;
try {
llmResponse = callLlm(messages);
} catch (Exception e) {
log.error("LLM调用失败", e);
return AgentResult.builder()
.succeeded(false)
.failureReason("LLM调用失败: " + e.getMessage())
.steps(steps)
.build();
}
// 解析LLM的输出
AgentStep step = parseAgentStep(llmResponse);
// 如果是Final Answer,结束循环
if (step.getFinalAnswer() != null) {
steps.add(step);
return AgentResult.builder()
.finalAnswer(step.getFinalAnswer())
.steps(steps)
.totalDurationMs(System.currentTimeMillis() - startTime)
.succeeded(true)
.build();
}
// 执行工具
String observation = executeTool(step.getActionName(), step.getActionInput(), tools);
step.setObservation(observation);
steps.add(step);
// 把这一步加入对话历史
messages.add(new AssistantMessage(llmResponse));
messages.add(new UserMessage("Observation: " + observation));
}
// 超过最大迭代次数
return AgentResult.builder()
.succeeded(false)
.failureReason("超过最大迭代次数 " + MAX_ITERATIONS)
.steps(steps)
.build();
}
private String buildSystemPrompt(List<AgentTool> tools) {
StringBuilder sb = new StringBuilder();
sb.append("你是一个智能助手,可以使用以下工具来帮助用户完成任务:\n\n");
for (AgentTool tool : tools) {
sb.append("工具名称: ").append(tool.name()).append("\n");
sb.append("功能描述: ").append(tool.description()).append("\n");
sb.append("参数格式: ").append(tool.parametersSchema()).append("\n\n");
}
sb.append("""
请按照以下格式进行推理和操作:
Thought: [你的推理过程]
Action: [工具名称]
Action Input: [JSON格式的参数]
等待Observation后继续推理。当任务完成时,用以下格式结束:
Thought: [最终推理]
Final Answer: [给用户的最终回答]
注意:
1. Action Input必须是合法的JSON格式
2. 每次只能调用一个工具
3. 仔细阅读Observation,再决定下一步行动
""");
return sb.toString();
}
private String callLlm(List<Message> messages) {
return chatClient.prompt()
.messages(messages)
.call()
.content();
}
private AgentStep parseAgentStep(String llmResponse) {
AgentStep.AgentStepBuilder builder = AgentStep.builder();
try {
// 提取Thought
String thought = extractSection(llmResponse, "Thought:", "Action:");
if (thought == null) {
thought = extractSection(llmResponse, "Thought:", "Final Answer:");
}
builder.thought(thought != null ? thought.trim() : "");
// 检查是否是Final Answer
String finalAnswer = extractSection(llmResponse, "Final Answer:", null);
if (finalAnswer != null) {
builder.finalAnswer(finalAnswer.trim());
return builder.build();
}
// 提取Action
String actionName = extractSection(llmResponse, "Action:", "Action Input:");
if (actionName != null) {
builder.actionName(actionName.trim());
}
// 提取Action Input(JSON)
String actionInputStr = extractSection(llmResponse, "Action Input:", "Observation:");
if (actionInputStr == null) {
actionInputStr = extractSection(llmResponse, "Action Input:", null);
}
if (actionInputStr != null) {
// 清理JSON字符串(可能有代码块标记)
actionInputStr = cleanJsonString(actionInputStr.trim());
Map<String, Object> actionInput = objectMapper.readValue(
actionInputStr, new TypeReference<Map<String, Object>>() {}
);
builder.actionInput(actionInput);
}
} catch (Exception e) {
log.warn("解析Agent输出失败: {}", llmResponse, e);
builder.isError(true).errorMessage("解析LLM输出失败: " + e.getMessage());
}
return builder.build();
}
private String extractSection(String text, String startMarker, String endMarker) {
int startIdx = text.indexOf(startMarker);
if (startIdx == -1) return null;
startIdx += startMarker.length();
if (endMarker != null) {
int endIdx = text.indexOf(endMarker, startIdx);
if (endIdx == -1) return text.substring(startIdx);
return text.substring(startIdx, endIdx);
}
return text.substring(startIdx);
}
private String cleanJsonString(String s) {
// 去掉Markdown代码块标记
if (s.startsWith("```json")) s = s.substring(7);
if (s.startsWith("```")) s = s.substring(3);
if (s.endsWith("```")) s = s.substring(0, s.length() - 3);
return s.trim();
}
private String executeTool(String toolName, Map<String, Object> params,
List<AgentTool> tools) {
// 按名称查找工具
Optional<AgentTool> tool = tools.stream()
.filter(t -> t.name().equals(toolName))
.findFirst();
if (tool.isEmpty()) {
return "Error: 工具 '" + toolName + "' 不存在。可用工具: " +
tools.stream().map(AgentTool::name).collect(Collectors.joining(", "));
}
try {
String result = tool.get().execute(params != null ? params : Collections.emptyMap());
log.debug("工具 {} 执行成功,结果长度: {}", toolName, result.length());
return result;
} catch (Exception e) {
log.error("工具 {} 执行失败", toolName, e);
return "Error: 工具执行失败 - " + e.getMessage();
}
}
}实现具体工具
以ERP场景为例,实现几个真实工具:
/**
* 查询订单工具
*/
@Component
public class QueryOrderTool implements AgentTool {
private final OrderRepository orderRepository;
private final ObjectMapper objectMapper;
@Override
public String name() {
return "query_order";
}
@Override
public String description() {
return "查询订单信息,包括订单状态、交货期、客户信息等";
}
@Override
public String parametersSchema() {
return """
{
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "订单ID,格式为O-YYYYMMDD-XXX"
}
},
"required": ["order_id"]
}
""";
}
@Override
public String execute(Map<String, Object> parameters) {
String orderId = (String) parameters.get("order_id");
if (orderId == null || orderId.isBlank()) {
return "Error: order_id不能为空";
}
Order order = orderRepository.findById(orderId).orElse(null);
if (order == null) {
return "订单 " + orderId + " 不存在";
}
// 构建结构化的返回信息
Map<String, Object> result = new LinkedHashMap<>();
result.put("order_id", order.getId());
result.put("status", order.getStatus().getDisplayName());
result.put("customer_name", order.getCustomerName());
result.put("customer_email", order.getCustomerEmail());
result.put("planned_delivery_date", order.getPlannedDeliveryDate().toString());
result.put("actual_delivery_date",
order.getActualDeliveryDate() != null ? order.getActualDeliveryDate().toString() : "未知");
boolean isDelayed = order.isDelayed();
result.put("is_delayed", isDelayed);
if (isDelayed) {
result.put("delay_days", order.getDelayDays());
}
try {
return objectMapper.writeValueAsString(result);
} catch (JsonProcessingException e) {
return result.toString();
}
}
}
/**
* 发送邮件工具
*/
@Component
public class SendEmailTool implements AgentTool {
private final JavaMailSender mailSender;
@Override
public String name() {
return "send_email";
}
@Override
public String description() {
return "发送邮件通知。用于通知客户订单状态变化等。";
}
@Override
public String parametersSchema() {
return """
{
"type": "object",
"properties": {
"to": {"type": "string", "description": "收件人邮箱"},
"subject": {"type": "string", "description": "邮件主题"},
"body": {"type": "string", "description": "邮件正文(纯文本)"}
},
"required": ["to", "subject", "body"]
}
""";
}
@Override
public String execute(Map<String, Object> parameters) {
String to = (String) parameters.get("to");
String subject = (String) parameters.get("subject");
String body = (String) parameters.get("body");
// 基本校验
if (to == null || !to.contains("@")) {
return "Error: 无效的收件人邮箱";
}
try {
SimpleMailMessage message = new SimpleMailMessage();
message.setTo(to);
message.setSubject(subject);
message.setText(body);
mailSender.send(message);
return "邮件已成功发送到 " + to;
} catch (Exception e) {
return "Error: 邮件发送失败 - " + e.getMessage();
}
}
}
/**
* 修改交货期工具
*/
@Component
public class UpdateDeliveryDateTool implements AgentTool {
private final OrderRepository orderRepository;
@Override
public String name() {
return "update_delivery_date";
}
@Override
public String description() {
return "修改订单的计划交货日期。需要提供订单ID和新的交货日期。";
}
@Override
public String parametersSchema() {
return """
{
"type": "object",
"properties": {
"order_id": {"type": "string", "description": "订单ID"},
"new_date": {"type": "string", "description": "新的交货日期,格式YYYY-MM-DD"},
"reason": {"type": "string", "description": "修改原因(可选)"}
},
"required": ["order_id", "new_date"]
}
""";
}
@Override
public String execute(Map<String, Object> parameters) {
String orderId = (String) parameters.get("order_id");
String newDateStr = (String) parameters.get("new_date");
String reason = (String) parameters.get("reason");
try {
LocalDate newDate = LocalDate.parse(newDateStr);
Order order = orderRepository.findById(orderId)
.orElseThrow(() -> new IllegalArgumentException("订单不存在: " + orderId));
LocalDate oldDate = order.getPlannedDeliveryDate();
order.setPlannedDeliveryDate(newDate);
order.addChangeLog("交货期修改",
"从 " + oldDate + " 改为 " + newDate +
(reason != null ? ",原因:" + reason : ""));
orderRepository.save(order);
return "订单 " + orderId + " 的交货期已从 " + oldDate + " 修改为 " + newDate;
} catch (DateTimeParseException e) {
return "Error: 日期格式无效,请使用YYYY-MM-DD格式";
} catch (Exception e) {
return "Error: " + e.getMessage();
}
}
}把工具和Agent组装起来
@Service
@RequiredArgsConstructor
public class ErpAgentService {
private final ReActAgent agent;
private final QueryOrderTool queryOrderTool;
private final SendEmailTool sendEmailTool;
private final UpdateDeliveryDateTool updateDeliveryDateTool;
private List<AgentTool> getTools() {
return List.of(queryOrderTool, sendEmailTool, updateDeliveryDateTool);
}
public AgentResult handleUserRequest(String userId, String request) {
log.info("ERP Agent处理请求, userId={}, request={}", userId, request);
AgentResult result = agent.run(request, getTools());
// 记录完整的推理链,用于审计
logAgentExecution(userId, request, result);
return result;
}
private void logAgentExecution(String userId, String request, AgentResult result) {
StringBuilder log = new StringBuilder();
log.append("用户: ").append(userId).append("\n");
log.append("请求: ").append(request).append("\n");
log.append("成功: ").append(result.isSucceeded()).append("\n");
log.append("耗时: ").append(result.getTotalDurationMs()).append("ms\n");
log.append("步骤数: ").append(result.getSteps().size()).append("\n");
for (int i = 0; i < result.getSteps().size(); i++) {
AgentStep step = result.getSteps().get(i);
log.append("\n--- 步骤").append(i + 1).append(" ---\n");
log.append("Thought: ").append(step.getThought()).append("\n");
if (step.getActionName() != null) {
log.append("Action: ").append(step.getActionName()).append("\n");
log.append("Observation: ").append(step.getObservation()).append("\n");
}
if (step.getFinalAnswer() != null) {
log.append("Final Answer: ").append(step.getFinalAnswer()).append("\n");
}
}
auditLogService.save(userId, "AGENT_EXECUTION", log.toString());
}
}HTTP接口
@RestController
@RequestMapping("/api/agent")
@RequiredArgsConstructor
public class AgentController {
private final ErpAgentService erpAgentService;
@PostMapping("/query")
public ResponseEntity<AgentResponse> query(
@RequestBody AgentRequest request,
@AuthenticationPrincipal UserDetails user) {
AgentResult result = erpAgentService.handleUserRequest(
user.getUsername(),
request.getMessage()
);
AgentResponse response = AgentResponse.builder()
.answer(result.getFinalAnswer())
.succeeded(result.isSucceeded())
.steps(result.getSteps().stream()
.map(this::toStepDto)
.collect(Collectors.toList()))
.durationMs(result.getTotalDurationMs())
.build();
return ResponseEntity.ok(response);
}
private StepDto toStepDto(AgentStep step) {
return StepDto.builder()
.thought(step.getThought())
.action(step.getActionName())
.observation(step.getObservation())
.build();
}
}生产中遇到的真实问题
问题1:LLM偶尔输出格式不符合预期
LLM有时会输出"Thought: ... \nAction: ...\nParameters: ..."(把Action Input写成了Parameters),或者Action和Action Input之间没有换行。
我的解决方案是多种格式都尝试匹配,不强依赖一种格式:
private String extractActionInput(String text) {
// 尝试多种可能的标记
String[] markers = {"Action Input:", "Parameters:", "Input:", "Arguments:"};
for (String marker : markers) {
String result = extractSection(text, marker, null);
if (result != null && !result.isBlank()) {
return result.trim();
}
}
return null;
}问题2:工具执行结果太长,塞满上下文
有时候查询结果非常长(比如查询一个有几百个子项的大订单),把它完整塞回给LLM会浪费大量Token,而且LLM对超长的Observation理解能力下降。
解决方案:对Observation做截断,超过阈值时自动摘要:
private String truncateObservation(String observation, int maxChars) {
if (observation.length() <= maxChars) return observation;
// 超长时,只取前maxChars字符,并标注已截断
return observation.substring(0, maxChars) +
"\n...[结果已截断,共" + observation.length() + "字符,只显示前" + maxChars + "字符]";
}问题3:循环引用——Agent试图调用不存在的工具
当LLM产生幻觉,调用了一个不存在的工具名时,我返回了一个详细的错误信息(包含可用工具列表)作为Observation,让LLM在下一轮自我纠正。大多数情况下,LLM能从这个错误中恢复。
关于MAX_ITERATIONS的设定
我见过有人把这个值设成100,然后费用账单让他大吃一惊。
正确的思路是:
- 简单任务(查一查、改一改):3-5步足够
- 中等复杂任务(多步骤协调):5-8步
- 把MAX_ITERATIONS设成10-15是合理的上限
如果Agent做了10步还没完成,要么任务本身设计有问题,要么LLM陷入了循环。这时候强制终止,并把已完成的部分告知用户,比一直傻跑下去强。
这套ReAct框架上线后,我们的ERP助手能处理大约80%的常见操作请求了。剩下20%是超出工具能力范围的复杂场景,这些我们明确地告诉用户"这个操作超出了我的能力范围",而不是给出一个错误的答案。
Agent的工程价值,不在于让它无所不能,而在于让它在自己能力范围内可靠地工作,在能力边界清晰地认识自己的局限。
