第1744篇:实时特征计算——Flink在在线学习系统中的应用
第1744篇:实时特征计算——Flink在在线学习系统中的应用
有一次跟一个做推荐系统的朋友聊,他说他们的模型每周才重训一次,我问他为什么不做更高频的更新,他的回答让我印象深刻:"每周重训已经是我们现在能做到的极限了,特征是 T+1 的,凌晨批跑,早上才能用上。"
这个情况在传统推荐、风控场景里非常普遍。T+1 特征意味着模型看到的是昨天的世界,但用户今天的行为变化它完全不知道。在内容推荐场景里,一个热点事件今天爆发,模型要等到明天才能感知到用户兴趣的转移,错过了最好的推荐窗口。
Flink 做实时特征计算解决的就是这个问题。这篇文章我们从原理到工程实践,把这套方案完整走一遍。
一、在线学习系统的架构全貌
先把整体架构放出来,后面的内容都在这个框架里展开:
有几点值得关注:
- 实时特征和批量特征并存,不是完全替换。批量特征(如用户画像、长期统计)还是离线跑更稳,实时特征只覆盖需要高时效的部分。
- 在线学习是可选组件。有些场景只需要实时特征,不需要在线训练模型。
- 推理服务(Java)同时读实时特征和批量特征,做特征拼接。
二、Flink 实时特征计算的核心模式
2.1 滑动窗口聚合特征
这是最常见的实时特征计算需求:近 1 小时点击次数、近 10 分钟购买金额等。
// Flink Job:计算用户实时行为统计特征
public class UserRealtimeFeatureJob {
public static void main(String[] args) throws Exception {
StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
// 设置 Checkpoint(保证故障恢复)
env.enableCheckpointing(60000); // 每 60 秒做一次 Checkpoint
env.getCheckpointConfig()
.setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE);
// 从 Kafka 读取用户行为事件
KafkaSource<UserBehaviorEvent> kafkaSource = KafkaSource
.<UserBehaviorEvent>builder()
.setBootstrapServers("kafka:9092")
.setTopics("user-behavior-events")
.setGroupId("flink-feature-job")
.setValueOnlyDeserializer(new UserBehaviorEventDeserializer())
.build();
DataStream<UserBehaviorEvent> eventStream = env
.fromSource(kafkaSource, WatermarkStrategy
.<UserBehaviorEvent>forBoundedOutOfOrderness(Duration.ofSeconds(10))
.withTimestampAssigner((event, ts) -> event.getEventTime()),
"UserBehaviorKafkaSource");
// 按用户 ID 分组
KeyedStream<UserBehaviorEvent, String> keyedStream =
eventStream.keyBy(UserBehaviorEvent::getUserId);
// 滑动窗口:1 小时内,每 5 分钟更新一次
DataStream<UserFeatureRecord> hourlyFeatures = keyedStream
.window(SlidingEventTimeWindows.of(
Time.hours(1), Time.minutes(5)))
.aggregate(new UserBehaviorAggregator(),
new UserFeatureWindowFunction());
// 写入 Redis
hourlyFeatures.addSink(new RedisFeatureSink());
env.execute("UserRealtimeFeatureJob");
}
}/**
* 增量聚合器:增量计算,不用存所有原始事件
*/
public class UserBehaviorAggregator
implements AggregateFunction<UserBehaviorEvent,
UserBehaviorAccumulator,
UserBehaviorAccumulator> {
@Override
public UserBehaviorAccumulator createAccumulator() {
return new UserBehaviorAccumulator();
}
@Override
public UserBehaviorAccumulator add(UserBehaviorEvent event,
UserBehaviorAccumulator acc) {
acc.totalEvents++;
switch (event.getEventType()) {
case "click":
acc.clickCount++;
break;
case "purchase":
acc.purchaseCount++;
acc.purchaseAmount += event.getAmount();
break;
case "browse":
acc.browseCount++;
acc.browseSeconds += event.getDuration();
break;
}
// 记录最后一次事件时间
acc.lastEventTime = Math.max(acc.lastEventTime, event.getEventTime());
return acc;
}
@Override
public UserBehaviorAccumulator getResult(UserBehaviorAccumulator acc) {
return acc;
}
@Override
public UserBehaviorAccumulator merge(UserBehaviorAccumulator a,
UserBehaviorAccumulator b) {
UserBehaviorAccumulator merged = new UserBehaviorAccumulator();
merged.totalEvents = a.totalEvents + b.totalEvents;
merged.clickCount = a.clickCount + b.clickCount;
merged.purchaseCount = a.purchaseCount + b.purchaseCount;
merged.purchaseAmount = a.purchaseAmount + b.purchaseAmount;
merged.browseCount = a.browseCount + b.browseCount;
merged.browseSeconds = a.browseSeconds + b.browseSeconds;
merged.lastEventTime = Math.max(a.lastEventTime, b.lastEventTime);
return merged;
}
}
@Data
public class UserBehaviorAccumulator {
public int totalEvents = 0;
public int clickCount = 0;
public int purchaseCount = 0;
public double purchaseAmount = 0;
public int browseCount = 0;
public long browseSeconds = 0;
public long lastEventTime = 0;
}2.2 Session 特征:捕捉用户当前会话行为
滑动窗口是时间驱动的,Session 特征是行为驱动的——用户 30 分钟内没有行为就认为会话结束。
/**
* 会话特征计算:每个 Session 内的行为序列特征
*/
public class UserSessionFeatureJob {
public static void main(String[] args) throws Exception {
StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<UserBehaviorEvent> eventStream = buildEventStream(env);
// Session 窗口:30 分钟无事件则会话结束
DataStream<SessionFeature> sessionFeatures = eventStream
.keyBy(UserBehaviorEvent::getUserId)
.window(EventTimeSessionWindows.withGap(Time.minutes(30)))
.process(new SessionFeatureProcessor());
sessionFeatures.addSink(new RedisFeatureSink());
env.execute("UserSessionFeatureJob");
}
}
public class SessionFeatureProcessor
extends ProcessWindowFunction<UserBehaviorEvent,
SessionFeature, String, TimeWindow> {
@Override
public void process(String userId,
Context context,
Iterable<UserBehaviorEvent> events,
Collector<SessionFeature> out) {
List<UserBehaviorEvent> eventList = new ArrayList<>();
events.forEach(eventList::add);
// 按时间排序
eventList.sort(Comparator.comparingLong(UserBehaviorEvent::getEventTime));
if (eventList.isEmpty()) return;
SessionFeature feature = new SessionFeature();
feature.setUserId(userId);
feature.setSessionEventCount(eventList.size());
// 会话时长(毫秒)
long sessionDuration = eventList.get(eventList.size() - 1).getEventTime()
- eventList.get(0).getEventTime();
feature.setSessionDurationMs(sessionDuration);
// 事件类型序列(用于序列模式挖掘)
String eventSequence = eventList.stream()
.map(UserBehaviorEvent::getEventType)
.collect(Collectors.joining(","));
feature.setEventSequence(eventSequence);
// 浏览深度(最大页面序号)
int maxPageDepth = eventList.stream()
.filter(e -> e.getPageDepth() != null)
.mapToInt(UserBehaviorEvent::getPageDepth)
.max().orElse(0);
feature.setMaxPageDepth(maxPageDepth);
// 是否有购买行为
boolean hasPurchase = eventList.stream()
.anyMatch(e -> "purchase".equals(e.getEventType()));
feature.setHasPurchase(hasPurchase);
// 最近事件距当前时间(新鲜度)
feature.setLastEventAgeMs(
System.currentTimeMillis() -
eventList.get(eventList.size() - 1).getEventTime());
out.collect(feature);
}
}2.3 双流 Join:实时特征拼接
很多时候需要把两个事件流 Join 起来。比如用户行为流和商品信息流 Join,计算"用户对该类目的偏好与当前浏览商品类目的匹配度":
public class BehaviorItemJoinJob {
public static void main(String[] args) throws Exception {
StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<UserBehaviorEvent> behaviorStream = buildBehaviorStream(env);
DataStream<ItemViewEvent> itemStream = buildItemStream(env);
// 基于时间窗口的双流 Join
DataStream<JoinedFeature> joinedStream = behaviorStream
.join(itemStream)
.where(UserBehaviorEvent::getItemId) // 行为流 Join Key
.equalTo(ItemViewEvent::getItemId) // 商品流 Join Key
.window(TumblingEventTimeWindows.of(Time.minutes(1)))
.apply(new JoinFunction<UserBehaviorEvent,
ItemViewEvent,
JoinedFeature>() {
@Override
public JoinedFeature join(UserBehaviorEvent behavior,
ItemViewEvent item) {
JoinedFeature feature = new JoinedFeature();
feature.setUserId(behavior.getUserId());
feature.setItemId(behavior.getItemId());
feature.setEventType(behavior.getEventType());
feature.setItemCategory(item.getCategory());
feature.setItemPrice(item.getPrice());
feature.setUserCategory(behavior.getPreferredCategory());
// 用户偏好类目与商品类目是否匹配
feature.setCategoryMatch(
behavior.getPreferredCategory().equals(item.getCategory()));
return feature;
}
});
joinedStream.addSink(new RedisFeatureSink());
env.execute("BehaviorItemJoinJob");
}
}三、实时特征的存储与读取
3.1 Redis 存储设计
实时特征的读取是在线推理的关键路径,必须低延迟。Redis 是最常用的选择。
@Component
public class RedisFeatureSink extends RichSinkFunction<UserFeatureRecord> {
private transient JedisCluster jedis;
@Override
public void open(Configuration parameters) {
Set<HostAndPort> nodes = new HashSet<>();
nodes.add(new HostAndPort("redis-node1", 6379));
nodes.add(new HostAndPort("redis-node2", 6379));
nodes.add(new HostAndPort("redis-node3", 6379));
JedisPoolConfig poolConfig = new JedisPoolConfig();
poolConfig.setMaxTotal(100);
jedis = new JedisCluster(nodes, 2000, 2000, 5, poolConfig);
}
@Override
public void invoke(UserFeatureRecord record, Context context) {
String key = "rt_feature:" + record.getUserId();
// 用 Hash 存储多个特征值,原子操作
Map<String, String> featureMap = new HashMap<>();
featureMap.put("click_1h", String.valueOf(record.getClickCount1h()));
featureMap.put("purchase_1h", String.valueOf(record.getPurchaseCount1h()));
featureMap.put("browse_1h", String.valueOf(record.getBrowseCount1h()));
featureMap.put("session_depth", String.valueOf(record.getMaxPageDepth()));
featureMap.put("last_event_ts", String.valueOf(record.getLastEventTime()));
featureMap.put("updated_at", String.valueOf(System.currentTimeMillis()));
jedis.hset(key, featureMap);
jedis.expire(key, 86400); // TTL 24 小时,过期自动清理
}
@Override
public void close() {
if (jedis != null) jedis.close();
}
}3.2 Java 推理服务的特征读取
在线推理时,Java 服务需要合并实时特征(Redis)和批量特征(HBase/MySQL):
@Service
public class FeatureAssemblyService {
@Autowired
private RedisTemplate<String, String> redisTemplate;
@Autowired
private BatchFeatureService batchFeatureService;
/**
* 组装完整特征向量,用于在线推理
* 超时控制:整体不超过 20ms
*/
public FeatureVector assembleFeatures(String userId, String itemId) {
long startTime = System.currentTimeMillis();
// 并行读取实时特征和批量特征
CompletableFuture<Map<String, Double>> rtFeatureFuture =
CompletableFuture.supplyAsync(() -> getRealtimeFeatures(userId));
CompletableFuture<Map<String, Double>> batchFeatureFuture =
CompletableFuture.supplyAsync(() ->
batchFeatureService.getUserFeatures(userId));
Map<String, Double> allFeatures = new HashMap<>();
try {
// 并行等待,最多等 15ms
Map<String, Double> rtFeatures = rtFeatureFuture.get(15, TimeUnit.MILLISECONDS);
allFeatures.putAll(rtFeatures);
} catch (TimeoutException e) {
log.warn("实时特征读取超时,使用默认值: userId={}", userId);
allFeatures.putAll(getDefaultRealtimeFeatures());
} catch (Exception e) {
log.error("实时特征读取失败: userId={}", userId, e);
allFeatures.putAll(getDefaultRealtimeFeatures());
}
try {
Map<String, Double> batchFeatures = batchFeatureFuture.get(15, TimeUnit.MILLISECONDS);
allFeatures.putAll(batchFeatures);
} catch (Exception e) {
log.warn("批量特征读取失败,使用缓存: userId={}", userId);
allFeatures.putAll(getCachedBatchFeatures(userId));
}
log.debug("特征组装耗时: {}ms", System.currentTimeMillis() - startTime);
return new FeatureVector(userId, allFeatures);
}
private Map<String, Double> getRealtimeFeatures(String userId) {
String key = "rt_feature:" + userId;
Map<Object, Object> rawMap = redisTemplate.opsForHash().entries(key);
if (rawMap.isEmpty()) {
return getDefaultRealtimeFeatures();
}
Map<String, Double> features = new HashMap<>();
rawMap.forEach((k, v) -> {
try {
features.put("rt_" + k, Double.parseDouble((String) v));
} catch (NumberFormatException e) {
// 忽略非数值字段(如 updated_at 等)
}
});
// 特征新鲜度:最后一次更新距今多久(分钟)
String updatedAt = (String) rawMap.get("updated_at");
if (updatedAt != null) {
double ageMinutes = (System.currentTimeMillis() -
Long.parseLong(updatedAt)) / 60000.0;
features.put("rt_feature_age_minutes", ageMinutes);
}
return features;
}
private Map<String, Double> getDefaultRealtimeFeatures() {
// 实时特征获取失败时的降级默认值
Map<String, Double> defaults = new HashMap<>();
defaults.put("rt_click_1h", 0.0);
defaults.put("rt_purchase_1h", 0.0);
defaults.put("rt_browse_1h", 0.0);
defaults.put("rt_session_depth", 0.0);
defaults.put("rt_feature_age_minutes", 999.0); // 表示特征不新鲜
return defaults;
}
}四、在线学习:实时更新模型参数
在线学习比实时特征更进一步——不仅特征是实时的,连模型参数也是实时更新的。这在广告 CTR 预估、信息流推荐里比较常用。
4.1 FTRL 在线学习算法(Java 实现)
FTRL(Follow The Regularized Leader)是工业界最常用的在线学习算法,特别适合稀疏特征:
/**
* FTRL-Proximal 在线学习
* 论文:Ad Click Prediction: a View from the Trenches (Google 2013)
*/
@Component
public class FtrlOnlineLearner {
private final double alpha; // 学习率参数
private final double beta; // 学习率参数
private final double lambda1; // L1 正则
private final double lambda2; // L2 正则
// 稀疏参数存储(特征ID -> 参数)
private final ConcurrentHashMap<Long, double[]> paramMap;
// paramMap 中每个 double[] 格式:[w, z, n]
// w: 权重, z: FTRL 中间变量, n: 梯度累积
public FtrlOnlineLearner(double alpha, double beta,
double lambda1, double lambda2) {
this.alpha = alpha;
this.beta = beta;
this.lambda1 = lambda1;
this.lambda2 = lambda2;
this.paramMap = new ConcurrentHashMap<>();
}
/**
* 在线预测(Sigmoid 二分类)
*/
public double predict(Map<Long, Double> features) {
double score = 0.0;
for (Map.Entry<Long, Double> entry : features.entrySet()) {
double w = getWeight(entry.getKey());
score += w * entry.getValue();
}
return sigmoid(score);
}
/**
* 在线更新:接收一个样本,更新参数
*/
public void update(Map<Long, Double> features, double label) {
double prediction = predict(features);
double gradient = prediction - label; // 对数损失的梯度
for (Map.Entry<Long, Double> entry : features.entrySet()) {
long featureId = entry.getKey();
double featureVal = entry.getValue();
double g = gradient * featureVal; // 该特征的梯度
double[] param = paramMap.computeIfAbsent(
featureId, k -> new double[]{0, 0, 0});
// 更新 n(梯度平方累积)
double sigma = (Math.sqrt(param[2] + g * g) - Math.sqrt(param[2])) / alpha;
param[1] += g - sigma * param[0]; // 更新 z
param[2] += g * g; // 更新 n
// 重新计算 w(FTRL 的软阈值更新)
param[0] = computeWeight(param[1], param[2]);
}
}
private double getWeight(long featureId) {
double[] param = paramMap.get(featureId);
if (param == null) return 0.0;
return param[0];
}
private double computeWeight(double z, double n) {
// L1 软阈值
if (Math.abs(z) <= lambda1) return 0.0;
double sign = z > 0 ? 1.0 : -1.0;
return -sign * (Math.abs(z) - lambda1) /
((beta + Math.sqrt(n)) / alpha + lambda2);
}
private double sigmoid(double x) {
return 1.0 / (1.0 + Math.exp(-x));
}
/**
* 序列化模型参数(定期持久化到 Redis/S3)
*/
public Map<Long, Double> getWeights() {
Map<Long, Double> weights = new HashMap<>();
paramMap.forEach((id, param) -> {
if (param[0] != 0) weights.put(id, param[0]);
});
return weights;
}
}4.2 Flink 驱动的在线训练
public class OnlineLearningJob {
public static void main(String[] args) throws Exception {
StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
// 读取带标签的训练样本流(用户行为 + 最终结果 Join)
DataStream<LabeledSample> sampleStream = buildLabeledSampleStream(env);
// 在线训练:每个算子实例维护一个模型副本
// 注意:Flink 的并行度 = 模型分片数
sampleStream
.keyBy(sample -> sample.getUserId().hashCode() % 10)
.process(new OnlineTrainingProcessor())
.addSink(new ModelParameterSink());
env.execute("OnlineLearningJob");
}
}
public class OnlineTrainingProcessor
extends KeyedProcessFunction<Integer, LabeledSample, ModelUpdateEvent> {
private transient FtrlOnlineLearner learner;
private transient ValueState<Long> sampleCountState;
@Override
public void open(Configuration parameters) {
learner = new FtrlOnlineLearner(0.1, 1.0, 0.01, 0.01);
ValueStateDescriptor<Long> descriptor =
new ValueStateDescriptor<>("sampleCount", Long.class, 0L);
sampleCountState = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(LabeledSample sample,
Context ctx,
Collector<ModelUpdateEvent> out) throws Exception {
// 在线更新模型
learner.update(sample.getFeatures(), sample.getLabel());
long count = sampleCountState.value() + 1;
sampleCountState.update(count);
// 每处理 1000 个样本,推送一次参数更新
if (count % 1000 == 0) {
Map<Long, Double> weights = learner.getWeights();
out.collect(new ModelUpdateEvent(
ctx.getCurrentKey(), weights, count));
}
}
}五、水位线与乱序事件处理
实时处理的一个难点是乱序事件。用户行为可能因为网络延迟到达 Flink 时已经乱序,需要用 Watermark(水位线)机制来处理。
我踩过的一个坑:Watermark 的延迟设太小,导致大量事件被当作"迟到数据"丢弃,实时特征严重不准。后来的经验是根据实际延迟分布来设——先跑一段时间记录事件从产生到到达 Flink 的延迟,取 P99 作为 Watermark 的延迟参数。
// 监控事件延迟分布
DataStream<UserBehaviorEvent> eventStream = kafkaStream
.assignTimestampsAndWatermarks(
WatermarkStrategy
.<UserBehaviorEvent>forBoundedOutOfOrderness(
Duration.ofSeconds(30)) // 根据实际 P99 延迟设置
.withTimestampAssigner((event, ts) -> event.getEventTime())
.withIdleness(Duration.ofMinutes(5)) // 防止空闲分区阻塞水位线推进
);
// 对迟到事件的处理策略:写入侧输出流,后续补充处理
OutputTag<UserBehaviorEvent> lateTag =
new OutputTag<UserBehaviorEvent>("late-events"){};
SingleOutputStreamOperator<UserFeatureRecord> mainStream = eventStream
.keyBy(UserBehaviorEvent::getUserId)
.window(SlidingEventTimeWindows.of(Time.hours(1), Time.minutes(5)))
.allowedLateness(Time.minutes(5)) // 允许 5 分钟内的迟到数据更新窗口
.sideOutputLateData(lateTag)
.aggregate(new UserBehaviorAggregator(),
new UserFeatureWindowFunction());
// 迟到数据另行处理(比如写入日志或修正已输出的特征)
DataStream<UserBehaviorEvent> lateStream = mainStream.getSideOutput(lateTag);
lateStream.addSink(new LateEventSink());六、踩坑经验汇总
坑一:Flink 状态过大导致 OOM
会话特征的计算需要在 State 里保存用户的事件历史,如果用户活跃且窗口长,State 会越来越大。解决方案:一是用 RocksDB State Backend(磁盘 + 压缩),二是对 State 里的数据做精简,只保存聚合后的中间结果而不是原始事件。
坑二:Kafka 分区不均匀导致热点
某些热门用户事件量是普通用户的几百倍,但 Kafka 默认按 userId.hashCode() 分区,热门用户的事件全打到一个分区,那个 Flink 算子实例成了瓶颈。后来改成了随机分区 + 两阶段聚合(先局部聚合,再全局合并)。
坑三:Redis 写入成为瓶颈
高 QPS 下 Redis 写入成了瓶颈。优化方案:批量写入(Pipeline)+ 本地缓冲:
// Flink Sink 里用 Pipeline 批量写入
private transient Pipeline pipeline;
private transient int bufferCount = 0;
private static final int BUFFER_SIZE = 100;
@Override
public void invoke(UserFeatureRecord record, Context context) {
// 加入 Pipeline
pipeline.hset("rt_feature:" + record.getUserId(),
buildFeatureMap(record));
bufferCount++;
if (bufferCount >= BUFFER_SIZE) {
pipeline.sync(); // 批量提交
bufferCount = 0;
}
}坑四:在线学习的参数发散
FTRL 的学习率如果设得太大,模型参数会在噪声样本上剧烈波动,线上效果反而不如静态模型。上线在线学习前,一定要先在离线数据上做时序验证(用历史数据模拟在线更新过程),确认模型不会发散。
七、小结
实时特征计算是 AI 系统从"T+1 感知"升级到"秒级感知"的关键。Flink 在这个场景里的核心价值是:精确的事件时间语义、有状态的增量计算、以及对乱序数据的优雅处理。
和在线学习结合起来,可以让模型真正做到"跟着用户实时变化"。当然,这套系统的工程复杂度也不低——从 Flink 集群的运维、State 管理到 Redis 的高可用,每一环都需要认真对待。
