Fork/Join Framework for Parallelism in Java

The Fork/Join Framework is a Java implementation of the divide-and-conquer paradigm, designed to efficiently parallelize recursive algorithms using work-stealing.

1. Overview

What is Fork/Join?

  • Part of java.util.concurrent package (since Java 7)
  • Designed for recursive, divide-and-conquer algorithms
  • Uses work-stealing algorithm for load balancing
  • Ideal for CPU-intensive tasks

Key Components

  • ForkJoinPool - Specialized thread pool
  • ForkJoinTask - Base class for tasks
  • RecursiveTask - Returns a result
  • RecursiveAction - No result returned

2. Basic Architecture

Work-Stealing Algorithm

Thread 1: [Task A] -> splits into [A1, A2, A3]
Thread 2: (idle) steals A3 from Thread 1
Thread 3: (idle) steals A2 from Thread 1

3. ForkJoinPool

Creating ForkJoinPool

import java.util.concurrent.ForkJoinPool;
public class ForkJoinPoolCreation {
public static void main(String[] args) {
// Method 1: Use common pool (recommended for most cases)
ForkJoinPool commonPool = ForkJoinPool.commonPool();
// Method 2: Create custom pool
ForkJoinPool customPool = new ForkJoinPool(4); // 4 threads
// Method 3: Create with parallelism level and factory
ForkJoinPool configuredPool = new ForkJoinPool(
Runtime.getRuntime().availableProcessors(),
ForkJoinPool.defaultForkJoinWorkerThreadFactory,
null, // default handler
false // async mode?
);
System.out.println("Common pool parallelism: " + commonPool.getParallelism());
System.out.println("Available processors: " + Runtime.getRuntime().availableProcessors());
}
}

4. RecursiveTask (Returns Result)

Factorial Calculation Example

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
public class FactorialTask extends RecursiveTask<Long> {
private final int start;
private final int end;
private static final int THRESHOLD = 5; // Sequential computation threshold
public FactorialTask(int n) {
this(1, n);
}
private FactorialTask(int start, int end) {
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = end - start + 1;
// If problem is small enough, compute sequentially
if (length <= THRESHOLD) {
return computeSequentially();
}
// Split task into subtasks
int mid = start + (end - start) / 2;
FactorialTask leftTask = new FactorialTask(start, mid);
FactorialTask rightTask = new FactorialTask(mid + 1, end);
// Fork the left task (execute asynchronously)
leftTask.fork();
// Compute right task and wait for left task
Long rightResult = rightTask.compute();
Long leftResult = leftTask.join();
// Combine results
return leftResult * rightResult;
}
private Long computeSequentially() {
long result = 1;
for (int i = start; i <= end; i++) {
result *= i;
// Simulate some work
try { Thread.sleep(10); } catch (InterruptedException e) {}
}
System.out.println(Thread.currentThread().getName() + 
" computed " + start + " to " + end);
return result;
}
public static void main(String[] args) {
ForkJoinPool pool = ForkJoinPool.commonPool();
FactorialTask task = new FactorialTask(10);
Long result = pool.invoke(task);
System.out.println("Factorial result: " + result);
System.out.println("Pool parallelism: " + pool.getParallelism());
}
}

Array Sum Calculation

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
import java.util.Random;
public class ArraySumTask extends RecursiveTask<Long> {
private final int[] array;
private final int start;
private final int end;
private static final int THRESHOLD = 1000;
public ArraySumTask(int[] array) {
this(array, 0, array.length);
}
private ArraySumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = end - start;
if (length <= THRESHOLD) {
return computeSequentially();
}
int mid = start + length / 2;
ArraySumTask leftTask = new ArraySumTask(array, start, mid);
ArraySumTask rightTask = new ArraySumTask(array, mid, end);
leftTask.fork();
Long rightResult = rightTask.compute();
Long leftResult = leftTask.join();
return leftResult + rightResult;
}
private Long computeSequentially() {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}
public static void main(String[] args) {
// Create large array
int[] array = new int[100000];
Random random = new Random();
for (int i = 0; i < array.length; i++) {
array[i] = random.nextInt(100);
}
// Parallel computation
long startTime = System.currentTimeMillis();
ForkJoinPool pool = ForkJoinPool.commonPool();
ArraySumTask task = new ArraySumTask(array);
Long parallelSum = pool.invoke(task);
long parallelTime = System.currentTimeMillis() - startTime;
// Sequential computation for comparison
startTime = System.currentTimeMillis();
long sequentialSum = 0;
for (int value : array) {
sequentialSum += value;
}
long sequentialTime = System.currentTimeMillis() - startTime;
System.out.println("Parallel sum: " + parallelSum + " in " + parallelTime + "ms");
System.out.println("Sequential sum: " + sequentialSum + " in " + sequentialTime + "ms");
System.out.println("Results match: " + (parallelSum == sequentialSum));
}
}

5. RecursiveAction (No Result)

Array Processing Example

import java.util.concurrent.RecursiveAction;
import java.util.concurrent.ForkJoinPool;
import java.util.Random;
public class ArrayTransformAction extends RecursiveAction {
private final double[] array;
private final int start;
private final int end;
private static final int THRESHOLD = 10000;
public ArrayTransformAction(double[] array) {
this(array, 0, array.length);
}
private ArrayTransformAction(double[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected void compute() {
int length = end - start;
if (length <= THRESHOLD) {
transformSequentially();
return;
}
int mid = start + length / 2;
ArrayTransformAction leftTask = new ArrayTransformAction(array, start, mid);
ArrayTransformAction rightTask = new ArrayTransformAction(array, mid, end);
invokeAll(leftTask, rightTask); // Fork and wait for both
}
private void transformSequentially() {
for (int i = start; i < end; i++) {
// Apply some transformation
array[i] = Math.sqrt(Math.sin(array[i]) * Math.cos(array[i]));
}
System.out.println(Thread.currentThread().getName() + 
" processed " + start + " to " + end);
}
public static void main(String[] args) {
double[] array = new double[100000];
Random random = new Random();
for (int i = 0; i < array.length; i++) {
array[i] = random.nextDouble() * 100;
}
ForkJoinPool pool = new ForkJoinPool();
ArrayTransformAction task = new ArrayTransformAction(array);
pool.invoke(task);
System.out.println("Array transformation completed");
// Verify some values
for (int i = 0; i < 5; i++) {
System.out.printf("array[%d] = %.4f%n", i, array[i]);
}
}
}

6. Fibonacci Sequence (Classic Example)

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
public class FibonacciTask extends RecursiveTask<Long> {
private final int n;
public FibonacciTask(int n) {
this.n = n;
}
@Override
protected Long compute() {
if (n <= 1) {
return (long) n;
}
// For very small n, compute sequentially to avoid overhead
if (n <= 10) {
return computeSequentially(n);
}
FibonacciTask leftTask = new FibonacciTask(n - 1);
FibonacciTask rightTask = new FibonacciTask(n - 2);
leftTask.fork();
Long rightResult = rightTask.compute();
Long leftResult = leftTask.join();
return leftResult + rightResult;
}
private Long computeSequentially(int n) {
if (n <= 1) return (long) n;
long a = 0, b = 1, result = 0;
for (int i = 2; i <= n; i++) {
result = a + b;
a = b;
b = result;
}
return result;
}
public static void main(String[] args) {
int n = 25;
// Parallel computation
long startTime = System.currentTimeMillis();
ForkJoinPool pool = ForkJoinPool.commonPool();
FibonacciTask task = new FibonacciTask(n);
Long parallelResult = pool.invoke(task);
long parallelTime = System.currentTimeMillis() - startTime;
// Sequential computation
startTime = System.currentTimeMillis();
Long sequentialResult = new FibonacciTask(n).computeSequentially(n);
long sequentialTime = System.currentTimeMillis() - startTime;
System.out.println("Fibonacci(" + n + ") = " + parallelResult);
System.out.println("Parallel time: " + parallelTime + "ms");
System.out.println("Sequential time: " + sequentialTime + "ms");
System.out.println("Results match: " + (parallelResult.equals(sequentialResult)));
}
}

7. Advanced Features

invokeAll() Method

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
import java.util.Arrays;
public class InvokeAllExample extends RecursiveTask<Long> {
private final int[] array;
private final int start;
private final int end;
private static final int THRESHOLD = 100;
public InvokeAllExample(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = end - start;
if (length <= THRESHOLD) {
return Arrays.stream(array, start, end)
.asLongStream()
.sum();
}
int mid = start + length / 2;
InvokeAllExample leftTask = new InvokeAllExample(array, start, mid);
InvokeAllExample rightTask = new InvokeAllExample(array, mid, end);
// Fork both tasks and wait for them - more efficient than fork()/join()
invokeAll(leftTask, rightTask);
return leftTask.join() + rightTask.join();
}
public static void main(String[] args) {
int[] array = new int[1000];
Arrays.fill(array, 1);
ForkJoinPool pool = ForkJoinPool.commonPool();
InvokeAllExample task = new InvokeAllExample(array, 0, array.length);
Long result = pool.invoke(task);
System.out.println("Sum: " + result);
System.out.println("Expected: " + array.length);
}
}

Exception Handling in Fork/Join

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
public class ExceptionHandlingTask extends RecursiveTask<Integer> {
private final int[] array;
private final int start;
private final int end;
public ExceptionHandlingTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
try {
if (end - start <= 1) {
if (array[start] < 0) {
throw new IllegalArgumentException("Negative number found: " + array[start]);
}
return array[start];
}
int mid = start + (end - start) / 2;
ExceptionHandlingTask leftTask = new ExceptionHandlingTask(array, start, mid);
ExceptionHandlingTask rightTask = new ExceptionHandlingTask(array, mid, end);
leftTask.fork();
int rightResult = rightTask.compute();
int leftResult = leftTask.join();
return leftResult + rightResult;
} catch (Exception e) {
// Re-throw as RuntimeException to propagate
throw new RuntimeException("Error in subtask", e);
}
}
public static void main(String[] args) {
int[] array = {1, 2, 3, -4, 5, 6}; // Contains negative number
ForkJoinPool pool = ForkJoinPool.commonPool();
ExceptionHandlingTask task = new ExceptionHandlingTask(array, 0, array.length);
try {
Integer result = pool.invoke(task);
System.out.println("Result: " + result);
} catch (RuntimeException e) {
System.out.println("Caught exception: " + e.getMessage());
System.out.println("Root cause: " + e.getCause().getMessage());
}
}
}

8. Performance Considerations

Threshold Tuning

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
public class ThresholdTuning extends RecursiveTask<Long> {
private final long[] array;
private final int start;
private final int end;
private final int threshold;
public ThresholdTuning(long[] array, int start, int end, int threshold) {
this.array = array;
this.start = start;
this.end = end;
this.threshold = threshold;
}
@Override
protected Long compute() {
int length = end - start;
if (length <= threshold) {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}
int mid = start + length / 2;
ThresholdTuning leftTask = new ThresholdTuning(array, start, mid, threshold);
ThresholdTuning rightTask = new ThresholdTuning(array, mid, end, threshold);
leftTask.fork();
Long rightResult = rightTask.compute();
Long leftResult = leftTask.join();
return leftResult + rightResult;
}
public static void main(String[] args) {
long[] array = new long[1000000];
for (int i = 0; i < array.length; i++) {
array[i] = i + 1;
}
int[] thresholds = {100, 1000, 10000, 100000};
for (int threshold : thresholds) {
long startTime = System.currentTimeMillis();
ForkJoinPool pool = ForkJoinPool.commonPool();
ThresholdTuning task = new ThresholdTuning(array, 0, array.length, threshold);
Long result = pool.invoke(task);
long duration = System.currentTimeMillis() - startTime;
System.out.printf("Threshold: %7d, Time: %4dms, Result: %d%n", 
threshold, duration, result);
}
}
}

9. Real-World Example: File Processing

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
import java.io.IOException;
import java.nio.file.*;
import java.util.ArrayList;
import java.util.List;
public class FileSearchTask extends RecursiveTask<List<Path>> {
private final Path directory;
private final String fileExtension;
private static final int THRESHOLD = 100; // Max directories to process sequentially
public FileSearchTask(Path directory, String fileExtension) {
this.directory = directory;
this.fileExtension = fileExtension;
}
@Override
protected List<Path> compute() {
List<Path> files = new ArrayList<>();
List<FileSearchTask> subtasks = new ArrayList<>();
try (DirectoryStream<Path> stream = Files.newDirectoryStream(directory)) {
for (Path entry : stream) {
if (Files.isDirectory(entry)) {
// Create subtask for subdirectory
FileSearchTask subtask = new FileSearchTask(entry, fileExtension);
subtask.fork();
subtasks.add(subtask);
} else if (entry.toString().endsWith(fileExtension)) {
files.add(entry);
}
}
} catch (IOException e) {
System.err.println("Error reading directory: " + e.getMessage());
}
// If too many subtasks, process some sequentially
if (subtasks.size() > THRESHOLD) {
// Process first half of subtasks sequentially
for (int i = 0; i < subtasks.size() / 2; i++) {
files.addAll(subtasks.get(i).compute());
}
// Wait for second half
for (int i = subtasks.size() / 2; i < subtasks.size(); i++) {
files.addAll(subtasks.get(i).join());
}
} else {
// Wait for all subtasks
for (FileSearchTask subtask : subtasks) {
files.addAll(subtask.join());
}
}
return files;
}
public static void main(String[] args) {
Path startDir = Paths.get("."); // Current directory
String extension = ".java";
ForkJoinPool pool = ForkJoinPool.commonPool();
FileSearchTask task = new FileSearchTask(startDir, extension);
long startTime = System.currentTimeMillis();
List<Path> foundFiles = pool.invoke(task);
long duration = System.currentTimeMillis() - startTime;
System.out.println("Found " + foundFiles.size() + " " + extension + " files in " + duration + "ms");
foundFiles.forEach(System.out::println);
}
}

10. Best Practices

1. Choose Appropriate Threshold

// Good threshold depends on:
// - Task complexity
// - Number of available processors
// - Overhead of task creation
private static final int THRESHOLD = 
Math.max(array.length / (Runtime.getRuntime().availableProcessors() * 4), 1000);

2. Avoid Blocking Operations

// Bad - blocking I/O in compute method
protected Long compute() {
// Reading from network/file - blocks thread
String data = readFromNetwork(); // AVOID THIS
return process(data);
}
// Good - use CompletableFuture for I/O
protected Long compute() {
CompletableFuture<String> future = CompletableFuture.supplyAsync(this::readFromNetwork);
return process(future.join());
}

3. Use invokeAll for Multiple Subtasks

// More efficient than sequential fork/join
protected Result compute() {
List<Subtask> subtasks = createSubtasks();
// Fork all and then join all
invokeAll(subtasks);
return combineResults(subtasks);
}

4. Handle Exceptions Properly

@Override
protected Result compute() {
try {
// computation
} catch (Exception e) {
// Log and rethrow or return default value
completeExceptionally(e);
return null;
}
}

Summary

The Fork/Join Framework is ideal for:

  • Recursive algorithms that can be divided
  • CPU-intensive tasks with minimal I/O
  • Data parallelism scenarios
  • Divide-and-conquer problems

Key Benefits:

  • Automatic workload balancing
  • Efficient use of multiple cores
  • Minimal synchronization overhead
  • Elegant recursive problem solving

When to Avoid:

  • I/O-bound operations
  • Tasks with dependencies
  • Very small problem sizes
  • Simple sequential operations

Leave a Reply

Your email address will not be published. Required fields are marked *


Macro Nepal Helper