深浅模式
ForkJoin
更新: 3/4/2026 字数: 0 字
Java 并发包中专门为分治(Divide and Conquer)任务设计的并行计算框架,非常适合处理可以拆分成小任务、最终合并结果的场景。
一、核心原理
Fork/Join 框架是 JDK 1.7 引入的,基于「分而治之」思想,核心解决「大任务拆小、小任务并行执行、结果合并」的问题,底层依赖 ForkJoinPool 线程池实现。
1.1 核心设计思路
- Fork(拆分):将一个大任务递归拆分成多个独立的子任务,直到子任务小到可以直接计算(达到「阈值」)
- Join(合并):等待所有子任务执行完成,然后合并子任务的结果,最终得到大任务的结果
- 工作窃取(Work-Stealing):这是 Fork/Join 最核心的优化机制
- 每个线程都有自己的任务队列(双端队列)
- 当一个线程的队列空了,它会从其他线程的队列「尾部」窃取任务执行
- 这种机制最大化利用线程资源,减少空闲,提升并行效率
1.2 核心组件
| 组件 | 作用 |
|---|---|
ForkJoinPool | 核心线程池,管理工作线程,实现工作窃取机制 |
ForkJoinTask<V> | 抽象任务类,是所有 Fork/Join 任务的父类,常用子类:RecursiveTask<V>:有返回值的分治任务(最常用); RecursiveAction:无返回值的分治任务; |
ForkJoinWorkerThread | ForkJoinPool 中的工作线程 |
二、实际使用示例
2.1 场景:并行计算大数组的累加和
假设我们有一个包含 1000 万个整数的数组,需要计算所有元素的和。用 Fork/Join 拆分任务,每个子任务计算数组中一个片段的和,最终合并所有片段的结果。
java
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
/**
* Fork/Join 示例:并行计算大数组的累加和
*/
public class ForkJoinSumDemo {
// 任务拆分的阈值(当数组片段长度小于该值时,直接计算,不再拆分)
private static final int THRESHOLD = 10000;
// 自定义分治任务类(有返回值,继承 RecursiveTask<Long>)
static class SumTask extends RecursiveTask<Long> {
private final int[] array; // 要计算的数组
private final int start; // 片段起始索引
private final int end; // 片段结束索引(不包含)
public SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
// 核心方法:实现任务拆分和计算逻辑
@Override
protected Long compute() {
// 1. 如果任务足够小,直接计算
int length = end - start;
if (length <= THRESHOLD) {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}
// 2. 任务太大,拆分成两个子任务
int mid = start + length / 2;
SumTask leftTask = new SumTask(array, start, mid); // 左半部分
SumTask rightTask = new SumTask(array, mid, end); // 右半部分
// 执行子任务(Fork:异步提交子任务)
leftTask.fork();
rightTask.fork();
// 3. 合并子任务结果(Join:等待子任务完成并获取结果)
long leftSum = leftTask.join();
long rightSum = rightTask.join();
// 4. 返回合并后的结果
return leftSum + rightSum;
}
}
public static void main(String[] args) {
// 1. 准备测试数据:生成 1000 万个随机整数(1-100)
int[] array = new int[10_000_000];
Random random = new Random();
for (int i = 0; i < array.length; i++) {
array[i] = random.nextInt(100) + 1;
}
// 2. 创建 ForkJoin 线程池(JDK1.8+ 推荐用 commonPool,也可自定义)
// ForkJoinPool pool = new ForkJoinPool(); // 自定义线程池
ForkJoinPool pool = ForkJoinPool.commonPool(); // 使用JVM默认的公共池
// 3. 提交总任务
SumTask totalTask = new SumTask(array, 0, array.length);
long startTime = System.currentTimeMillis();
Long totalSum = pool.invoke(totalTask); // 执行任务并获取结果
long endTime = System.currentTimeMillis();
// 4. 输出结果
System.out.println("数组累加和:" + totalSum);
System.out.println("Fork/Join 耗时:" + (endTime - startTime) + "ms");
// 5. 验证:单线程计算(对比结果和耗时)
long singleSum = 0;
startTime = System.currentTimeMillis();
for (int num : array) {
singleSum += num;
}
endTime = System.currentTimeMillis();
System.out.println("单线程累加和:" + singleSum);
System.out.println("单线程耗时:" + (endTime - startTime) + "ms");
// 关闭线程池(如果是自定义的pool,必须关闭;commonPool 无需手动关闭)
// pool.shutdown();
}
}运行结果(示例)
text
数组累加和:504987654
Fork/Join 耗时:12ms
单线程累加和:504987654
单线程耗时:35msTIP
注:耗时会因电脑配置不同而变化,但 Fork/Join 并行计算通常比单线程快(数据量越大,优势越明显)。
2.2 关键代码解释
任务阈值(THRESHOLD)
这是性能调优的关键:阈值太小,拆分/合并的开销会超过并行收益;阈值太大,并行度不足。示例中设置为 10000,可根据实际场景调整(比如结合 CPU 核心数)。
compute() 方法
核心逻辑入口,必须重写:
- 先判断任务是否足够小,小则直接计算
- 大则拆分为左右两个子任务,通过 fork() 异步提交
- 再通过 join() 等待子任务完成并获取结果,最后合并返回
ForkJoinPool 使用
- commonPool():JVM 全局共享的 ForkJoin 池,适用于大多数场景,无需手动关闭
- 自定义池:new ForkJoinPool(4)(指定核心线程数),使用后需调用 shutdown() 关闭
三、进阶说明(避坑点)
3.1 避免过度拆分
拆分的任务数建议不超过 CPU 核心数的 2-4 倍,过多拆分会增加调度开销。
3.2 异常处理
如果任务执行中抛出异常,join() 会抛出 CompletionException,需捕获并通过 getCause() 获取原始异常。
3.3 适用场景
- 适合「CPU 密集型」任务(如计算、排序),不适合「IO 密集型」任务(如文件读写、网络请求)
- 必须是可拆分、无状态、线程安全的任务(子任务之间无依赖)
四、总结
4.1 核心原理
Fork/Join 基于「分治 + 工作窃取」,将大任务拆分为小任务并行执行,合并结果,最大化利用 CPU 资源。
4.2 核心用法
- 继承
RecursiveTask<V>(有返回值)或RecursiveAction(无返回值),重写compute()实现拆分/计算逻辑 - 通过
ForkJoinPool提交并执行任务
4.3 适用场景
CPU 密集型、可拆分、无状态的批量计算任务(如大数组/集合的计算、排序、统计)。
五、实际应用场景
5.1 统计超大文本文件中每个单词出现次数
一、实现思路
- 文件拆分:将超大文本文件按字节范围拆分成多个小任务(每个任务处理一段字节),避免一次性加载整个文件到内存(防止 OOM)。
- 并行统计:每个子任务统计自己负责的片段中单词的出现次数(用 HashMap 暂存)。
- 结果合并:递归合并所有子任务的统计结果,最终得到全文件的单词频次。
- 细节处理:
- 避免拆分时将单词截断(比如一个单词跨两个片段);
- 统一单词格式(转小写、去除标点);
- 处理大文件的字节流读取,避免内存溢出。
二、完整实现代码
java
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.regex.Pattern;
/**
* Fork/Join 统计超大文本文件中单词出现次数
*/
public class BigFileWordCount extends RecursiveTask<Map<String, Integer>> {
// 1. 配置参数
private static final long THRESHOLD = 1024 * 1024; // 每个任务处理 1MB 数据(可根据内存调整)
private static final Pattern WORD_PATTERN = Pattern.compile("[^a-zA-Z0-9]+"); // 分割单词的正则(非字母数字都作为分隔符)
private final File file;
private final long start; // 处理的起始字节
private final long end; // 处理的结束字节
// 构造方法:初始化任务的文件范围
public BigFileWordCount(File file, long start, long end) {
this.file = file;
this.start = start;
this.end = end;
}
// 2. 核心方法:拆分任务 + 统计单词
@Override
protected Map<String, Integer> compute() {
// 任务足够小,直接统计
if (end - start <= THRESHOLD) {
return countWordsInRange();
}
// 任务太大,拆分成两个子任务
long mid = start + (end - start) / 2;
// 修正 mid:避免拆分在单词中间(找到最近的分隔符位置)
mid = findSafeSplitPoint(mid);
// 创建子任务
BigFileWordCount leftTask = new BigFileWordCount(file, start, mid);
BigFileWordCount rightTask = new BigFileWordCount(file, mid, end);
// 执行子任务
leftTask.fork();
rightTask.fork();
// 合并子任务结果
Map<String, Integer> leftResult = leftTask.join();
Map<String, Integer> rightResult = rightTask.join();
return mergeResults(leftResult, rightResult);
}
// 3. 核心工具方法:统计指定字节范围内的单词
private Map<String, Integer> countWordsInRange() {
Map<String, Integer> wordCount = new HashMap<>();
try (RandomAccessFile raf = new RandomAccessFile(file, "r")) {
// 定位到起始位置
raf.seek(start);
// 读取指定范围的字节(预留一点,避免截断最后一个单词)
long readLength = end - start + 1024; // 多读取1024字节,防止单词截断
byte[] buffer = new byte[(int) Math.min(readLength, raf.length() - start)];
int bytesRead = raf.read(buffer);
if (bytesRead <= 0) {
return wordCount;
}
// 转成字符串,处理编码
String content = new String(buffer, 0, bytesRead, StandardCharsets.UTF_8);
// 统一转小写,分割单词,过滤空字符串
String[] words = WORD_PATTERN.split(content.toLowerCase());
for (String word : words) {
word = word.trim();
if (word.length() > 0) { // 过滤空单词
wordCount.put(word, wordCount.getOrDefault(word, 0) + 1);
}
}
} catch (IOException e) {
e.printStackTrace();
}
return wordCount;
}
// 4. 工具方法:找到安全的拆分点(避免单词截断)
private long findSafeSplitPoint(long mid) {
try (RandomAccessFile raf = new RandomAccessFile(file, "r")) {
// 从 mid 位置往前找,直到找到分隔符(空格、换行、标点等)
raf.seek(mid);
for (long i = mid; i > start; i--) {
raf.seek(i);
int c = raf.read();
if (c == -1) break;
// 判断是否是分隔符(非字母数字)
if (!Character.isLetterOrDigit(c)) {
return i + 1; // 拆分点在分隔符后,避免截断单词
}
}
} catch (IOException e) {
e.printStackTrace();
}
return mid; // 找不到分隔符,直接用 mid(极端情况)
}
// 5. 工具方法:合并两个单词统计结果
private Map<String, Integer> mergeResults(Map<String, Integer> map1, Map<String, Integer> map2) {
// 优化:将小 map 合并到大 map,减少遍历次数
Map<String, Integer> result = new HashMap<>(map1.size() + map2.size());
result.putAll(map1);
for (Map.Entry<String, Integer> entry : map2.entrySet()) {
String word = entry.getKey();
int count = entry.getValue();
result.put(word, result.getOrDefault(word, 0) + count);
}
return result;
}
// 6. 入口方法:启动统计
public static Map<String, Integer> count(File file) throws IOException {
if (!file.exists() || !file.isFile()) {
throw new IllegalArgumentException("文件不存在或不是普通文件");
}
long fileLength = file.length();
if (fileLength == 0) {
return Collections.emptyMap();
}
// 使用 ForkJoin 池执行任务
ForkJoinPool pool = new ForkJoinPool(); // 默认可用 CPU 核心数
BigFileWordCount task = new BigFileWordCount(file, 0, fileLength);
return pool.invoke(task);
}
// 测试主方法
public static void main(String[] args) {
try {
// 替换为你的超大文本文件路径(比如小说、日志文件)
File bigFile = new File("D:\\big_text_file.txt");
long startTime = System.currentTimeMillis();
// 执行统计
Map<String, Integer> wordCount = count(bigFile);
long endTime = System.currentTimeMillis();
System.out.println("统计完成,耗时:" + (endTime - startTime) + "ms");
System.out.println("总单词数:" + wordCount.values().stream().mapToInt(Integer::intValue).sum());
// 输出出现次数前10的单词
wordCount.entrySet().stream()
.sorted((e1, e2) -> Integer.compare(e2.getValue(), e1.getValue()))
.limit(10)
.forEach(entry -> System.out.println(entry.getKey() + " : " + entry.getValue()));
} catch (Exception e) {
e.printStackTrace();
}
}
}三、关键代码解释
任务拆分阈值(THRESHOLD)
设置为 1MB,每个子任务处理最多 1MB 的数据,可根据服务器内存调整(比如 2MB、4MB)。阈值太小会增加拆分/合并的开销,太大则并行度不足。
安全拆分点(findSafeSplitPoint)
核心解决「单词跨片段」问题:拆分时不会直接按字节数拆分,而是找到最近的分隔符(空格、标点等),确保一个单词完整落在一个片段中。
单词统计(countWordsInRange)
- 使用 RandomAccessFile 随机读取文件指定字节范围,避免加载整个文件到内存
- 正则
[^a-zA-Z0-9]+将非字母数字的字符都作为分隔符,统一转小写避免「Hello」和「hello」被算作不同单词
结果合并(mergeResults)
合并两个 HashMap,将相同单词的计数累加,保证统计结果的正确性。
四、总结
- 核心逻辑:通过 Fork/Join 将大文件拆分为小片段并行统计单词,合并结果,兼顾性能和内存安全
- 关键优化:安全拆分点避免单词截断,按需读取文件片段避免内存溢出,并行计算提升统计效率
- 适用场景:GB 级及以上的超大文本文件单词统计,相比单线程效率提升数倍(取决于 CPU 核心数)
5.2 Fork/Join 实现并行归并排序(适配超大数组)
java
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
/**
* Fork/Join 实现并行归并排序(适配超大数组)
*/
public class ParallelMergeSort extends RecursiveAction {
// 任务拆分阈值:子数组长度小于该值时,用单线程排序(减少拆分开销)
private static final int THRESHOLD = 10000;
private final int[] array;
private final int left;
private final int right;
// 构造方法:指定要排序的数组范围
public ParallelMergeSort(int[] array, int left, int right) {
this.array = array;
this.left = left;
this.right = right;
}
// 核心方法:拆分任务 + 排序 + 合并
@Override
protected void compute() {
// 1. 子数组足够小,直接用单线程排序(Arrays.sort 是优化后的双轴快排,小数据效率高)
if (right - left <= THRESHOLD) {
Arrays.sort(array, left, right + 1);
return;
}
// 2. 拆分:将数组分成左右两部分
int mid = left + (right - left) / 2;
ParallelMergeSort leftTask = new ParallelMergeSort(array, left, mid);
ParallelMergeSort rightTask = new ParallelMergeSort(array, mid + 1, right);
// 3. Fork:异步执行子任务(并行排序左右子数组)
leftTask.fork();
rightTask.fork();
// 4. Join:等待子任务完成
leftTask.join();
rightTask.join();
// 5. 合并:将两个有序子数组合并为一个有序数组
merge(left, mid, right);
}
// 归并排序的核心合并方法
private void merge(int left, int mid, int right) {
// 临时数组存储合并结果
int[] temp = new int[right - left + 1];
int i = left; // 左子数组指针
int j = mid + 1; // 右子数组指针
int k = 0; // 临时数组指针
// 合并两个有序子数组
while (i <= mid && j <= right) {
if (array[i] <= array[j]) {
temp[k++] = array[i++];
} else {
temp[k++] = array[j++];
}
}
// 拷贝左子数组剩余元素
while (i <= mid) {
temp[k++] = array[i++];
}
// 拷贝右子数组剩余元素
while (j <= right) {
temp[k++] = array[j++];
}
// 将临时数组的结果拷贝回原数组
System.arraycopy(temp, 0, array, left, temp.length);
}
// 对外提供的排序入口方法
public static void sort(int[] array) {
if (array == null || array.length <= 1) {
return;
}
// 使用 ForkJoin 池执行并行排序(默认线程数 = CPU 核心数)
ForkJoinPool pool = ForkJoinPool.commonPool();
pool.invoke(new ParallelMergeSort(array, 0, array.length - 1));
}
// 测试:对比并行归并排序和单线程排序的耗时
public static void main(String[] args) {
// 生成 1000 万个随机整数的超大数组
int[] array = new int[10_000_000];
Random random = new Random();
for (int i = 0; i < array.length; i++) {
array[i] = random.nextInt(Integer.MAX_VALUE);
}
// 1. 并行归并排序(Fork/Join)
int[] parallelArray = Arrays.copyOf(array, array.length);
long startTime = System.currentTimeMillis();
ParallelMergeSort.sort(parallelArray);
long parallelTime = System.currentTimeMillis() - startTime;
System.out.println("Fork/Join 并行归并排序耗时:" + parallelTime + "ms");
// 2. 单线程排序(Arrays.sort)
int[] singleArray = Arrays.copyOf(array, array.length);
startTime = System.currentTimeMillis();
Arrays.sort(singleArray);
long singleTime = System.currentTimeMillis() - startTime;
System.out.println("单线程排序耗时:" + singleTime + "ms");
// 验证排序结果一致
System.out.println("排序结果是否一致:" + Arrays.equals(parallelArray, singleArray));
}
}text
Fork/Join 并行归并排序耗时:931ms
单线程排序耗时:1494ms
排序结果是否一致:true