The Fork/Join Framework is a special implementation of the ExecutorService that helps take advantage of multiple processors for parallel processing. It's designed for work that can be broken down into smaller pieces recursively.
1. Introduction to Fork/Join
Key Concepts
- Fork: Splitting a task into smaller subtasks
- Join: Waiting for subtasks to complete and combining results
- Work Stealing: Efficient thread utilization where idle threads steal work from busy threads
Core Components
ForkJoinPool- Special thread pool for fork/join tasksForkJoinTask- Abstract base class for tasksRecursiveTask- Returns a resultRecursiveAction- No result returned
2. Basic Setup and Imports
import java.util.concurrent.ForkJoinPool; import java.util.concurrent.RecursiveTask; import java.util.concurrent.RecursiveAction; import java.util.concurrent.ForkJoinTask; import java.util.concurrent.TimeUnit;
3. RecursiveTask (Returns Result)
Example: Summing Array Elements
public class SumTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 10_000;
private final int[] array;
private final int start;
private final int end;
public SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = end - start;
// If below threshold, compute directly
if (length <= THRESHOLD) {
return computeDirectly();
}
// Split task into subtasks
int mid = start + (end - start) / 2;
SumTask leftTask = new SumTask(array, start, mid);
SumTask rightTask = new SumTask(array, mid, end);
// Fork the left task (execute asynchronously)
leftTask.fork();
// Compute right task directly and wait for left task
Long rightResult = rightTask.compute();
Long leftResult = leftTask.join();
return leftResult + rightResult;
}
private Long computeDirectly() {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}
}
Using the SumTask
public class SumTaskExample {
public static void main(String[] args) {
// Create a large array
int[] numbers = new int[1_000_000];
for (int i = 0; i < numbers.length; i++) {
numbers[i] = i + 1;
}
// Create ForkJoinPool
ForkJoinPool pool = new ForkJoinPool();
// Create task
SumTask task = new SumTask(numbers, 0, numbers.length);
// Execute and get result
long result = pool.invoke(task);
System.out.println("Sum: " + result);
System.out.println("Expected: " + ((long)numbers.length * (numbers.length + 1)) / 2);
pool.shutdown();
}
}
4. RecursiveAction (No Result)
Example: Parallel Array Processing
public class ArrayProcessor extends RecursiveAction {
private static final int THRESHOLD = 10_000;
private final int[] array;
private final int start;
private final int end;
private final int multiplier;
public ArrayProcessor(int[] array, int start, int end, int multiplier) {
this.array = array;
this.start = start;
this.end = end;
this.multiplier = multiplier;
}
@Override
protected void compute() {
int length = end - start;
if (length <= THRESHOLD) {
processDirectly();
} else {
// Split the task
int mid = start + (end - start) / 2;
ArrayProcessor leftTask = new ArrayProcessor(array, start, mid, multiplier);
ArrayProcessor rightTask = new ArrayProcessor(array, mid, end, multiplier);
// Invoke all tasks - more efficient than separate fork/join
invokeAll(leftTask, rightTask);
}
}
private void processDirectly() {
for (int i = start; i < end; i++) {
array[i] = array[i] * multiplier;
}
}
}
Using RecursiveAction
public class ArrayProcessorExample {
public static void main(String[] args) {
int[] data = new int[100_000];
Arrays.fill(data, 5);
ForkJoinPool pool = new ForkJoinPool();
ArrayProcessor task = new ArrayProcessor(data, 0, data.length, 3);
pool.invoke(task);
// Verify results
boolean allCorrect = true;
for (int value : data) {
if (value != 15) {
allCorrect = false;
break;
}
}
System.out.println("All values correctly multiplied: " + allCorrect);
System.out.println("First 10 values: " +
Arrays.toString(Arrays.copyOf(data, 10)));
pool.shutdown();
}
}
5. Advanced Example: Parallel Merge Sort
public class ParallelMergeSort extends RecursiveAction {
private static final int THRESHOLD = 10_000;
private final int[] array;
private final int start;
private final int end;
private final int[] temp;
public ParallelMergeSort(int[] array, int start, int end, int[] temp) {
this.array = array;
this.start = start;
this.end = end;
this.temp = temp;
}
@Override
protected void compute() {
int length = end - start;
if (length <= THRESHOLD) {
// Use sequential sort for small arrays
Arrays.sort(array, start, end);
return;
}
int mid = start + (end - start) / 2;
// Create subtasks
ParallelMergeSort leftTask = new ParallelMergeSort(array, start, mid, temp);
ParallelMergeSort rightTask = new ParallelMergeSort(array, mid, end, temp);
// Execute subtasks in parallel
invokeAll(leftTask, rightTask);
// Merge results
merge(array, start, mid, end, temp);
}
private void merge(int[] array, int start, int mid, int end, int[] temp) {
System.arraycopy(array, start, temp, start, end - start);
int i = start, j = mid, k = start;
while (i < mid && j < end) {
if (temp[i] <= temp[j]) {
array[k++] = temp[i++];
} else {
array[k++] = temp[j++];
}
}
while (i < mid) {
array[k++] = temp[i++];
}
while (j < end) {
array[k++] = temp[j++];
}
}
public static void sort(int[] array) {
int[] temp = new int[array.length];
ForkJoinPool pool = new ForkJoinPool();
pool.invoke(new ParallelMergeSort(array, 0, array.length, temp));
pool.shutdown();
}
}
Testing Parallel Merge Sort
public class MergeSortExample {
public static void main(String[] args) {
int[] largeArray = new int[1_000_000];
Random random = new Random();
for (int i = 0; i < largeArray.length; i++) {
largeArray[i] = random.nextInt(1_000_000);
}
int[] sequentialArray = largeArray.clone();
int[] parallelArray = largeArray.clone();
// Sequential sort
long startTime = System.currentTimeMillis();
Arrays.sort(sequentialArray);
long sequentialTime = System.currentTimeMillis() - startTime;
// Parallel sort
startTime = System.currentTimeMillis();
ParallelMergeSort.sort(parallelArray);
long parallelTime = System.currentTimeMillis() - startTime;
System.out.println("Sequential sort time: " + sequentialTime + "ms");
System.out.println("Parallel sort time: " + parallelTime + "ms");
System.out.println("Speedup: " + (double)sequentialTime/parallelTime + "x");
// Verify both arrays are sorted and equal
System.out.println("Arrays equal: " + Arrays.equals(sequentialArray, parallelArray));
}
}
6. Fibonacci Sequence with Fork/Join
public class FibonacciTask extends RecursiveTask<Long> {
private final int n;
public FibonacciTask(int n) {
this.n = n;
}
@Override
protected Long compute() {
if (n <= 1) {
return (long) n;
}
if (n <= 10) { // Threshold for sequential computation
return computeSequentially(n);
}
FibonacciTask first = new FibonacciTask(n - 1);
FibonacciTask second = new FibonacciTask(n - 2);
first.fork();
Long secondResult = second.compute();
Long firstResult = first.join();
return firstResult + secondResult;
}
private Long computeSequentially(int n) {
if (n <= 1) return (long) n;
long a = 0, b = 1;
for (int i = 2; i <= n; i++) {
long temp = a + b;
a = b;
b = temp;
}
return b;
}
}
Fibonacci Example
public class FibonacciExample {
public static void main(String[] args) {
ForkJoinPool pool = new ForkJoinPool();
for (int i = 0; i <= 20; i++) {
FibonacciTask task = new FibonacciTask(i);
long result = pool.invoke(task);
System.out.println("Fibonacci(" + i + ") = " + result);
}
// Compare performance
int largeN = 40;
long startTime = System.currentTimeMillis();
FibonacciTask parallelTask = new FibonacciTask(largeN);
long parallelResult = pool.invoke(parallelTask);
long parallelTime = System.currentTimeMillis() - startTime;
startTime = System.currentTimeMillis();
long sequentialResult = new FibonacciTask(largeN).computeSequentially(largeN);
long sequentialTime = System.currentTimeMillis() - startTime;
System.out.println("\nFibonacci(" + largeN + ") = " + parallelResult);
System.out.println("Parallel time: " + parallelTime + "ms");
System.out.println("Sequential time: " + sequentialTime + "ms");
System.out.println("Results match: " + (parallelResult == sequentialResult));
pool.shutdown();
}
}
7. File Processing with Fork/Join
public class FileSearchTask extends RecursiveTask<List<File>> {
private final File directory;
private final String fileExtension;
public FileSearchTask(File directory, String fileExtension) {
this.directory = directory;
this.fileExtension = fileExtension.toLowerCase();
}
@Override
protected List<File> compute() {
List<File> result = new ArrayList<>();
List<FileSearchTask> subtasks = new ArrayList<>();
File[] files = directory.listFiles();
if (files == null) {
return result;
}
for (File file : files) {
if (file.isDirectory()) {
// Create subtask for directory
FileSearchTask subtask = new FileSearchTask(file, fileExtension);
subtask.fork();
subtasks.add(subtask);
} else {
// Check file extension
if (file.getName().toLowerCase().endsWith(fileExtension)) {
result.add(file);
}
}
}
// Collect results from subtasks
for (FileSearchTask subtask : subtasks) {
result.addAll(subtask.join());
}
return result;
}
}
File Search Example
public class FileSearchExample {
public static void main(String[] args) {
File startDirectory = new File("."); // Current directory
ForkJoinPool pool = new ForkJoinPool();
FileSearchTask task = new FileSearchTask(startDirectory, ".java");
List<File> javaFiles = pool.invoke(task);
System.out.println("Found " + javaFiles.size() + " Java files:");
javaFiles.stream()
.map(File::getAbsolutePath)
.sorted()
.forEach(System.out::println);
pool.shutdown();
}
}
8. ForkJoinPool Configuration and Management
Creating Custom ForkJoinPool
public class ForkJoinPoolManagement {
public static void main(String[] args) {
// Create ForkJoinPool with custom parallelism level
int parallelism = Runtime.getRuntime().availableProcessors();
ForkJoinPool customPool = new ForkJoinPool(parallelism);
// Create common ForkJoinPool (usually preferred)
ForkJoinPool commonPool = ForkJoinPool.commonPool();
System.out.println("Custom pool parallelism: " + customPool.getParallelism());
System.out.println("Common pool parallelism: " + commonPool.getParallelism());
System.out.println("Available processors: " + Runtime.getRuntime().availableProcessors());
// Monitor pool activity
customPool.execute(() -> {
SumTask task = new SumTask(new int[100_000], 0, 100_000);
Long result = task.invoke();
System.out.println("Task result: " + result);
});
// Wait for tasks to complete
customPool.awaitQuiescence(1, TimeUnit.MINUTES);
System.out.println("Active threads: " + customPool.getActiveThreadCount());
System.out.println("Pool size: " + customPool.getPoolSize());
System.out.println("Steal count: " + customPool.getStealCount());
customPool.shutdown();
}
}
9. Best Practices and Performance Tips
1. Choose Appropriate Threshold
public class OptimizedSumTask extends RecursiveTask<Long> {
// Dynamic threshold based on array size and available processors
private static int calculateThreshold(int arraySize) {
int processors = Runtime.getRuntime().availableProcessors();
return Math.max(arraySize / (processors * 4), 10_000);
}
// Rest of implementation...
}
2. Avoid Too Many Small Tasks
@Override
protected Long compute() {
int length = end - start;
int threshold = calculateThreshold(length);
if (length <= threshold) {
return computeDirectly();
}
// Split and process...
}
3. Use invokeAll for Multiple Subtasks
// More efficient than separate fork/join
@Override
protected void compute() {
if (shouldComputeDirectly()) {
computeDirectly();
} else {
List<SubTask> subtasks = createSubtasks();
invokeAll(subtasks); // More efficient
combineResults(subtasks);
}
}
4. Handle Exceptions Properly
public class ExceptionHandlingTask extends RecursiveTask<Long> {
@Override
protected Long compute() {
try {
// Task computation
return computeInternal();
} catch (Exception e) {
// Handle or rethrow
completeExceptionally(e);
return null;
}
}
private Long computeInternal() {
// Actual computation
return 0L;
}
}
10. Common Pitfalls and Solutions
1. Stack Overflow with Deep Recursion
Solution: Use iterative approach for very deep recursion or increase threshold
2. Memory Overhead
Solution: Reuse objects where possible and avoid creating too many small tasks
3. Poor Load Balancing
Solution: Use work stealing effectively by creating reasonably sized tasks
4. Blocking Operations
Solution: Avoid blocking operations in compute() method
// Bad - blocking IO in compute method
@Override
protected Long compute() {
// This blocks the worker thread
String data = readFromNetwork();
return process(data);
}
// Better - use CompletableFuture for async IO
@Override
protected Long compute() {
CompletableFuture<String> future = CompletableFuture.supplyAsync(this::readFromNetwork);
return process(future.join());
}
Summary
The Fork/Join Framework is ideal for:
- CPU-intensive tasks that can be recursively divided
- Large data processing where work can be split
- Recursive algorithms like sorting, searching, mathematical computations
Key Benefits:
- Automatic work stealing for load balancing
- Efficient use of multiple processors
- Simplified parallel programming model
When to Use:
- Tasks are CPU-bound
- Work can be recursively divided
- You have sufficient processors
- Task granularity is appropriate (not too fine, not too coarse)