Skip to content

ForkJoin

更新: 3/4/2026 字数: 0 字

Java 并发包中专门为分治(Divide and Conquer)任务设计的并行计算框架,非常适合处理可以拆分成小任务、最终合并结果的场景。

一、核心原理

Fork/Join 框架是 JDK 1.7 引入的,基于「分而治之」思想,核心解决「大任务拆小、小任务并行执行、结果合并」的问题,底层依赖 ForkJoinPool 线程池实现。

1.1 核心设计思路

  • Fork(拆分):将一个大任务递归拆分成多个独立的子任务,直到子任务小到可以直接计算(达到「阈值」)
  • Join(合并):等待所有子任务执行完成,然后合并子任务的结果,最终得到大任务的结果
  • 工作窃取(Work-Stealing):这是 Fork/Join 最核心的优化机制
    • 每个线程都有自己的任务队列(双端队列)
    • 当一个线程的队列空了,它会从其他线程的队列「尾部」窃取任务执行
    • 这种机制最大化利用线程资源,减少空闲,提升并行效率

1.2 核心组件

组件作用
ForkJoinPool核心线程池,管理工作线程,实现工作窃取机制
ForkJoinTask<V>抽象任务类,是所有 Fork/Join 任务的父类,常用子类:RecursiveTask<V>:有返回值的分治任务(最常用); RecursiveAction:无返回值的分治任务;
ForkJoinWorkerThreadForkJoinPool 中的工作线程

二、实际使用示例

2.1 场景:并行计算大数组的累加和

假设我们有一个包含 1000 万个整数的数组,需要计算所有元素的和。用 Fork/Join 拆分任务,每个子任务计算数组中一个片段的和,最终合并所有片段的结果。

java
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;

/**
 * Fork/Join 示例:并行计算大数组的累加和
 */
public class ForkJoinSumDemo {

    // 任务拆分的阈值(当数组片段长度小于该值时,直接计算,不再拆分)
    private static final int THRESHOLD = 10000;

    // 自定义分治任务类(有返回值,继承 RecursiveTask<Long>)
    static class SumTask extends RecursiveTask<Long> {
        private final int[] array;    // 要计算的数组
        private final int start;      // 片段起始索引
        private final int end;        // 片段结束索引(不包含)

        public SumTask(int[] array, int start, int end) {
            this.array = array;
            this.start = start;
            this.end = end;
        }

        // 核心方法:实现任务拆分和计算逻辑
        @Override
        protected Long compute() {
            // 1. 如果任务足够小,直接计算
            int length = end - start;
            if (length <= THRESHOLD) {
                long sum = 0;
                for (int i = start; i < end; i++) {
                    sum += array[i];
                }
                return sum;
            }

            // 2. 任务太大,拆分成两个子任务
            int mid = start + length / 2;
            SumTask leftTask = new SumTask(array, start, mid);  // 左半部分
            SumTask rightTask = new SumTask(array, mid, end);   // 右半部分

            // 执行子任务(Fork:异步提交子任务)
            leftTask.fork();
            rightTask.fork();

            // 3. 合并子任务结果(Join:等待子任务完成并获取结果)
            long leftSum = leftTask.join();
            long rightSum = rightTask.join();

            // 4. 返回合并后的结果
            return leftSum + rightSum;
        }
    }

    public static void main(String[] args) {
        // 1. 准备测试数据:生成 1000 万个随机整数(1-100)
        int[] array = new int[10_000_000];
        Random random = new Random();
        for (int i = 0; i < array.length; i++) {
            array[i] = random.nextInt(100) + 1;
        }

        // 2. 创建 ForkJoin 线程池(JDK1.8+ 推荐用 commonPool,也可自定义)
        // ForkJoinPool pool = new ForkJoinPool(); // 自定义线程池
        ForkJoinPool pool = ForkJoinPool.commonPool(); // 使用JVM默认的公共池

        // 3. 提交总任务
        SumTask totalTask = new SumTask(array, 0, array.length);
        long startTime = System.currentTimeMillis();
        Long totalSum = pool.invoke(totalTask); // 执行任务并获取结果
        long endTime = System.currentTimeMillis();

        // 4. 输出结果
        System.out.println("数组累加和:" + totalSum);
        System.out.println("Fork/Join 耗时:" + (endTime - startTime) + "ms");

        // 5. 验证:单线程计算(对比结果和耗时)
        long singleSum = 0;
        startTime = System.currentTimeMillis();
        for (int num : array) {
            singleSum += num;
        }
        endTime = System.currentTimeMillis();
        System.out.println("单线程累加和:" + singleSum);
        System.out.println("单线程耗时:" + (endTime - startTime) + "ms");

        // 关闭线程池(如果是自定义的pool,必须关闭;commonPool 无需手动关闭)
        // pool.shutdown();
    }
}

运行结果(示例)

text
数组累加和:504987654
Fork/Join 耗时:12ms
单线程累加和:504987654
单线程耗时:35ms

TIP

注:耗时会因电脑配置不同而变化,但 Fork/Join 并行计算通常比单线程快(数据量越大,优势越明显)。

2.2 关键代码解释

任务阈值(THRESHOLD)

这是性能调优的关键:阈值太小,拆分/合并的开销会超过并行收益;阈值太大,并行度不足。示例中设置为 10000,可根据实际场景调整(比如结合 CPU 核心数)。

compute() 方法

核心逻辑入口,必须重写:

  • 先判断任务是否足够小,小则直接计算
  • 大则拆分为左右两个子任务,通过 fork() 异步提交
  • 再通过 join() 等待子任务完成并获取结果,最后合并返回

ForkJoinPool 使用

  • commonPool():JVM 全局共享的 ForkJoin 池,适用于大多数场景,无需手动关闭
  • 自定义池:new ForkJoinPool(4)(指定核心线程数),使用后需调用 shutdown() 关闭

三、进阶说明(避坑点)

3.1 避免过度拆分

拆分的任务数建议不超过 CPU 核心数的 2-4 倍,过多拆分会增加调度开销。

3.2 异常处理

如果任务执行中抛出异常,join() 会抛出 CompletionException,需捕获并通过 getCause() 获取原始异常。

3.3 适用场景

  • 适合「CPU 密集型」任务(如计算、排序),不适合「IO 密集型」任务(如文件读写、网络请求)
  • 必须是可拆分、无状态、线程安全的任务(子任务之间无依赖)

四、总结

4.1 核心原理

Fork/Join 基于「分治 + 工作窃取」,将大任务拆分为小任务并行执行,合并结果,最大化利用 CPU 资源。

4.2 核心用法

  • 继承 RecursiveTask<V>(有返回值)或 RecursiveAction(无返回值),重写 compute() 实现拆分/计算逻辑
  • 通过 ForkJoinPool 提交并执行任务

4.3 适用场景

CPU 密集型、可拆分、无状态的批量计算任务(如大数组/集合的计算、排序、统计)。

五、实际应用场景

5.1 统计超大文本文件中每个单词出现次数

一、实现思路

  • 文件拆分:将超大文本文件按字节范围拆分成多个小任务(每个任务处理一段字节),避免一次性加载整个文件到内存(防止 OOM)。
  • 并行统计:每个子任务统计自己负责的片段中单词的出现次数(用 HashMap 暂存)。
  • 结果合并:递归合并所有子任务的统计结果,最终得到全文件的单词频次。
  • 细节处理:
    • 避免拆分时将单词截断(比如一个单词跨两个片段);
    • 统一单词格式(转小写、去除标点);
    • 处理大文件的字节流读取,避免内存溢出。

二、完整实现代码

java
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.regex.Pattern;

/**
 * Fork/Join 统计超大文本文件中单词出现次数
 */
public class BigFileWordCount extends RecursiveTask<Map<String, Integer>> {
    // 1. 配置参数
    private static final long THRESHOLD = 1024 * 1024; // 每个任务处理 1MB 数据(可根据内存调整)
    private static final Pattern WORD_PATTERN = Pattern.compile("[^a-zA-Z0-9]+"); // 分割单词的正则(非字母数字都作为分隔符)
    private final File file;
    private final long start; // 处理的起始字节
    private final long end;   // 处理的结束字节

    // 构造方法:初始化任务的文件范围
    public BigFileWordCount(File file, long start, long end) {
        this.file = file;
        this.start = start;
        this.end = end;
    }

    // 2. 核心方法:拆分任务 + 统计单词
    @Override
    protected Map<String, Integer> compute() {
        // 任务足够小,直接统计
        if (end - start <= THRESHOLD) {
            return countWordsInRange();
        }

        // 任务太大,拆分成两个子任务
        long mid = start + (end - start) / 2;
        // 修正 mid:避免拆分在单词中间(找到最近的分隔符位置)
        mid = findSafeSplitPoint(mid);

        // 创建子任务
        BigFileWordCount leftTask = new BigFileWordCount(file, start, mid);
        BigFileWordCount rightTask = new BigFileWordCount(file, mid, end);

        // 执行子任务
        leftTask.fork();
        rightTask.fork();

        // 合并子任务结果
        Map<String, Integer> leftResult = leftTask.join();
        Map<String, Integer> rightResult = rightTask.join();

        return mergeResults(leftResult, rightResult);
    }

    // 3. 核心工具方法:统计指定字节范围内的单词
    private Map<String, Integer> countWordsInRange() {
        Map<String, Integer> wordCount = new HashMap<>();
        try (RandomAccessFile raf = new RandomAccessFile(file, "r")) {
            // 定位到起始位置
            raf.seek(start);
            // 读取指定范围的字节(预留一点,避免截断最后一个单词)
            long readLength = end - start + 1024; // 多读取1024字节,防止单词截断
            byte[] buffer = new byte[(int) Math.min(readLength, raf.length() - start)];
            int bytesRead = raf.read(buffer);
            if (bytesRead <= 0) {
                return wordCount;
            }

            // 转成字符串,处理编码
            String content = new String(buffer, 0, bytesRead, StandardCharsets.UTF_8);
            // 统一转小写,分割单词,过滤空字符串
            String[] words = WORD_PATTERN.split(content.toLowerCase());
            for (String word : words) {
                word = word.trim();
                if (word.length() > 0) { // 过滤空单词
                    wordCount.put(word, wordCount.getOrDefault(word, 0) + 1);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return wordCount;
    }

    // 4. 工具方法:找到安全的拆分点(避免单词截断)
    private long findSafeSplitPoint(long mid) {
        try (RandomAccessFile raf = new RandomAccessFile(file, "r")) {
            // 从 mid 位置往前找,直到找到分隔符(空格、换行、标点等)
            raf.seek(mid);
            for (long i = mid; i > start; i--) {
                raf.seek(i);
                int c = raf.read();
                if (c == -1) break;
                // 判断是否是分隔符(非字母数字)
                if (!Character.isLetterOrDigit(c)) {
                    return i + 1; // 拆分点在分隔符后,避免截断单词
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return mid; // 找不到分隔符,直接用 mid(极端情况)
    }

    // 5. 工具方法:合并两个单词统计结果
    private Map<String, Integer> mergeResults(Map<String, Integer> map1, Map<String, Integer> map2) {
        // 优化:将小 map 合并到大 map,减少遍历次数
        Map<String, Integer> result = new HashMap<>(map1.size() + map2.size());
        result.putAll(map1);

        for (Map.Entry<String, Integer> entry : map2.entrySet()) {
            String word = entry.getKey();
            int count = entry.getValue();
            result.put(word, result.getOrDefault(word, 0) + count);
        }
        return result;
    }

    // 6. 入口方法:启动统计
    public static Map<String, Integer> count(File file) throws IOException {
        if (!file.exists() || !file.isFile()) {
            throw new IllegalArgumentException("文件不存在或不是普通文件");
        }
        long fileLength = file.length();
        if (fileLength == 0) {
            return Collections.emptyMap();
        }

        // 使用 ForkJoin 池执行任务
        ForkJoinPool pool = new ForkJoinPool(); // 默认可用 CPU 核心数
        BigFileWordCount task = new BigFileWordCount(file, 0, fileLength);
        return pool.invoke(task);
    }

    // 测试主方法
    public static void main(String[] args) {
        try {
            // 替换为你的超大文本文件路径(比如小说、日志文件)
            File bigFile = new File("D:\\big_text_file.txt");
            long startTime = System.currentTimeMillis();

            // 执行统计
            Map<String, Integer> wordCount = count(bigFile);

            long endTime = System.currentTimeMillis();
            System.out.println("统计完成,耗时:" + (endTime - startTime) + "ms");
            System.out.println("总单词数:" + wordCount.values().stream().mapToInt(Integer::intValue).sum());

            // 输出出现次数前10的单词
            wordCount.entrySet().stream()
                    .sorted((e1, e2) -> Integer.compare(e2.getValue(), e1.getValue()))
                    .limit(10)
                    .forEach(entry -> System.out.println(entry.getKey() + " : " + entry.getValue()));

        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

三、关键代码解释

任务拆分阈值(THRESHOLD)

设置为 1MB,每个子任务处理最多 1MB 的数据,可根据服务器内存调整(比如 2MB、4MB)。阈值太小会增加拆分/合并的开销,太大则并行度不足。

安全拆分点(findSafeSplitPoint)

核心解决「单词跨片段」问题:拆分时不会直接按字节数拆分,而是找到最近的分隔符(空格、标点等),确保一个单词完整落在一个片段中。

单词统计(countWordsInRange)
  • 使用 RandomAccessFile 随机读取文件指定字节范围,避免加载整个文件到内存
  • 正则 [^a-zA-Z0-9]+ 将非字母数字的字符都作为分隔符,统一转小写避免「Hello」和「hello」被算作不同单词
结果合并(mergeResults)

合并两个 HashMap,将相同单词的计数累加,保证统计结果的正确性。

四、总结

  • 核心逻辑:通过 Fork/Join 将大文件拆分为小片段并行统计单词,合并结果,兼顾性能和内存安全
  • 关键优化:安全拆分点避免单词截断,按需读取文件片段避免内存溢出,并行计算提升统计效率
  • 适用场景:GB 级及以上的超大文本文件单词统计,相比单线程效率提升数倍(取决于 CPU 核心数)

5.2 Fork/Join 实现并行归并排序(适配超大数组)

java
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;

/**
 * Fork/Join 实现并行归并排序(适配超大数组)
 */
public class ParallelMergeSort extends RecursiveAction {
    // 任务拆分阈值:子数组长度小于该值时,用单线程排序(减少拆分开销)
    private static final int THRESHOLD = 10000;
    private final int[] array;
    private final int left;
    private final int right;

    // 构造方法:指定要排序的数组范围
    public ParallelMergeSort(int[] array, int left, int right) {
        this.array = array;
        this.left = left;
        this.right = right;
    }

    // 核心方法:拆分任务 + 排序 + 合并
    @Override
    protected void compute() {
        // 1. 子数组足够小,直接用单线程排序(Arrays.sort 是优化后的双轴快排,小数据效率高)
        if (right - left <= THRESHOLD) {
            Arrays.sort(array, left, right + 1);
            return;
        }

        // 2. 拆分:将数组分成左右两部分
        int mid = left + (right - left) / 2;
        ParallelMergeSort leftTask = new ParallelMergeSort(array, left, mid);
        ParallelMergeSort rightTask = new ParallelMergeSort(array, mid + 1, right);

        // 3. Fork:异步执行子任务(并行排序左右子数组)
        leftTask.fork();
        rightTask.fork();

        // 4. Join:等待子任务完成
        leftTask.join();
        rightTask.join();

        // 5. 合并:将两个有序子数组合并为一个有序数组
        merge(left, mid, right);
    }

    // 归并排序的核心合并方法
    private void merge(int left, int mid, int right) {
        // 临时数组存储合并结果
        int[] temp = new int[right - left + 1];
        int i = left;    // 左子数组指针
        int j = mid + 1; // 右子数组指针
        int k = 0;       // 临时数组指针

        // 合并两个有序子数组
        while (i <= mid && j <= right) {
            if (array[i] <= array[j]) {
                temp[k++] = array[i++];
            } else {
                temp[k++] = array[j++];
            }
        }

        // 拷贝左子数组剩余元素
        while (i <= mid) {
            temp[k++] = array[i++];
        }

        // 拷贝右子数组剩余元素
        while (j <= right) {
            temp[k++] = array[j++];
        }

        // 将临时数组的结果拷贝回原数组
        System.arraycopy(temp, 0, array, left, temp.length);
    }

    // 对外提供的排序入口方法
    public static void sort(int[] array) {
        if (array == null || array.length <= 1) {
            return;
        }
        // 使用 ForkJoin 池执行并行排序(默认线程数 = CPU 核心数)
        ForkJoinPool pool = ForkJoinPool.commonPool();
        pool.invoke(new ParallelMergeSort(array, 0, array.length - 1));
    }

    // 测试:对比并行归并排序和单线程排序的耗时
    public static void main(String[] args) {
        // 生成 1000 万个随机整数的超大数组
        int[] array = new int[10_000_000];
        Random random = new Random();
        for (int i = 0; i < array.length; i++) {
            array[i] = random.nextInt(Integer.MAX_VALUE);
        }

        // 1. 并行归并排序(Fork/Join)
        int[] parallelArray = Arrays.copyOf(array, array.length);
        long startTime = System.currentTimeMillis();
        ParallelMergeSort.sort(parallelArray);
        long parallelTime = System.currentTimeMillis() - startTime;
        System.out.println("Fork/Join 并行归并排序耗时:" + parallelTime + "ms");

        // 2. 单线程排序(Arrays.sort)
        int[] singleArray = Arrays.copyOf(array, array.length);
        startTime = System.currentTimeMillis();
        Arrays.sort(singleArray);
        long singleTime = System.currentTimeMillis() - startTime;
        System.out.println("单线程排序耗时:" + singleTime + "ms");

        // 验证排序结果一致
        System.out.println("排序结果是否一致:" + Arrays.equals(parallelArray, singleArray));
    }
}
text
Fork/Join 并行归并排序耗时:931ms
单线程排序耗时:1494ms
排序结果是否一致:true