自定义注解+运行时反射:实现一个轻量级参数校验框架
自定义注解+运行时反射:实现一个轻量级参数校验框架
适读人群:Java中级开发者、想深入理解注解机制的后端工程师 | 阅读时长:约16分钟 | 文章类型:实战开发+原理讲解
开篇故事
去年我在一个没有引入Spring Validation的项目里,看到了大量这样的参数校验代码:
public ApiResult createOrder(OrderRequest request) {
if (request.getUserId() == null) {
return ApiResult.error("userId不能为空");
}
if (request.getAmount() == null || request.getAmount().compareTo(BigDecimal.ZERO) <= 0) {
return ApiResult.error("金额必须大于0");
}
if (request.getProductList() == null || request.getProductList().isEmpty()) {
return ApiResult.error("商品列表不能为空");
}
if (request.getAddress() != null && request.getAddress().length() > 200) {
return ApiResult.error("地址长度不能超过200字符");
}
// 到这里才是真正的业务逻辑...
return doCreateOrder(request);
}一个Controller方法,前10行全是手写的参数校验。每个接口都这样,代码量翻倍,而且校验逻辑散落在各处,改起来容易遗漏。
我花了半天时间,写了一个基于注解的轻量级校验框架,把上面那段代码变成了:
@NotNull(field = "userId", message = "userId不能为空")
@Positive(field = "amount", message = "金额必须大于0")
@NotEmpty(field = "productList", message = "商品列表不能为空")
@MaxLength(field = "address", value = 200, message = "地址长度不能超过200字符")
public ApiResult createOrder(OrderRequest request) {
return doCreateOrder(request);
}今天把这个框架的实现思路和完整代码写出来。这个例子能把Java注解的核心用法都覆盖到。
一、Java注解的基础知识
注解的四个元注解
定义注解时,需要用元注解来描述注解本身的行为:
@Target:注解可以用在哪里(METHOD/FIELD/TYPE/PARAMETER等)@Retention:注解的保留策略SOURCE:只在源码中,编译后丢弃(如@Override)CLASS:保留到.class文件,但运行时JVM不加载(默认)RUNTIME:保留到运行时,可以通过反射读取(我们需要这个)
@Documented:是否包含在JavaDoc中@Inherited:是否被子类继承
注解是接口的特殊形式
@interface MyAnnotation {
String value() default ""; // 注解的"字段"叫属性
}编译后,MyAnnotation本质上是一个继承自java.lang.annotation.Annotation的接口。
二、核心原理深挖
注解的运行时读取
当注解的@Retention是RUNTIME时,可以通过反射读取:
Method method = SomeClass.class.getMethod("someMethod");
MyAnnotation ann = method.getAnnotation(MyAnnotation.class);
if (ann != null) {
String value = ann.value();
}对于字段的注解,可以在Field上读取:
Field field = SomeClass.class.getDeclaredField("someField");
NotNull ann = field.getAnnotation(NotNull.class);框架设计思路
三、完整代码实现
代码一:注解定义
package com.laozhang.validation.annotation;
import java.lang.annotation.*;
/**
* 非空校验(对象不为null)
*/
@Target({ElementType.FIELD, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface NotNull {
String message() default "字段不能为null";
}
// ---
package com.laozhang.validation.annotation;
import java.lang.annotation.*;
/**
* 非空校验(字符串/集合/数组不为空)
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface NotEmpty {
String message() default "字段不能为空";
}
// ---
package com.laozhang.validation.annotation;
import java.lang.annotation.*;
/**
* 最大长度校验(适用于String)
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MaxLength {
int value();
String message() default "字段长度超出限制";
}
// ---
package com.laozhang.validation.annotation;
import java.lang.annotation.*;
/**
* 正数校验(适用于数值类型,必须 > 0)
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Positive {
String message() default "字段必须是正数";
}
// ---
package com.laozhang.validation.annotation;
import java.lang.annotation.*;
/**
* 范围校验(适用于整数)
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Range {
long min() default Long.MIN_VALUE;
long max() default Long.MAX_VALUE;
String message() default "字段超出范围";
}
// ---
package com.laozhang.validation.annotation;
import java.lang.annotation.*;
/**
* 正则表达式校验
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Pattern {
String value(); // 正则表达式
String message() default "字段格式不正确";
}代码二:校验框架核心实现
package com.laozhang.validation;
import com.laozhang.validation.annotation.*;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* 轻量级参数校验框架核心实现
*
* 用法:
* ValidationResult result = Validator.validate(requestObject);
* if (!result.isValid()) {
* throw new ValidationException(result.getFirstError());
* }
*/
public class Validator {
/**
* 校验结果
*/
public static class ValidationResult {
private final List<String> errors;
ValidationResult(List<String> errors) {
this.errors = errors;
}
public boolean isValid() { return errors.isEmpty(); }
public List<String> getErrors() { return Collections.unmodifiableList(errors); }
public String getFirstError() { return errors.isEmpty() ? null : errors.get(0); }
public String getAllErrors() { return String.join("; ", errors); }
}
/**
* 字段校验接口
*/
@FunctionalInterface
interface FieldValidator<A extends Annotation> {
/**
* @param annotation 注解实例
* @param fieldName 字段名
* @param value 字段值
* @return 错误信息,null表示校验通过
*/
String validate(A annotation, String fieldName, Object value);
}
// 注解类型 → 校验器的映射
@SuppressWarnings("rawtypes")
private static final Map<Class<? extends Annotation>, FieldValidator> VALIDATORS
= new HashMap<>();
// 缓存类的字段列表
private static final ConcurrentHashMap<Class<?>, Field[]> FIELD_CACHE
= new ConcurrentHashMap<>();
// 注册各注解的校验逻辑
static {
// @NotNull:不能为null
VALIDATORS.put(NotNull.class, (FieldValidator<NotNull>) (ann, field, value) -> {
if (value == null) {
return "[" + field + "] " + ann.message();
}
return null;
});
// @NotEmpty:不能为null,且不能为空(字符串/集合/数组)
VALIDATORS.put(NotEmpty.class, (FieldValidator<NotEmpty>) (ann, field, value) -> {
if (value == null) return "[" + field + "] " + ann.message();
if (value instanceof String && ((String) value).trim().isEmpty()) {
return "[" + field + "] " + ann.message();
}
if (value instanceof Collection && ((Collection<?>) value).isEmpty()) {
return "[" + field + "] " + ann.message();
}
if (value instanceof Object[] && ((Object[]) value).length == 0) {
return "[" + field + "] " + ann.message();
}
return null;
});
// @MaxLength:字符串最大长度
VALIDATORS.put(MaxLength.class, (FieldValidator<MaxLength>) (ann, field, value) -> {
if (value == null) return null; // null交给@NotNull处理
if (value instanceof String && ((String) value).length() > ann.value()) {
return "[" + field + "] " + ann.message() + "(最大" + ann.value() + "字符)";
}
return null;
});
// @Positive:必须是正数
VALIDATORS.put(Positive.class, (FieldValidator<Positive>) (ann, field, value) -> {
if (value == null) return null;
double d;
if (value instanceof BigDecimal) {
d = ((BigDecimal) value).doubleValue();
} else if (value instanceof Number) {
d = ((Number) value).doubleValue();
} else {
return null; // 非数值类型跳过
}
if (d <= 0) return "[" + field + "] " + ann.message();
return null;
});
// @Range:整数范围
VALIDATORS.put(Range.class, (FieldValidator<Range>) (ann, field, value) -> {
if (value == null) return null;
if (!(value instanceof Number)) return null;
long v = ((Number) value).longValue();
if (v < ann.min() || v > ann.max()) {
return "[" + field + "] " + ann.message() +
String.format("(范围: %d ~ %d)", ann.min(), ann.max());
}
return null;
});
// @Pattern:正则校验
VALIDATORS.put(Pattern.class, (FieldValidator<Pattern>) (ann, field, value) -> {
if (value == null) return null;
if (!(value instanceof String)) return null;
String str = (String) value;
if (!str.matches(ann.value())) {
return "[" + field + "] " + ann.message();
}
return null;
});
}
/**
* 校验Bean对象上所有字段的注解
* 返回所有校验错误(fail-all模式)
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public static ValidationResult validate(Object bean) {
if (bean == null) {
return new ValidationResult(List.of("校验对象不能为null"));
}
List<String> errors = new ArrayList<>();
Class<?> clazz = bean.getClass();
// 获取缓存的字段列表(包含父类字段)
Field[] fields = FIELD_CACHE.computeIfAbsent(clazz, c -> {
List<Field> allFields = new ArrayList<>();
Class<?> current = c;
while (current != null && current != Object.class) {
for (Field f : current.getDeclaredFields()) {
f.setAccessible(true);
allFields.add(f);
}
current = current.getSuperclass();
}
return allFields.toArray(new Field[0]);
});
// 遍历每个字段
for (Field field : fields) {
Object value;
try {
value = field.get(bean);
} catch (IllegalAccessException e) {
continue; // 忽略无法访问的字段
}
// 遍历字段上的所有注解
for (Annotation annotation : field.getAnnotations()) {
FieldValidator validator = VALIDATORS.get(annotation.annotationType());
if (validator != null) {
String error = validator.validate(annotation, field.getName(), value);
if (error != null) {
errors.add(error);
}
}
}
}
return new ValidationResult(errors);
}
/**
* 快速校验,遇到第一个错误立即返回(fail-fast模式)
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public static Optional<String> validateFast(Object bean) {
if (bean == null) return Optional.of("校验对象不能为null");
Class<?> clazz = bean.getClass();
Field[] fields = FIELD_CACHE.computeIfAbsent(clazz, c -> {
List<Field> allFields = new ArrayList<>();
Class<?> current = c;
while (current != null && current != Object.class) {
for (Field f : current.getDeclaredFields()) {
f.setAccessible(true);
allFields.add(f);
}
current = current.getSuperclass();
}
return allFields.toArray(new Field[0]);
});
for (Field field : fields) {
Object value;
try { value = field.get(bean); }
catch (IllegalAccessException e) { continue; }
for (Annotation annotation : field.getAnnotations()) {
FieldValidator validator = VALIDATORS.get(annotation.annotationType());
if (validator != null) {
String error = validator.validate(annotation, field.getName(), value);
if (error != null) return Optional.of(error);
}
}
}
return Optional.empty();
}
/**
* 校验失败时抛出异常(适合在AOP切面中使用)
*/
public static void validateOrThrow(Object bean) {
ValidationResult result = validate(bean);
if (!result.isValid()) {
throw new ValidationException(result.getFirstError(), result.getErrors());
}
}
/**
* 校验异常
*/
public static class ValidationException extends RuntimeException {
private final List<String> errors;
ValidationException(String message, List<String> errors) {
super(message);
this.errors = errors;
}
public List<String> getErrors() { return errors; }
}
// ===== 使用示例 =====
static class OrderRequest {
@NotNull(message = "userId不能为空")
private Long userId;
@NotNull(message = "金额不能为空")
@Positive(message = "金额必须大于0")
private BigDecimal amount;
@NotEmpty(message = "商品列表不能为空")
private List<String> productList;
@MaxLength(value = 200, message = "地址长度不能超过200字符")
private String address;
@Pattern(value = "^1[3-9]\\d{9}$", message = "手机号格式不正确")
private String phone;
@Range(min = 1, max = 5, message = "评分范围1-5")
private Integer rating;
// 构造器和setter
public OrderRequest(Long userId, BigDecimal amount, List<String> products,
String address, String phone, Integer rating) {
this.userId = userId;
this.amount = amount;
this.productList = products;
this.address = address;
this.phone = phone;
this.rating = rating;
}
}
public static void main(String[] args) {
System.out.println("=== 校验失败的请求 ===");
OrderRequest bad = new OrderRequest(
null, // userId: null → 校验失败
new BigDecimal("-10"), // amount: 负数 → 校验失败
new ArrayList<>(), // productList: 空 → 校验失败
"这是一个超过200字符的地址," + "x".repeat(200), // address: 超长 → 校验失败
"12345", // phone: 格式错 → 校验失败
10 // rating: 超出范围 → 校验失败
);
ValidationResult result = validate(bad);
System.out.println("校验通过: " + result.isValid());
System.out.println("错误列表:");
result.getErrors().forEach(e -> System.out.println(" - " + e));
System.out.println("\n=== 校验通过的请求 ===");
OrderRequest good = new OrderRequest(
10001L,
new BigDecimal("99.99"),
List.of("product-001", "product-002"),
"北京市朝阳区xx街道xx号",
"13812345678",
5
);
ValidationResult goodResult = validate(good);
System.out.println("校验通过: " + goodResult.isValid());
}
}四、踩坑实录
坑1:注解的@Retention设置错了,运行时读不到
报错现象:
// getAnnotation返回null,反射读不到注解
NotNull ann = field.getAnnotation(NotNull.class);
// ann == null,校验框架完全不生效根本原因:
注解没有加@Retention(RetentionPolicy.RUNTIME),默认是CLASS,运行时JVM不会加载注解信息。
具体解法:
// 必须加这行
@Retention(RetentionPolicy.RUNTIME)
public @interface NotNull { ... }坑2:字段有注解但getField/getDeclaredField找不到
报错现象:
java.lang.NoSuchFieldException: userId
at java.base/java.lang.Class.getDeclaredField(Class.java:2599)根本原因:
getDeclaredField只能获取当前类的字段,不包括父类。如果要验证继承的字段,需要遍历整个继承链。
具体解法:
// 遍历继承链获取所有字段
private static List<Field> getAllFields(Class<?> clazz) {
List<Field> fields = new ArrayList<>();
while (clazz != null && clazz != Object.class) {
Collections.addAll(fields, clazz.getDeclaredFields());
clazz = clazz.getSuperclass();
}
return fields;
}坑3:注解属性的默认值和用户指定值混淆
报错现象:
用户没有设置message属性,但代码里读到的ann.message()却是空字符串而不是预期的默认值。
根本原因:
注解属性的默认值是在@interface定义里的:
@interface NotNull {
String message() default "字段不能为null"; // 这里的default是注解定义的默认值
}如果用户在使用时写了@NotNull(message = "")(空字符串),就会覆盖默认值。这和没写message完全不同。
具体解法:
在校验器实现里处理空字符串:
String msg = ann.message();
if (msg == null || msg.isEmpty()) {
msg = "字段不能为null"; // fallback
}五、总结与延伸
今天实现的这个框架展示了注解的核心使用模式:
- 定义注解:用
@Target和@Retention(RUNTIME)描述注解的行为 - 读取注解:用反射的
getAnnotation/getAnnotations在运行时读取 - 执行逻辑:根据注解类型找到对应的处理逻辑
这个模式在很多场景都适用:
- 参数校验(我们今天做的)
- 权限控制(读取@RequireRole,检查当前用户是否有权限)
- 缓存控制(读取@Cacheable,决定是否走缓存)
- 日志打印(读取@LogMethod,自动记录方法入参和出参)
当然,生产项目里通常直接用Spring Validation(Hibernate Validator),功能更完善,和Spring AOP集成更好。但理解自己手撸一遍的原理,能让你在遇到框架的奇怪行为时更快定位问题。
另外,这个框架有一个明显的局限:嵌套对象的校验(OrderRequest里的Address对象)需要额外支持。可以加一个@Valid注解表示"递归校验此字段",在框架里判断字段值是否是POJO并递归调用validate——这是留给感兴趣的读者的小作业。
