Fork/Join Framework in Java

The Fork/Join Framework is a special implementation of the ExecutorService that helps take advantage of multiple processors for parallel processing. It's designed for work that can be broken down into smaller pieces recursively.

1. Introduction to Fork/Join

Key Concepts

  • Fork: Splitting a task into smaller subtasks
  • Join: Waiting for subtasks to complete and combining results
  • Work Stealing: Efficient thread utilization where idle threads steal work from busy threads

Core Components

  • ForkJoinPool - Special thread pool for fork/join tasks
  • ForkJoinTask - Abstract base class for tasks
  • RecursiveTask - Returns a result
  • RecursiveAction - No result returned

2. Basic Setup and Imports

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.TimeUnit;

3. RecursiveTask (Returns Result)

Example: Summing Array Elements

public class SumTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 10_000;
private final int[] array;
private final int start;
private final int end;
public SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = end - start;
// If below threshold, compute directly
if (length <= THRESHOLD) {
return computeDirectly();
}
// Split task into subtasks
int mid = start + (end - start) / 2;
SumTask leftTask = new SumTask(array, start, mid);
SumTask rightTask = new SumTask(array, mid, end);
// Fork the left task (execute asynchronously)
leftTask.fork();
// Compute right task directly and wait for left task
Long rightResult = rightTask.compute();
Long leftResult = leftTask.join();
return leftResult + rightResult;
}
private Long computeDirectly() {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}
}

Using the SumTask

public class SumTaskExample {
public static void main(String[] args) {
// Create a large array
int[] numbers = new int[1_000_000];
for (int i = 0; i < numbers.length; i++) {
numbers[i] = i + 1;
}
// Create ForkJoinPool
ForkJoinPool pool = new ForkJoinPool();
// Create task
SumTask task = new SumTask(numbers, 0, numbers.length);
// Execute and get result
long result = pool.invoke(task);
System.out.println("Sum: " + result);
System.out.println("Expected: " + ((long)numbers.length * (numbers.length + 1)) / 2);
pool.shutdown();
}
}

4. RecursiveAction (No Result)

Example: Parallel Array Processing

public class ArrayProcessor extends RecursiveAction {
private static final int THRESHOLD = 10_000;
private final int[] array;
private final int start;
private final int end;
private final int multiplier;
public ArrayProcessor(int[] array, int start, int end, int multiplier) {
this.array = array;
this.start = start;
this.end = end;
this.multiplier = multiplier;
}
@Override
protected void compute() {
int length = end - start;
if (length <= THRESHOLD) {
processDirectly();
} else {
// Split the task
int mid = start + (end - start) / 2;
ArrayProcessor leftTask = new ArrayProcessor(array, start, mid, multiplier);
ArrayProcessor rightTask = new ArrayProcessor(array, mid, end, multiplier);
// Invoke all tasks - more efficient than separate fork/join
invokeAll(leftTask, rightTask);
}
}
private void processDirectly() {
for (int i = start; i < end; i++) {
array[i] = array[i] * multiplier;
}
}
}

Using RecursiveAction

public class ArrayProcessorExample {
public static void main(String[] args) {
int[] data = new int[100_000];
Arrays.fill(data, 5);
ForkJoinPool pool = new ForkJoinPool();
ArrayProcessor task = new ArrayProcessor(data, 0, data.length, 3);
pool.invoke(task);
// Verify results
boolean allCorrect = true;
for (int value : data) {
if (value != 15) {
allCorrect = false;
break;
}
}
System.out.println("All values correctly multiplied: " + allCorrect);
System.out.println("First 10 values: " + 
Arrays.toString(Arrays.copyOf(data, 10)));
pool.shutdown();
}
}

5. Advanced Example: Parallel Merge Sort

public class ParallelMergeSort extends RecursiveAction {
private static final int THRESHOLD = 10_000;
private final int[] array;
private final int start;
private final int end;
private final int[] temp;
public ParallelMergeSort(int[] array, int start, int end, int[] temp) {
this.array = array;
this.start = start;
this.end = end;
this.temp = temp;
}
@Override
protected void compute() {
int length = end - start;
if (length <= THRESHOLD) {
// Use sequential sort for small arrays
Arrays.sort(array, start, end);
return;
}
int mid = start + (end - start) / 2;
// Create subtasks
ParallelMergeSort leftTask = new ParallelMergeSort(array, start, mid, temp);
ParallelMergeSort rightTask = new ParallelMergeSort(array, mid, end, temp);
// Execute subtasks in parallel
invokeAll(leftTask, rightTask);
// Merge results
merge(array, start, mid, end, temp);
}
private void merge(int[] array, int start, int mid, int end, int[] temp) {
System.arraycopy(array, start, temp, start, end - start);
int i = start, j = mid, k = start;
while (i < mid && j < end) {
if (temp[i] <= temp[j]) {
array[k++] = temp[i++];
} else {
array[k++] = temp[j++];
}
}
while (i < mid) {
array[k++] = temp[i++];
}
while (j < end) {
array[k++] = temp[j++];
}
}
public static void sort(int[] array) {
int[] temp = new int[array.length];
ForkJoinPool pool = new ForkJoinPool();
pool.invoke(new ParallelMergeSort(array, 0, array.length, temp));
pool.shutdown();
}
}

Testing Parallel Merge Sort

public class MergeSortExample {
public static void main(String[] args) {
int[] largeArray = new int[1_000_000];
Random random = new Random();
for (int i = 0; i < largeArray.length; i++) {
largeArray[i] = random.nextInt(1_000_000);
}
int[] sequentialArray = largeArray.clone();
int[] parallelArray = largeArray.clone();
// Sequential sort
long startTime = System.currentTimeMillis();
Arrays.sort(sequentialArray);
long sequentialTime = System.currentTimeMillis() - startTime;
// Parallel sort
startTime = System.currentTimeMillis();
ParallelMergeSort.sort(parallelArray);
long parallelTime = System.currentTimeMillis() - startTime;
System.out.println("Sequential sort time: " + sequentialTime + "ms");
System.out.println("Parallel sort time: " + parallelTime + "ms");
System.out.println("Speedup: " + (double)sequentialTime/parallelTime + "x");
// Verify both arrays are sorted and equal
System.out.println("Arrays equal: " + Arrays.equals(sequentialArray, parallelArray));
}
}

6. Fibonacci Sequence with Fork/Join

public class FibonacciTask extends RecursiveTask<Long> {
private final int n;
public FibonacciTask(int n) {
this.n = n;
}
@Override
protected Long compute() {
if (n <= 1) {
return (long) n;
}
if (n <= 10) { // Threshold for sequential computation
return computeSequentially(n);
}
FibonacciTask first = new FibonacciTask(n - 1);
FibonacciTask second = new FibonacciTask(n - 2);
first.fork();
Long secondResult = second.compute();
Long firstResult = first.join();
return firstResult + secondResult;
}
private Long computeSequentially(int n) {
if (n <= 1) return (long) n;
long a = 0, b = 1;
for (int i = 2; i <= n; i++) {
long temp = a + b;
a = b;
b = temp;
}
return b;
}
}

Fibonacci Example

public class FibonacciExample {
public static void main(String[] args) {
ForkJoinPool pool = new ForkJoinPool();
for (int i = 0; i <= 20; i++) {
FibonacciTask task = new FibonacciTask(i);
long result = pool.invoke(task);
System.out.println("Fibonacci(" + i + ") = " + result);
}
// Compare performance
int largeN = 40;
long startTime = System.currentTimeMillis();
FibonacciTask parallelTask = new FibonacciTask(largeN);
long parallelResult = pool.invoke(parallelTask);
long parallelTime = System.currentTimeMillis() - startTime;
startTime = System.currentTimeMillis();
long sequentialResult = new FibonacciTask(largeN).computeSequentially(largeN);
long sequentialTime = System.currentTimeMillis() - startTime;
System.out.println("\nFibonacci(" + largeN + ") = " + parallelResult);
System.out.println("Parallel time: " + parallelTime + "ms");
System.out.println("Sequential time: " + sequentialTime + "ms");
System.out.println("Results match: " + (parallelResult == sequentialResult));
pool.shutdown();
}
}

7. File Processing with Fork/Join

public class FileSearchTask extends RecursiveTask<List<File>> {
private final File directory;
private final String fileExtension;
public FileSearchTask(File directory, String fileExtension) {
this.directory = directory;
this.fileExtension = fileExtension.toLowerCase();
}
@Override
protected List<File> compute() {
List<File> result = new ArrayList<>();
List<FileSearchTask> subtasks = new ArrayList<>();
File[] files = directory.listFiles();
if (files == null) {
return result;
}
for (File file : files) {
if (file.isDirectory()) {
// Create subtask for directory
FileSearchTask subtask = new FileSearchTask(file, fileExtension);
subtask.fork();
subtasks.add(subtask);
} else {
// Check file extension
if (file.getName().toLowerCase().endsWith(fileExtension)) {
result.add(file);
}
}
}
// Collect results from subtasks
for (FileSearchTask subtask : subtasks) {
result.addAll(subtask.join());
}
return result;
}
}

File Search Example

public class FileSearchExample {
public static void main(String[] args) {
File startDirectory = new File("."); // Current directory
ForkJoinPool pool = new ForkJoinPool();
FileSearchTask task = new FileSearchTask(startDirectory, ".java");
List<File> javaFiles = pool.invoke(task);
System.out.println("Found " + javaFiles.size() + " Java files:");
javaFiles.stream()
.map(File::getAbsolutePath)
.sorted()
.forEach(System.out::println);
pool.shutdown();
}
}

8. ForkJoinPool Configuration and Management

Creating Custom ForkJoinPool

public class ForkJoinPoolManagement {
public static void main(String[] args) {
// Create ForkJoinPool with custom parallelism level
int parallelism = Runtime.getRuntime().availableProcessors();
ForkJoinPool customPool = new ForkJoinPool(parallelism);
// Create common ForkJoinPool (usually preferred)
ForkJoinPool commonPool = ForkJoinPool.commonPool();
System.out.println("Custom pool parallelism: " + customPool.getParallelism());
System.out.println("Common pool parallelism: " + commonPool.getParallelism());
System.out.println("Available processors: " + Runtime.getRuntime().availableProcessors());
// Monitor pool activity
customPool.execute(() -> {
SumTask task = new SumTask(new int[100_000], 0, 100_000);
Long result = task.invoke();
System.out.println("Task result: " + result);
});
// Wait for tasks to complete
customPool.awaitQuiescence(1, TimeUnit.MINUTES);
System.out.println("Active threads: " + customPool.getActiveThreadCount());
System.out.println("Pool size: " + customPool.getPoolSize());
System.out.println("Steal count: " + customPool.getStealCount());
customPool.shutdown();
}
}

9. Best Practices and Performance Tips

1. Choose Appropriate Threshold

public class OptimizedSumTask extends RecursiveTask<Long> {
// Dynamic threshold based on array size and available processors
private static int calculateThreshold(int arraySize) {
int processors = Runtime.getRuntime().availableProcessors();
return Math.max(arraySize / (processors * 4), 10_000);
}
// Rest of implementation...
}

2. Avoid Too Many Small Tasks

@Override
protected Long compute() {
int length = end - start;
int threshold = calculateThreshold(length);
if (length <= threshold) {
return computeDirectly();
}
// Split and process...
}

3. Use invokeAll for Multiple Subtasks

// More efficient than separate fork/join
@Override
protected void compute() {
if (shouldComputeDirectly()) {
computeDirectly();
} else {
List<SubTask> subtasks = createSubtasks();
invokeAll(subtasks); // More efficient
combineResults(subtasks);
}
}

4. Handle Exceptions Properly

public class ExceptionHandlingTask extends RecursiveTask<Long> {
@Override
protected Long compute() {
try {
// Task computation
return computeInternal();
} catch (Exception e) {
// Handle or rethrow
completeExceptionally(e);
return null;
}
}
private Long computeInternal() {
// Actual computation
return 0L;
}
}

10. Common Pitfalls and Solutions

1. Stack Overflow with Deep Recursion

Solution: Use iterative approach for very deep recursion or increase threshold

2. Memory Overhead

Solution: Reuse objects where possible and avoid creating too many small tasks

3. Poor Load Balancing

Solution: Use work stealing effectively by creating reasonably sized tasks

4. Blocking Operations

Solution: Avoid blocking operations in compute() method

// Bad - blocking IO in compute method
@Override
protected Long compute() {
// This blocks the worker thread
String data = readFromNetwork();
return process(data);
}
// Better - use CompletableFuture for async IO
@Override
protected Long compute() {
CompletableFuture<String> future = CompletableFuture.supplyAsync(this::readFromNetwork);
return process(future.join());
}

Summary

The Fork/Join Framework is ideal for:

  • CPU-intensive tasks that can be recursively divided
  • Large data processing where work can be split
  • Recursive algorithms like sorting, searching, mathematical computations

Key Benefits:

  • Automatic work stealing for load balancing
  • Efficient use of multiple processors
  • Simplified parallel programming model

When to Use:

  • Tasks are CPU-bound
  • Work can be recursively divided
  • You have sufficient processors
  • Task granularity is appropriate (not too fine, not too coarse)

Leave a Reply

Your email address will not be published. Required fields are marked *


Macro Nepal Helper