访问者模式:AST抽象语法树遍历与编译原理在工程中的实际应用
访问者模式:AST抽象语法树遍历与编译原理在工程中的实际应用
适读人群:中高级Java开发者 | 阅读时长:约25分钟 | 模式类型:行为型
开篇故事
做了15年 Java,我一直以为编译原理是学院派的东西,离工程实践很远。直到有一天,我们需要开发一个"动态规则引擎"——运营同学可以在后台配置类似 age > 18 AND level >= 3 AND region IN ('北京', '上海') 这样的规则表达式,系统实时解析并判断用户是否符合条件。
最开始我想用正则表达式来解析,写了两天后发现对于嵌套括号、复杂运算符优先级这类场景根本搞不定。后来参考了编译原理中的词法分析+语法分析+AST构建+访问者遍历的流程,用了一个周末把这个规则引擎做出来了。
整个过程中,访问者模式是处理 AST 节点的核心——同样一棵 AST,可以被不同的访问者以不同方式遍历:有的访问者计算表达式的值(求值访问者),有的打印表达式的 SQL 形式(SQL 生成访问者),有的检查表达式是否合法(语义检查访问者)。
这让我对访问者模式有了全新的认识:它不是一个晦涩难懂的"高级"模式,而是在需要对稳定的数据结构执行多种不同操作时,最合适的设计方案。
一、模式动机:数据结构稳定但操作多变
访问者模式(Visitor Pattern)的适用场景:
- 数据结构相对稳定,但需要在上面定义很多种不同的操作。
- 新操作的添加不应该污染数据结构类(开闭原则)。
- 操作需要跨越不同类型的节点(如遍历异构树形结构)。
最典型的案例就是 AST(Abstract Syntax Tree,抽象语法树):
- AST 的节点类型是固定的(文字节点、二元运算节点、一元运算节点、函数调用节点等)。
- 但对 AST 的操作可能有很多种:求值、生成 SQL、代码优化、语义检查……
- 每增加一种操作,如果直接在节点类里加方法,会让节点类越来越臃肿。
- 访问者模式将"操作"从节点类中抽离,集中在独立的访问者类中。
二、模式结构
访问者模式的核心是双重分发(Double Dispatch):
- 调用节点的
accept(visitor)方法(第一次分发,确定节点类型) - 节点内部调用
visitor.visitXxx(this)方法(第二次分发,确定访问者类型)
三、Java 中的访问者模式实例
3.1 Checkstyle 的 AST 访问者
Java 代码检查工具 Checkstyle 就是用访问者模式遍历 Java 源码的 AST:
// Checkstyle 的核心抽象类(所有规则检查都继承这个类)
public abstract class AbstractCheck {
// 声明要访问哪些AST节点类型(类似于register visitor)
public abstract int[] getDefaultTokens();
// 访问方法(访问进入节点时调用)
public void visitToken(DetailAST ast) {}
// 离开节点时调用
public void leaveToken(DetailAST ast) {}
// 日志/报告方法
protected final void log(int line, String key, Object... args) { ... }
}
// 具体规则:检查行宽是否超过限制
public class LineLengthCheck extends AbstractCheck {
private int max = 80; // 默认最大行宽
@Override
public int[] getDefaultTokens() {
return new int[]{TokenTypes.PACKAGE_DEF}; // 只访问package定义(用于触发整体检查)
}
@Override
public void visitToken(DetailAST ast) {
// 当访问到package定义时,检查所有代码行
}
}四、生产级代码实现:规则引擎的 AST 与访问者
4.1 AST 节点定义
/**
* 表达式节点接口(Element,接受访问者)
*/
public interface ExpressionNode {
/**
* 接受访问者(双重分发的关键)
*/
<T> T accept(ExpressionVisitor<T> visitor);
/**
* 节点类型描述
*/
String getNodeType();
}
/**
* 字面量节点(LiteralNode):数字、字符串、布尔值
*/
@Data
public class LiteralNode implements ExpressionNode {
private final Object value;
private final LiteralType type;
public enum LiteralType { INTEGER, DECIMAL, STRING, BOOLEAN, NULL }
@Override
public <T> T accept(ExpressionVisitor<T> visitor) {
return visitor.visitLiteral(this); // 第二次分发
}
@Override
public String getNodeType() { return "LITERAL"; }
}
/**
* 标识符节点(IdentifierNode):变量名,如 age、level
*/
@Data
public class IdentifierNode implements ExpressionNode {
private final String name; // 变量名
@Override
public <T> T accept(ExpressionVisitor<T> visitor) {
return visitor.visitIdentifier(this);
}
@Override
public String getNodeType() { return "IDENTIFIER"; }
}
/**
* 二元运算节点(BinaryOpNode):a AND b, x > 5, c IN (...)
*/
@Data
public class BinaryOpNode implements ExpressionNode {
private final ExpressionNode left;
private final ExpressionNode right;
private final BinaryOperator operator;
public enum BinaryOperator {
AND, OR,
EQ("="), NEQ("!="), GT(">"), GTE(">="), LT("<"), LTE("<="),
IN, NOT_IN,
LIKE, NOT_LIKE;
private final String symbol;
BinaryOperator() { this.symbol = name(); }
BinaryOperator(String symbol) { this.symbol = symbol; }
public String getSymbol() { return symbol; }
}
@Override
public <T> T accept(ExpressionVisitor<T> visitor) {
return visitor.visitBinaryOp(this);
}
@Override
public String getNodeType() { return "BINARY_OP:" + operator; }
}
/**
* 一元运算节点(UnaryOpNode):NOT condition
*/
@Data
public class UnaryOpNode implements ExpressionNode {
private final ExpressionNode operand;
private final UnaryOperator operator;
public enum UnaryOperator { NOT, NEG }
@Override
public <T> T accept(ExpressionVisitor<T> visitor) {
return visitor.visitUnaryOp(this);
}
@Override
public String getNodeType() { return "UNARY_OP:" + operator; }
}
/**
* 函数调用节点:contains(tags, 'vip')
*/
@Data
public class FunctionCallNode implements ExpressionNode {
private final String functionName;
private final List<ExpressionNode> arguments;
@Override
public <T> T accept(ExpressionVisitor<T> visitor) {
return visitor.visitFunctionCall(this);
}
@Override
public String getNodeType() { return "FUNCTION:" + functionName; }
}
/**
* IN 列表节点
*/
@Data
public class InListNode implements ExpressionNode {
private final List<ExpressionNode> elements;
@Override
public <T> T accept(ExpressionVisitor<T> visitor) {
return visitor.visitInList(this);
}
@Override
public String getNodeType() { return "IN_LIST"; }
}4.2 访问者接口与各种访问者实现
/**
* 表达式访问者接口(Visitor)
* 泛型 T 表示访问结果的类型
*/
public interface ExpressionVisitor<T> {
T visitLiteral(LiteralNode node);
T visitIdentifier(IdentifierNode node);
T visitBinaryOp(BinaryOpNode node);
T visitUnaryOp(UnaryOpNode node);
T visitFunctionCall(FunctionCallNode node);
T visitInList(InListNode node);
}
/**
* 求值访问者(ConcreteVisitor A)
* 遍历 AST,根据用户数据计算规则表达式的值
*/
@Slf4j
public class EvaluationVisitor implements ExpressionVisitor<Object> {
private final Map<String, Object> context; // 用户数据:{"age": 25, "level": 3, ...}
public EvaluationVisitor(Map<String, Object> context) {
this.context = context;
}
@Override
public Object visitLiteral(LiteralNode node) {
return node.getValue(); // 字面量直接返回值
}
@Override
public Object visitIdentifier(IdentifierNode node) {
Object value = context.get(node.getName());
if (value == null && !context.containsKey(node.getName())) {
log.warn("Variable '{}' not found in context", node.getName());
}
return value;
}
@Override
public Object visitBinaryOp(BinaryOpNode node) {
// 对于 AND/OR,先求左边,再决定是否求右边(短路求值)
if (node.getOperator() == BinaryOpNode.BinaryOperator.AND) {
Object leftResult = node.getLeft().accept(this);
if (!toBool(leftResult)) return false; // 短路:左边为false,直接返回false
return toBool(node.getRight().accept(this));
}
if (node.getOperator() == BinaryOpNode.BinaryOperator.OR) {
Object leftResult = node.getLeft().accept(this);
if (toBool(leftResult)) return true; // 短路:左边为true,直接返回true
return toBool(node.getRight().accept(this));
}
Object left = node.getLeft().accept(this);
Object right = node.getRight().accept(this);
return switch (node.getOperator()) {
case EQ -> equals(left, right);
case NEQ -> !equals(left, right);
case GT -> compare(left, right) > 0;
case GTE -> compare(left, right) >= 0;
case LT -> compare(left, right) < 0;
case LTE -> compare(left, right) <= 0;
case IN -> evaluateIn(left, right, false);
case NOT_IN -> evaluateIn(left, right, true);
case LIKE -> evaluateLike(left, right, false);
case NOT_LIKE -> evaluateLike(left, right, true);
default -> throw new EvaluationException("Unknown operator: " + node.getOperator());
};
}
@Override
public Object visitUnaryOp(UnaryOpNode node) {
Object operand = node.getOperand().accept(this);
return switch (node.getOperator()) {
case NOT -> !toBool(operand);
case NEG -> negate(operand);
};
}
@Override
public Object visitFunctionCall(FunctionCallNode node) {
List<Object> args = node.getArguments().stream()
.map(arg -> arg.accept(this))
.collect(Collectors.toList());
return switch (node.getFunctionName().toLowerCase()) {
case "contains" -> {
if (args.size() != 2) throw new EvaluationException("contains() requires 2 arguments");
Object collection = args.get(0);
Object element = args.get(1);
if (collection instanceof Collection<?> c) yield c.contains(element);
if (collection instanceof String s) yield s.contains(String.valueOf(element));
yield false;
}
case "length" -> {
Object obj = args.get(0);
if (obj instanceof String s) yield s.length();
if (obj instanceof Collection<?> c) yield c.size();
yield 0;
}
case "lower" -> String.valueOf(args.get(0)).toLowerCase();
case "upper" -> String.valueOf(args.get(0)).toUpperCase();
default -> throw new EvaluationException("Unknown function: " + node.getFunctionName());
};
}
@Override
public Object visitInList(InListNode node) {
return node.getElements().stream()
.map(e -> e.accept(this))
.collect(Collectors.toList());
}
private boolean toBool(Object value) {
if (value instanceof Boolean b) return b;
if (value == null) return false;
return Boolean.parseBoolean(String.valueOf(value));
}
@SuppressWarnings("unchecked")
private int compare(Object a, Object b) {
if (a instanceof Comparable ca && b != null) {
return ca.compareTo(b);
}
throw new EvaluationException("Cannot compare " + a + " with " + b);
}
private boolean equals(Object a, Object b) {
if (a == null && b == null) return true;
if (a == null || b == null) return false;
// 数字比较要处理类型转换
if (a instanceof Number na && b instanceof Number nb) {
return na.doubleValue() == nb.doubleValue();
}
return a.equals(b);
}
private boolean evaluateIn(Object value, Object list, boolean negate) {
if (!(list instanceof List<?> listValue)) {
throw new EvaluationException("IN operator requires a list on the right side");
}
boolean inList = listValue.stream().anyMatch(item -> equals(value, item));
return negate ? !inList : inList;
}
private boolean evaluateLike(Object value, Object pattern, boolean negate) {
String str = String.valueOf(value);
String patternStr = String.valueOf(pattern)
.replace("%", ".*")
.replace("_", ".");
boolean matches = str.matches(patternStr);
return negate ? !matches : matches;
}
private Object negate(Object value) {
if (value instanceof Integer i) return -i;
if (value instanceof Long l) return -l;
if (value instanceof Double d) return -d;
throw new EvaluationException("Cannot negate: " + value);
}
}
/**
* SQL 生成访问者(ConcreteVisitor B)
* 将 AST 转换为 SQL WHERE 子句
*/
public class SqlGenerationVisitor implements ExpressionVisitor<String> {
private final List<Object> parameters = new ArrayList<>(); // 参数化查询的参数
@Override
public String visitLiteral(LiteralNode node) {
parameters.add(node.getValue());
return "?"; // 使用参数化查询,防止SQL注入
}
@Override
public String visitIdentifier(IdentifierNode node) {
// 标识符作为列名(需要验证是否是合法的列名,防止注入)
return "`" + node.getName() + "`";
}
@Override
public String visitBinaryOp(BinaryOpNode node) {
String left = node.getLeft().accept(this);
String right = node.getRight().accept(this);
return switch (node.getOperator()) {
case AND -> "(" + left + " AND " + right + ")";
case OR -> "(" + left + " OR " + right + ")";
case EQ -> left + " = " + right;
case NEQ -> left + " != " + right;
case GT -> left + " > " + right;
case GTE -> left + " >= " + right;
case LT -> left + " < " + right;
case LTE -> left + " <= " + right;
case IN -> left + " IN " + right;
case NOT_IN -> left + " NOT IN " + right;
case LIKE -> left + " LIKE " + right;
case NOT_LIKE -> left + " NOT LIKE " + right;
default -> throw new SqlGenerationException("Cannot generate SQL for: " + node.getOperator());
};
}
@Override
public String visitUnaryOp(UnaryOpNode node) {
String operand = node.getOperand().accept(this);
return switch (node.getOperator()) {
case NOT -> "NOT (" + operand + ")";
case NEG -> "-" + operand;
};
}
@Override
public String visitFunctionCall(FunctionCallNode node) {
String args = node.getArguments().stream()
.map(arg -> arg.accept(this))
.collect(Collectors.joining(", "));
return node.getFunctionName().toUpperCase() + "(" + args + ")";
}
@Override
public String visitInList(InListNode node) {
String elements = node.getElements().stream()
.map(e -> e.accept(this))
.collect(Collectors.joining(", "));
return "(" + elements + ")";
}
public List<Object> getParameters() { return Collections.unmodifiableList(parameters); }
}
/**
* 规则引擎:组合使用多个访问者
*/
@Service
@Slf4j
public class RuleEngine {
@Autowired
private RuleRepository ruleRepository;
/**
* 对用户数据执行规则检查
*/
public boolean evaluate(String ruleId, Map<String, Object> userContext) {
Rule rule = ruleRepository.findById(ruleId)
.orElseThrow(() -> new RuleNotFoundException(ruleId));
ExpressionNode ast = parseExpression(rule.getExpression());
EvaluationVisitor evaluator = new EvaluationVisitor(userContext);
Object result = ast.accept(evaluator);
return Boolean.TRUE.equals(result) || "true".equals(String.valueOf(result));
}
/**
* 生成规则对应的 SQL WHERE 子句(用于数据库查询)
*/
public SqlClause generateSql(String ruleId) {
Rule rule = ruleRepository.findById(ruleId)
.orElseThrow(() -> new RuleNotFoundException(ruleId));
ExpressionNode ast = parseExpression(rule.getExpression());
SqlGenerationVisitor sqlVisitor = new SqlGenerationVisitor();
String sql = ast.accept(sqlVisitor);
return new SqlClause(sql, sqlVisitor.getParameters());
}
/**
* 验证规则表达式是否合法
*/
public ValidationResult validate(String expression) {
try {
ExpressionNode ast = parseExpression(expression);
SemanticCheckVisitor checker = new SemanticCheckVisitor();
ast.accept(checker);
return ValidationResult.valid();
} catch (Exception e) {
return ValidationResult.invalid(e.getMessage());
}
}
private ExpressionNode parseExpression(String expression) {
// 词法分析 + 语法分析 + 构建 AST
ExpressionLexer lexer = new ExpressionLexer(expression);
ExpressionParser parser = new ExpressionParser(lexer.tokenize());
return parser.parse();
}
}五、与相关模式的对比
访问者 vs 策略
- 策略:一组可替换的算法,针对同一种对象类型。
- 访问者:针对异构对象结构(如 AST 中的不同节点类型),同一访问者处理所有类型。
访问者模式的缺点
访问者模式的最大缺点是违反开闭原则(对于数据结构的扩展):如果需要新增一种 AST 节点类型(比如增加 TernaryOpNode),则必须在所有访问者实现中增加对应的 visitTernaryOp() 方法,否则编译报错(或漏实现导致运行时异常)。
因此,访问者模式适合"数据结构稳定、操作多变"的场景,不适合"数据结构经常扩展、操作稳定"的场景。
六、踩坑实录
坑一:Java 中没有真正的双重分发
Java 方法是基于静态类型分发的,如果直接写 visitor.visit(node),编译器根据编译期类型决定调用哪个重载,无法实现运行时动态分发。必须使用 accept() 方法才能实现双重分发:
// 错误:这会根据静态类型调用visit(ExpressionNode),而不是具体子类型
ExpressionNode node = new LiteralNode(42);
visitor.visit(node); // 永远调用 visit(ExpressionNode),不管实际类型
// 正确:通过accept()实现双重分发
node.accept(visitor); // LiteralNode.accept() 内部调用 visitor.visitLiteral(this)坑二:访问者修改了被访问对象的状态
访问者应该是只读的(对被访问对象不做修改),如果访问者在遍历过程中修改了 AST 节点,会导致遍历结果不可预测(尤其是在多线程环境下)。
如果需要变换 AST(比如优化,把 NOT (a AND b) 变换为 (NOT a) OR (NOT b)),应该返回新的 AST 节点而不是修改原有节点(immutable 风格)。
坑三:深层 AST 的递归深度导致栈溢出
对于超长的规则表达式(比如有100个 AND 连接的条件),AST 的深度可能超过 JVM 默认的递归栈深度,导致 StackOverflowError。
解决方案:将递归遍历改为迭代遍历(使用显式栈),或者增加 JVM 的栈大小(-Xss),或者在解析时进行表达式树的优化(将深度降低)。
七、总结
访问者模式是 GoF 23 种模式中最复杂、使用场景最特定的模式。它的最大价值在于:在不修改数据结构类的情况下,为数据结构新增操作(对于"新增操作"这个维度是开闭的)。
在工程实践中,访问者模式主要出现在:
- 编译器/解释器开发:对 AST 进行求值、代码生成、优化等不同操作。
- XML/JSON 解析器:对文档树进行不同形式的序列化/分析。
- SQL 生成/解析(MyBatis、Hibernate 的 HQL 解析):对查询 AST 生成不同数据库的 SQL 方言。
- 静态代码分析(Checkstyle、PMD、SpotBugs):遍历 Java 源码 AST 检查各种规则。
这个系列的最后一篇,也是整个设计模式系列的收尾——从创建型、结构型到行为型,20个模式走完一圈,设计模式的本质是代码组织的经验总结,用好它们,代码会更有生命力。
