JDK8 Stream源码解析:Spliterator分治遍历与Collector聚合器
2026/4/30大约 13 分钟
JDK8 Stream源码解析:Spliterator分治遍历与Collector聚合器
适读人群:熟悉Stream API基本用法、想深入理解底层实现的Java开发者 | 阅读时长:约20分钟
开篇故事
2020年双十一大促前,我们团队做了一次性能压测,发现有个数据处理服务CPU跑满了。我登上机器一看,火焰图里有个方法占了35%的时间:一段用Stream处理订单数据的代码。
当时业务同学的代码大概是这样:
List<Order> orders = fetchAllOrders(); // 十几万条
Map<String, Double> result = orders.stream()
.filter(o -> o.getStatus() == OrderStatus.PAID)
.collect(Collectors.groupingBy(
Order::getUserId,
Collectors.summingDouble(Order::getAmount)
));我们改成了并行流:
Map<String, Double> result = orders.parallelStream()
.filter(o -> o.getStatus() == OrderStatus.PAID)
.collect(Collectors.groupingBy(
Order::getUserId,
Collectors.summingDouble(Order::getAmount)
));CPU占用降到了8%,处理时间从1.2秒降到了340毫秒。
但这引出了一个问题:为什么parallelStream()能这么快?Stream底层到底是怎么把数据分给多个CPU的?这里面有个叫Spliterator的数据结构,就是Stream并行化的核心秘密。今天彻底讲清楚。
一、为什么Stream需要Spliterator?
1.1 传统遍历的局限
在JDK8之前,遍历集合只有两种方式:
- 外部迭代(for循环、Iterator):调用方控制遍历节奏,天然串行
- Enumeration(JDK1.0时代):只能顺序遍历
这两种方式都无法高效地支持并行遍历,因为没有"把数据分成两半"的能力。
1.2 Spliterator的设计目标
Spliterator(Splittable Iterator,可分割迭代器)在JDK8引入,解决了三个问题:
- 分割能力:
trySplit()将数据集分成两部分,支持并行处理 - 特征描述:通过
characteristics()告诉Stream数据的特征(有序、去重、有限等),让Stream做更多优化 - 大小估算:
estimateSize()帮助Fork/Join框架决定最优分割策略
二、Spliterator深度解析
2.1 Spliterator接口定义
public interface Spliterator<T> {
// 核心方法1:尝试处理下一个元素,如果有则返回true
boolean tryAdvance(Consumer<? super T> action);
// 核心方法2:尝试分割,返回新的Spliterator负责前半部分
// 如果无法分割(数据太少或不支持),返回null
Spliterator<T> trySplit();
// 核心方法3:估算剩余元素数量
long estimateSize();
// 核心方法4:返回特征位掩码
int characteristics();
// 默认方法:遍历所有剩余元素
default void forEachRemaining(Consumer<? super T> action) {
do { } while (tryAdvance(action));
}
// 如果SIZED特征已知,返回精确大小
default long getExactSizeIfKnown() {
return (characteristics() & SIZED) == 0 ? -1L : estimateSize();
}
// 特征常量
int ORDERED = 0x00000010; // 有顺序(如List)
int DISTINCT = 0x00000001; // 元素唯一(如Set)
int SORTED = 0x00000004; // 已排序
int SIZED = 0x00000040; // 大小已知
int NONNULL = 0x00000100; // 无null元素
int IMMUTABLE = 0x00000400; // 数据不可变
int CONCURRENT = 0x00001000; // 支持并发修改
int SUBSIZED = 0x00004000; // 分割后子Spliterator大小也已知
}2.2 并行Stream的工作流程
┌─────────────────────────────────────────────────────────────────────┐
│ parallelStream() 处理流程 │
└─────────────────────────────────────────────────────────────────────┘
数据源 (List/Array/etc.)
│
▼
获取根 Spliterator
│
▼
┌─────────────────────────────────────────────────┐
│ Fork/Join框架 │
│ │
│ Spliterator(全量) │
│ │ │
│ ├── trySplit() ──► Spliterator(前半) │
│ │ │ │
│ │ ├── trySplit() ──► │
│ │ │ Spliterator │
│ │ │ (前1/4) │
│ │ │ │
│ │ └── forEachRemaining│
│ │ (后1/4) │
│ │ │
│ └── 自身处理后半部分 │
│ │
│ 各子任务在不同线程执行,结果最终合并 │
└─────────────────────────────────────────────────┘
│
▼
Collector合并结果Mermaid版本:
2.3 ArrayList的Spliterator实现
深入看一下ArrayList内部的ArrayListSpliterator(JDK源码):
// ArrayList内部类,JDK源码简化版
static final class ArrayListSpliterator<E> implements Spliterator<E> {
private final ArrayList<E> list;
private int index; // 当前位置
private int fence; // 结束位置(-1表示还未确定)
private int expectedModCount; // 用于检测并发修改
ArrayListSpliterator(ArrayList<E> list, int origin, int fence, int expectedModCount) {
this.list = list;
this.index = origin;
this.fence = fence;
this.expectedModCount = expectedModCount;
}
private int getFence() {
int hi;
if ((hi = fence) < 0) {
// 第一次调用时确定fence
expectedModCount = list.modCount;
hi = fence = list.size;
}
return hi;
}
@Override
public ArrayListSpliterator<E> trySplit() {
int hi = getFence();
int lo = index;
int mid = (lo + hi) >>> 1; // 中间位置(无符号右移防止溢出)
// 如果剩余元素少于等于1,无法分割
if (lo >= mid) return null;
// 把前半部分给新的Spliterator,自己保留后半部分
index = mid;
return new ArrayListSpliterator<>(list, lo, mid, expectedModCount);
}
@Override
public boolean tryAdvance(Consumer<? super E> action) {
if (action == null) throw new NullPointerException();
int hi = getFence();
int i = index;
if (i < hi) {
index = i + 1;
@SuppressWarnings("unchecked")
E e = (E) list.elementData[i];
action.accept(e);
if (list.modCount != expectedModCount)
throw new ConcurrentModificationException();
return true;
}
return false;
}
@Override
public long estimateSize() {
return (long)(getFence() - index);
}
@Override
public int characteristics() {
// ArrayList:有序、大小已知、子分割大小也已知
return ORDERED | SIZED | SUBSIZED;
}
}关键洞察:trySplit()用的是(lo + hi) >>> 1而不是(lo + hi) / 2,这是防止整数溢出的经典技巧。当lo和hi都很大时,lo + hi可能溢出,>>> 1(无符号右移)能正确处理。
2.4 Stream的惰性求值链
Stream的操作分两类:
- 中间操作(Intermediate):
filter、map、flatMap、sorted等——不立即执行,构建"流水线" - 终止操作(Terminal):
collect、forEach、count、findFirst等——触发实际计算
┌──────────────────────────────────────────────────────┐
│ Stream Pipeline │
│ │
│ Source ──► filter ──► map ──► collect │
│ (Spliterator) (惰性) (惰性) (触发执行) │
│ │
│ 内部实现:AbstractPipeline链表 │
│ Head → StatelessOp(filter) → StatelessOp(map) → │
│ TerminalOp(collect) │
└──────────────────────────────────────────────────────┘2.5 Collector接口解析
Collector<T, A, R>是三个泛型的聚合器接口:
T:输入元素类型A:中间累积容器类型(通常是可变的)R:最终结果类型
public interface Collector<T, A, R> {
// 创建一个新的空累积容器(每个线程/分区各自创建一个)
Supplier<A> supplier();
// 将一个元素T累积到容器A中
BiConsumer<A, T> accumulator();
// 并行场景:合并两个容器(左合并右)
BinaryOperator<A> combiner();
// 将累积容器A转换为最终结果R
Function<A, R> finisher();
// 优化特征(CONCURRENT, UNORDERED, IDENTITY_FINISH)
Set<Characteristics> characteristics();
}三、完整代码示例
3.1 自定义Spliterator实现
import java.util.*;
import java.util.function.*;
import java.util.stream.*;
/**
* 自定义Spliterator示例
* 场景:处理一个大文件,按行分批并行处理
*/
public class CustomSpliteratorDemo {
// ===== 旧写法:手动分批处理 =====
public static Map<String, Long> oldWayWordCount(List<String> lines) {
Map<String, Long> result = new HashMap<>();
for (String line : lines) {
String[] words = line.split("\\s+");
for (String word : words) {
if (!word.isEmpty()) {
result.merge(word.toLowerCase(), 1L, Long::sum);
}
}
}
return result;
}
// ===== 新写法:使用Stream + 自定义Collector =====
public static Map<String, Long> newWayWordCount(List<String> lines) {
return lines.parallelStream()
.flatMap(line -> Arrays.stream(line.split("\\s+")))
.filter(word -> !word.isEmpty())
.map(String::toLowerCase)
.collect(Collectors.groupingBy(
Function.identity(),
Collectors.counting()
));
}
// ===== 自定义Spliterator:范围数字Spliterator =====
static class RangeSpliterator implements Spliterator.OfLong {
private long current;
private final long end;
RangeSpliterator(long start, long end) {
this.current = start;
this.end = end;
}
@Override
public OfLong trySplit() {
long lo = current;
long mid = (current + end) >>> 1;
if (lo >= mid) return null; // 太小,不分割
current = mid; // 自己负责后半
return new RangeSpliterator(lo, mid); // 返回前半
}
@Override
public boolean tryAdvance(LongConsumer action) {
if (current < end) {
action.accept(current++);
return true;
}
return false;
}
@Override
public long estimateSize() {
return end - current;
}
@Override
public int characteristics() {
return ORDERED | SIZED | SUBSIZED | IMMUTABLE | NONNULL | DISTINCT | SORTED;
}
}
public static void testCustomSpliterator() {
// 用自定义Spliterator创建Stream
LongStream stream = StreamSupport.longStream(
new RangeSpliterator(0, 1_000_000),
true // parallel=true
);
long sum = stream.sum();
System.out.println("Sum 0~999999 = " + sum); // 499999500000
// 对比:串行版
long start = System.nanoTime();
long serialSum = StreamSupport.longStream(new RangeSpliterator(0, 1_000_000), false).sum();
long serialTime = System.nanoTime() - start;
start = System.nanoTime();
long parallelSum = StreamSupport.longStream(new RangeSpliterator(0, 1_000_000), true).sum();
long parallelTime = System.nanoTime() - start;
System.out.printf("串行: %d ms, 并行: %d ms%n",
serialTime / 1_000_000, parallelTime / 1_000_000);
}
public static void main(String[] args) {
List<String> lines = Arrays.asList(
"hello world hello java",
"java stream api is great",
"hello stream hello"
);
System.out.println("旧写法: " + oldWayWordCount(lines));
System.out.println("新写法: " + newWayWordCount(lines));
testCustomSpliterator();
}
}3.2 自定义Collector实现
import java.util.*;
import java.util.function.*;
import java.util.stream.*;
/**
* 自定义Collector完整示例
* 场景:统计Top N元素
*/
public class CustomCollectorDemo {
// ===== 旧写法:手动实现TopN =====
public static <T extends Comparable<T>> List<T> oldTopN(List<T> list, int n) {
List<T> sorted = new ArrayList<>(list);
Collections.sort(sorted, Collections.reverseOrder());
return sorted.subList(0, Math.min(n, sorted.size()));
}
// ===== 新写法:使用内置Collector =====
public static <T extends Comparable<T>> List<T> newTopN(List<T> list, int n) {
return list.stream()
.sorted(Comparator.reverseOrder())
.limit(n)
.collect(Collectors.toList());
}
// ===== 更优雅:自定义TopN Collector =====
/**
* 自定义TopN Collector
* T: 输入类型
* A: 中间容器(PriorityQueue,维护最小堆)
* R: 结果类型(List)
*/
public static <T extends Comparable<T>> Collector<T, ?, List<T>> topN(int n) {
return Collector.of(
// supplier:创建大小为n的最小堆
() -> new PriorityQueue<T>(n),
// accumulator:如果堆未满,直接加入;否则与堆顶比较
(heap, elem) -> {
if (heap.size() < n) {
heap.offer(elem);
} else if (elem.compareTo(heap.peek()) > 0) {
heap.poll(); // 移除最小元素
heap.offer(elem); // 加入更大的元素
}
},
// combiner:合并两个堆(并行场景)
(heap1, heap2) -> {
for (T elem : heap2) {
if (heap1.size() < n) {
heap1.offer(elem);
} else if (elem.compareTo(heap1.peek()) > 0) {
heap1.poll();
heap1.offer(elem);
}
}
return heap1;
},
// finisher:将堆转为有序列表
heap -> {
List<T> result = new ArrayList<>(heap);
result.sort(Comparator.reverseOrder());
return result;
},
// characteristics:无特殊优化标志
Collector.Characteristics.UNORDERED
);
}
// ===== 分组统计Collector组合 =====
public static void collectorCombination() {
List<String> words = Arrays.asList(
"apple", "banana", "avocado", "blueberry",
"apricot", "cherry", "blackberry"
);
// 按首字母分组,每组取最长的2个单词
Map<Character, List<String>> result = words.stream()
.collect(Collectors.groupingBy(
s -> s.charAt(0),
Collectors.collectingAndThen(
Collectors.toList(),
list -> list.stream()
.sorted(Comparator.comparingInt(String::length).reversed())
.limit(2)
.collect(Collectors.toList())
)
));
result.forEach((k, v) -> System.out.println(k + ": " + v));
// 统计各首字母的单词平均长度
Map<Character, Double> avgLength = words.stream()
.collect(Collectors.groupingBy(
s -> s.charAt(0),
Collectors.averagingInt(String::length)
));
System.out.println("Avg length by initial: " + avgLength);
// 多级分组
Map<Character, Map<Integer, List<String>>> multiGroup = words.stream()
.collect(Collectors.groupingBy(
s -> s.charAt(0),
Collectors.groupingBy(String::length)
));
System.out.println("Multi group: " + multiGroup);
}
public static void main(String[] args) {
List<Integer> numbers = Arrays.asList(3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5);
System.out.println("旧写法TopN: " + oldTopN(numbers, 3));
System.out.println("新写法TopN: " + newTopN(numbers, 3));
System.out.println("自定义Collector TopN: " +
numbers.stream().collect(topN(3)));
System.out.println("\n=== Collector组合 ===");
collectorCombination();
}
}3.3 Stream性能对比测试
import java.util.*;
import java.util.stream.*;
import java.util.concurrent.*;
/**
* Stream性能对比:串行 vs 并行 vs 传统for循环
* 引入版本:JDK8 GA(2014年3月)
*/
public class StreamPerformanceTest {
static final int SIZE = 1_000_000;
static final int WARM_UP = 5;
static final int ITERATIONS = 20;
static List<Integer> data;
static {
Random random = new Random(42);
data = new ArrayList<>(SIZE);
for (int i = 0; i < SIZE; i++) {
data.add(random.nextInt(10000));
}
}
// 测试1:过滤+求和
public static void filterAndSum() {
System.out.println("\n=== 过滤+求和 (n=" + SIZE + ") ===");
// 传统for循环
long start = System.nanoTime();
for (int i = 0; i < WARM_UP; i++) {
long sum = 0;
for (int x : data) {
if (x > 5000) sum += x;
}
}
long[] times = new long[ITERATIONS];
for (int i = 0; i < ITERATIONS; i++) {
start = System.nanoTime();
long sum = 0;
for (int x : data) {
if (x > 5000) sum += x;
}
times[i] = System.nanoTime() - start;
}
System.out.printf("for循环: avg=%d ms%n", avg(times));
// 串行Stream
for (int i = 0; i < WARM_UP; i++) {
data.stream().filter(x -> x > 5000).mapToLong(Integer::longValue).sum();
}
for (int i = 0; i < ITERATIONS; i++) {
start = System.nanoTime();
data.stream().filter(x -> x > 5000).mapToLong(Integer::longValue).sum();
times[i] = System.nanoTime() - start;
}
System.out.printf("串行Stream: avg=%d ms%n", avg(times));
// 并行Stream
for (int i = 0; i < WARM_UP; i++) {
data.parallelStream().filter(x -> x > 5000).mapToLong(Integer::longValue).sum();
}
for (int i = 0; i < ITERATIONS; i++) {
start = System.nanoTime();
data.parallelStream().filter(x -> x > 5000).mapToLong(Integer::longValue).sum();
times[i] = System.nanoTime() - start;
}
System.out.printf("并行Stream: avg=%d ms%n", avg(times));
// IntStream(避免装箱)
int[] primitiveData = data.stream().mapToInt(Integer::intValue).toArray();
for (int i = 0; i < WARM_UP; i++) {
Arrays.stream(primitiveData).filter(x -> x > 5000).asLongStream().sum();
}
for (int i = 0; i < ITERATIONS; i++) {
start = System.nanoTime();
Arrays.stream(primitiveData).filter(x -> x > 5000).asLongStream().sum();
times[i] = System.nanoTime() - start;
}
System.out.printf("IntStream: avg=%d ms%n", avg(times));
}
static long avg(long[] times) {
long sum = 0;
for (long t : times) sum += t;
return sum / times.length / 1_000_000;
}
public static void main(String[] args) {
System.out.println("CPU核心数: " + Runtime.getRuntime().availableProcessors());
filterAndSum();
// 注意:并行Stream不一定总是更快
// 以下场景串行更好:
// 1. 数据量小(< 10000):分割和合并的开销超过并行收益
// 2. 操作本身很轻量:比如简单过滤
// 3. 有副作用的操作:如写入共享Map
System.out.println("\n结论:");
System.out.println("- 数据量 > 10万 且 CPU密集型操作:parallelStream()更快");
System.out.println("- 简单IO密集型:用虚拟线程(JDK21)更合适");
System.out.println("- 避免在并行Stream中使用有状态操作和共享可变变量");
}
}四、踩坑实录
坑1:并行Stream使用ForkJoinPool.commonPool导致线程饥饿
// 危险:所有parallelStream默认共用ForkJoinPool.commonPool
// 如果在Web服务中大量使用,会影响所有并行任务
// 问题代码:数据库查询结果用parallelStream处理
// 如果这个方法被大量并发调用,commonPool会被打满
public List<String> processData(List<Order> orders) {
return orders.parallelStream() // 共用commonPool!
.map(this::heavyProcess)
.collect(Collectors.toList());
}
// 正确做法:使用自定义ForkJoinPool
public List<String> processDataSafe(List<Order> orders) throws Exception {
ForkJoinPool customPool = new ForkJoinPool(4); // 限制并行度
try {
return customPool.submit(() ->
orders.parallelStream()
.map(this::heavyProcess)
.collect(Collectors.toList())
).get();
} finally {
customPool.shutdown();
}
}
// 更好:复用ForkJoinPool,而不是每次创建
// 或者在JDK21中用虚拟线程替代坑2:Stream只能消费一次
Stream<String> stream = Arrays.asList("a", "b", "c").stream();
stream.forEach(System.out::println); // 正常
// 再次使用会抛出IllegalStateException: stream has already been operated upon or closed
try {
stream.forEach(System.out::println);
} catch (IllegalStateException e) {
System.out.println("Stream已关闭: " + e.getMessage());
}
// 正确做法:每次使用前重新创建Stream
List<String> list = Arrays.asList("a", "b", "c");
list.stream().forEach(System.out::println); // 新Stream
list.stream().filter(s -> !s.equals("b")).forEach(System.out::println); // 又一个新Stream
// 如果需要多次操作,先collect成List
List<String> filtered = list.stream()
.filter(s -> !s.equals("b"))
.collect(Collectors.toList()); // 收集后可多次使用坑3:flatMap与null的陷阱
import java.util.*;
import java.util.stream.*;
public class FlatMapNullTrap {
public static void main(String[] args) {
List<List<String>> nested = Arrays.asList(
Arrays.asList("a", "b"),
null, // 外层有null列表
Arrays.asList("c", null) // 内层有null元素
);
// 错误:NullPointerException
try {
nested.stream()
.flatMap(Collection::stream) // 内层null列表会NPE
.collect(Collectors.toList());
} catch (NullPointerException e) {
System.out.println("NPE: " + e);
}
// 正确:过滤掉null
List<String> result = nested.stream()
.filter(Objects::nonNull) // 过滤null列表
.flatMap(Collection::stream)
.filter(Objects::nonNull) // 过滤null元素
.collect(Collectors.toList());
System.out.println(result); // [a, b, c]
// 或者用Optional包装
List<String> result2 = nested.stream()
.flatMap(list -> list == null ? Stream.empty() : list.stream())
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
}坑4:有状态的中间操作破坏并行性能
// 危险:sorted()是有状态操作,会破坏并行Stream的性能
List<Integer> list = IntStream.range(0, 1_000_000)
.boxed()
.collect(Collectors.toList());
// 这个parallel()几乎没用:sorted()需要汇聚所有数据再排序
List<Integer> result = list.parallelStream()
.filter(x -> x % 2 == 0)
.sorted() // 有状态!强制同步
.limit(100)
.collect(Collectors.toList());
// 更好的做法:把sorted()放在适当位置
// 如果只需要top100,用不同策略
List<Integer> better = list.parallelStream()
.filter(x -> x % 2 == 0)
.collect(Collectors.toList()); // 先并行过滤
better.sort(Comparator.naturalOrder()); // 再串行排序
List<Integer> top100 = better.subList(0, 100);坑5:Collectors.toMap遇到重复key会抛异常
List<String> words = Arrays.asList("apple", "apricot", "banana");
// 危险:如果有重复key,toMap会抛IllegalStateException
try {
Map<Character, String> map = words.stream()
.collect(Collectors.toMap(
s -> s.charAt(0), // key:首字母
s -> s // value:单词
)); // 'a'出现了两次,抛异常!
} catch (IllegalStateException e) {
System.out.println("重复key: " + e.getMessage());
}
// 正确:提供mergeFunction
Map<Character, String> map = words.stream()
.collect(Collectors.toMap(
s -> s.charAt(0),
s -> s,
(existing, newVal) -> existing + "," + newVal // 合并重复key
));
System.out.println(map); // {a=apple,apricot, b=banana}
// 或者用groupingBy得到List
Map<Character, List<String>> groupMap = words.stream()
.collect(Collectors.groupingBy(s -> s.charAt(0)));五、总结与延伸
5.1 Spliterator特征对性能的影响
| 特征 | 作用 | 优化效果 |
|---|---|---|
SIZED | 大小已知 | 可预分配容量,避免数组扩容 |
SUBSIZED | 分割后大小已知 | ForkJoin可以均匀分配任务 |
ORDERED | 有顺序 | 某些操作需要保序,影响并行度 |
DISTINCT | 元素唯一 | distinct()操作可以跳过 |
SORTED | 已排序 | sorted()操作可以跳过 |
IMMUTABLE | 不可变 | 可以省略并发修改检测 |
5.2 选择串行还是并行的经验法则
数据量 × 操作复杂度 > 阈值 → 考虑parallelStream
↓
阈值经验值:约10万元素以上
其他考虑因素:
- CPU核心数(4核以下,并行收益有限)
- 是否有共享可变状态(有则不适合并行)
- 操作是否IO密集(IO密集建议用虚拟线程而非并行流)
- 数据结构是否支持高效分割(LinkedList不适合,ArrayList/Array最佳)5.3 版本兼容建议
- JDK8~JDK10:Stream API稳定,Collector组合基本够用
- JDK9:增加了
Stream.takeWhile(),Stream.dropWhile(),Stream.iterate(seed, predicate, f),更强大 - JDK16:
Stream.toList()(不可变List),比Collectors.toList()更简洁 - JDK21:结合虚拟线程,IO密集型任务推荐用
ExecutorService+虚拟线程,而非parallelStream
