RecursiveTask and RecursiveAction in Java

1. Overview

RecursiveTask and RecursiveAction are abstract classes in the Fork/Join framework designed for parallelizing recursive algorithms.

Key Differences

RecursiveTaskRecursiveAction
Returns a resultNo result returned
Extends ForkJoinTask<V>Extends ForkJoinTask<Void>
compute() returns Vcompute() returns void
Use for computations that produce resultsUse for operations that modify state

2. RecursiveTask (Returns Result)

Basic Structure

public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
protected abstract V compute();
// Other inherited methods:
// get(), join(), fork(), invoke(), etc.
}

Complete Example: Array Sum with RecursiveTask

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
import java.util.Random;
public class ArraySumRecursiveTask extends RecursiveTask<Long> {
private final int[] array;
private final int start;
private final int end;
private static final int THRESHOLD = 1000; // Sequential computation threshold
// Public constructor for initial task
public ArraySumRecursiveTask(int[] array) {
this(array, 0, array.length);
}
// Private constructor for subtasks
private ArraySumRecursiveTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = end - start;
// Base case: if problem is small enough, compute sequentially
if (length <= THRESHOLD) {
return computeSequentially();
}
// Recursive case: split problem and fork subtasks
int mid = start + (length / 2);
ArraySumRecursiveTask leftTask = new ArraySumRecursiveTask(array, start, mid);
ArraySumRecursiveTask rightTask = new ArraySumRecursiveTask(array, mid, end);
// Fork the left task (execute asynchronously)
leftTask.fork();
// Compute right task in current thread
Long rightResult = rightTask.compute();
// Wait for left task to complete and get result
Long leftResult = leftTask.join();
// Combine results
return leftResult + rightResult;
}
private Long computeSequentially() {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
System.out.println(Thread.currentThread().getName() + 
" computed range " + start + " to " + (end-1));
return sum;
}
public static void main(String[] args) {
// Create test data
int[] array = createTestArray(100000);
// Create ForkJoinPool
ForkJoinPool pool = ForkJoinPool.commonPool();
// Create task
ArraySumRecursiveTask task = new ArraySumRecursiveTask(array);
// Execute and get result
long startTime = System.currentTimeMillis();
Long result = pool.invoke(task);
long endTime = System.currentTimeMillis();
// Verify result
long sequentialSum = computeSequentially(array);
System.out.println("Parallel result: " + result);
System.out.println("Sequential result: " + sequentialSum);
System.out.println("Results match: " + (result.equals(sequentialSum)));
System.out.println("Time taken: " + (endTime - startTime) + "ms");
System.out.println("Pool parallelism: " + pool.getParallelism());
}
private static int[] createTestArray(int size) {
int[] array = new int[size];
Random random = new Random();
for (int i = 0; i < size; i++) {
array[i] = random.nextInt(100);
}
return array;
}
private static long computeSequentially(int[] array) {
long sum = 0;
for (int value : array) {
sum += value;
}
return sum;
}
}

Advanced RecursiveTask: Matrix Multiplication

import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinPool;
public class MatrixMultiplicationTask extends RecursiveTask<int[][]> {
private final int[][] matrixA;
private final int[][] matrixB;
private final int rowStart;
private final int rowEnd;
private final int colStart;
private final int colEnd;
private static final int THRESHOLD = 64; // Threshold for sequential computation
public MatrixMultiplicationTask(int[][] matrixA, int[][] matrixB) {
this(matrixA, matrixB, 0, matrixA.length, 0, matrixB[0].length);
}
private MatrixMultiplicationTask(int[][] matrixA, int[][] matrixB, 
int rowStart, int rowEnd, int colStart, int colEnd) {
this.matrixA = matrixA;
this.matrixB = matrixB;
this.rowStart = rowStart;
this.rowEnd = rowEnd;
this.colStart = colStart;
this.colEnd = colEnd;
}
@Override
protected int[][] compute() {
int rows = rowEnd - rowStart;
int cols = colEnd - colStart;
// If the submatrix is small enough, compute sequentially
if (rows <= THRESHOLD && cols <= THRESHOLD) {
return computeSequentially();
}
// Split along rows if more rows than columns, otherwise split along columns
if (rows >= cols) {
int midRow = rowStart + rows / 2;
MatrixMultiplicationTask topTask = new MatrixMultiplicationTask(
matrixA, matrixB, rowStart, midRow, colStart, colEnd);
MatrixMultiplicationTask bottomTask = new MatrixMultiplicationTask(
matrixA, matrixB, midRow, rowEnd, colStart, colEnd);
topTask.fork();
int[][] bottomResult = bottomTask.compute();
int[][] topResult = topTask.join();
return combineResults(topResult, bottomResult, rows, cols);
} else {
int midCol = colStart + cols / 2;
MatrixMultiplicationTask leftTask = new MatrixMultiplicationTask(
matrixA, matrixB, rowStart, rowEnd, colStart, midCol);
MatrixMultiplicationTask rightTask = new MatrixMultiplicationTask(
matrixA, matrixB, rowStart, rowEnd, midCol, colEnd);
leftTask.fork();
int[][] rightResult = rightTask.compute();
int[][] leftResult = leftTask.join();
return combineResults(leftResult, rightResult, rows, cols);
}
}
private int[][] computeSequentially() {
int rows = rowEnd - rowStart;
int cols = colEnd - colStart;
int[][] result = new int[rows][cols];
for (int i = rowStart; i < rowEnd; i++) {
for (int j = colStart; j < colEnd; j++) {
for (int k = 0; k < matrixB.length; k++) {
result[i - rowStart][j - colStart] += matrixA[i][k] * matrixB[k][j];
}
}
}
return result;
}
private int[][] combineResults(int[][] part1, int[][] part2, int rows, int cols) {
int[][] combined = new int[rows][cols];
// Determine how to combine based on split direction
if (part1.length == rows && part2.length == rows) {
// Split was along columns
int splitPoint = part1[0].length;
for (int i = 0; i < rows; i++) {
System.arraycopy(part1[i], 0, combined[i], 0, splitPoint);
System.arraycopy(part2[i], 0, combined[i], splitPoint, part2[i].length);
}
} else {
// Split was along rows
int splitPoint = part1.length;
for (int i = 0; i < splitPoint; i++) {
System.arraycopy(part1[i], 0, combined[i], 0, cols);
}
for (int i = 0; i < part2.length; i++) {
System.arraycopy(part2[i], 0, combined[i + splitPoint], 0, cols);
}
}
return combined;
}
public static void main(String[] args) {
// Create test matrices
int size = 256;
int[][] matrixA = createMatrix(size, size, 1);
int[][] matrixB = createMatrix(size, size, 2);
ForkJoinPool pool = ForkJoinPool.commonPool();
MatrixMultiplicationTask task = new MatrixMultiplicationTask(matrixA, matrixB);
long startTime = System.currentTimeMillis();
int[][] result = pool.invoke(task);
long parallelTime = System.currentTimeMillis() - startTime;
System.out.println("Matrix multiplication completed");
System.out.println("Result matrix: " + result.length + "x" + result[0].length);
System.out.println("Parallel time: " + parallelTime + "ms");
}
private static int[][] createMatrix(int rows, int cols, int value) {
int[][] matrix = new int[rows][cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
matrix[i][j] = value;
}
}
return matrix;
}
}

3. RecursiveAction (No Result)

Basic Structure

public abstract class RecursiveAction extends ForkJoinTask<Void> {
protected abstract void compute();
// Other inherited methods:
// fork(), join(), invoke(), etc.
}

Complete Example: Array Processing with RecursiveAction

import java.util.concurrent.RecursiveAction;
import java.util.concurrent.ForkJoinPool;
import java.util.Random;
public class ArrayTransformRecursiveAction extends RecursiveAction {
private final double[] array;
private final int start;
private final int end;
private static final int THRESHOLD = 10000;
public ArrayTransformRecursiveAction(double[] array) {
this(array, 0, array.length);
}
private ArrayTransformRecursiveAction(double[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected void compute() {
int length = end - start;
if (length <= THRESHOLD) {
transformSequentially();
return;
}
// Split the task
int mid = start + length / 2;
ArrayTransformRecursiveAction leftTask = 
new ArrayTransformRecursiveAction(array, start, mid);
ArrayTransformRecursiveAction rightTask = 
new ArrayTransformRecursiveAction(array, mid, end);
// Invoke all tasks - more efficient than separate fork/join
invokeAll(leftTask, rightTask);
// No need to combine results since we're modifying the array in place
}
private void transformSequentially() {
for (int i = start; i < end; i++) {
// Apply complex transformation
array[i] = Math.sqrt(Math.abs(Math.sin(array[i]) * Math.cos(array[i]))) + 
Math.log(Math.abs(array[i]) + 1);
}
System.out.println(Thread.currentThread().getName() + 
" transformed elements " + start + " to " + (end-1));
}
public static void main(String[] args) {
// Create test array
double[] array = createTestArray(100000);
// Create ForkJoinPool
ForkJoinPool pool = new ForkJoinPool();
// Create and execute task
ArrayTransformRecursiveAction task = new ArrayTransformRecursiveAction(array);
System.out.println("Before transformation - first 5 elements:");
printFirstElements(array, 5);
long startTime = System.currentTimeMillis();
pool.invoke(task);
long endTime = System.currentTimeMillis();
System.out.println("After transformation - first 5 elements:");
printFirstElements(array, 5);
System.out.println("Transformation completed in " + (endTime - startTime) + "ms");
}
private static double[] createTestArray(int size) {
double[] array = new double[size];
Random random = new Random();
for (int i = 0; i < size; i++) {
array[i] = random.nextDouble() * 100 - 50; // Values between -50 and 50
}
return array;
}
private static void printFirstElements(double[] array, int count) {
for (int i = 0; i < Math.min(count, array.length); i++) {
System.out.printf("array[%d] = %.4f%n", i, array[i]);
}
}
}

Advanced RecursiveAction: Parallel QuickSort

import java.util.concurrent.RecursiveAction;
import java.util.concurrent.ForkJoinPool;
import java.util.Random;
public class ParallelQuickSortAction extends RecursiveAction {
private final int[] array;
private final int start;
private final int end;
private static final int THRESHOLD = 1000;
public ParallelQuickSortAction(int[] array) {
this(array, 0, array.length - 1);
}
private ParallelQuickSortAction(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected void compute() {
if (end - start <= THRESHOLD) {
// Use sequential quicksort for small arrays
sequentialQuickSort(start, end);
return;
}
if (start < end) {
int pivotIndex = partition(start, end);
ParallelQuickSortAction leftTask = 
new ParallelQuickSortAction(array, start, pivotIndex - 1);
ParallelQuickSortAction rightTask = 
new ParallelQuickSortAction(array, pivotIndex + 1, end);
// Fork both tasks and wait for completion
invokeAll(leftTask, rightTask);
}
}
private int partition(int low, int high) {
int pivot = array[high];
int i = low - 1;
for (int j = low; j < high; j++) {
if (array[j] <= pivot) {
i++;
swap(i, j);
}
}
swap(i + 1, high);
return i + 1;
}
private void sequentialQuickSort(int low, int high) {
if (low < high) {
int pivotIndex = partition(low, high);
sequentialQuickSort(low, pivotIndex - 1);
sequentialQuickSort(pivotIndex + 1, high);
}
}
private void swap(int i, int j) {
int temp = array[i];
array[i] = array[j];
array[j] = temp;
}
public static void main(String[] args) {
// Create test array
int[] array = createTestArray(100000);
int[] arrayCopy = array.clone();
// Test parallel sort
ForkJoinPool pool = ForkJoinPool.commonPool();
ParallelQuickSortAction parallelTask = new ParallelQuickSortAction(array);
long startTime = System.currentTimeMillis();
pool.invoke(parallelTask);
long parallelTime = System.currentTimeMillis() - startTime;
// Test sequential sort for comparison
startTime = System.currentTimeMillis();
sequentialQuickSort(arrayCopy, 0, arrayCopy.length - 1);
long sequentialTime = System.currentTimeMillis() - startTime;
System.out.println("Parallel sort time: " + parallelTime + "ms");
System.out.println("Sequential sort time: " + sequentialTime + "ms");
System.out.println("Speedup: " + (double) sequentialTime / parallelTime);
System.out.println("Array sorted: " + isSorted(array));
}
private static int[] createTestArray(int size) {
int[] array = new int[size];
Random random = new Random();
for (int i = 0; i < size; i++) {
array[i] = random.nextInt(1000000);
}
return array;
}
private static void sequentialQuickSort(int[] array, int low, int high) {
if (low < high) {
int pivot = array[high];
int i = low - 1;
for (int j = low; j < high; j++) {
if (array[j] <= pivot) {
i++;
int temp = array[i];
array[i] = array[j];
array[j] = temp;
}
}
int temp = array[i + 1];
array[i + 1] = array[high];
array[high] = temp;
int pivotIndex = i + 1;
sequentialQuickSort(array, low, pivotIndex - 1);
sequentialQuickSort(array, pivotIndex + 1, high);
}
}
private static boolean isSorted(int[] array) {
for (int i = 1; i < array.length; i++) {
if (array[i] < array[i - 1]) {
return false;
}
}
return true;
}
}

4. Key Methods and Patterns

Fork/Join Patterns

Pattern 1: Basic Fork-Join

@Override
protected Result compute() {
if (isBaseCase()) {
return computeSequentially();
}
SubTask left = new SubTask(leftWork);
SubTask right = new SubTask(rightWork);
left.fork();           // Execute left asynchronously
Result rightResult = right.compute();  // Compute right in current thread
Result leftResult = left.join();       // Wait for left result
return combine(leftResult, rightResult);
}

Pattern 2: invokeAll (More Efficient)

@Override
protected Result compute() {
if (isBaseCase()) {
return computeSequentially();
}
List<SubTask> subtasks = createSubtasks();
invokeAll(subtasks);  // Fork all and wait for all
return combineResults(subtasks);
}

Pattern 3: Mixed Sequential/Parallel

@Override
protected Result compute() {
if (isBaseCase()) {
return computeSequentially();
}
List<SubTask> subtasks = createSubtasks();
// Process first half sequentially in current thread
for (int i = 0; i < subtasks.size() / 2; i++) {
subtasks.get(i).compute();
}
// Fork second half
for (int i = subtasks.size() / 2; i < subtasks.size(); i++) {
subtasks.get(i).fork();
}
// Join second half
for (int i = subtasks.size() / 2; i < subtasks.size(); i++) {
subtasks.get(i).join();
}
return combineResults(subtasks);
}

5. Exception Handling

import java.util.concurrent.RecursiveTask;
public class ExceptionHandlingTask extends RecursiveTask<Integer> {
private final int[] array;
private final int start;
private final int end;
public ExceptionHandlingTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
try {
if (end - start == 1) {
if (array[start] < 0) {
throw new IllegalArgumentException("Negative value found: " + array[start]);
}
return array[start];
}
int mid = start + (end - start) / 2;
ExceptionHandlingTask leftTask = new ExceptionHandlingTask(array, start, mid);
ExceptionHandlingTask rightTask = new ExceptionHandlingTask(array, mid, end);
leftTask.fork();
int rightResult = rightTask.compute();
int leftResult = leftTask.join();
return leftResult + rightResult;
} catch (Exception e) {
// Handle exception and complete exceptionally
completeExceptionally(e);
return null;
}
}
public static void main(String[] args) {
int[] array = {1, 2, -3, 4, 5}; // Contains negative number
ForkJoinPool pool = ForkJoinPool.commonPool();
ExceptionHandlingTask task = new ExceptionHandlingTask(array, 0, array.length);
try {
Integer result = pool.invoke(task);
System.out.println("Result: " + result);
} catch (Exception e) {
System.out.println("Task failed: " + e.getMessage());
System.out.println("Cause: " + e.getCause().getMessage());
}
}
}

6. Best Practices

1. Choose Appropriate Threshold

// Good threshold calculation
private static final int THRESHOLD = 
Math.max(array.length / (Runtime.getRuntime().availableProcessors() * 4), 1000);

2. Avoid Shared Mutable State

// Good - each task works on its own data segment
private void transformSequentially() {
for (int i = start; i < end; i++) {
array[i] = transform(array[i]); // Only modifies own segment
}
}
// Bad - shared mutable state requires synchronization
private static int sharedCounter = 0; // AVOID THIS

3. Use invokeAll for Multiple Subtasks

// More efficient than sequential fork/join
@Override
protected Result compute() {
if (isBaseCase()) return computeSequentially();
List<SubTask> subtasks = createSubtasks();
invokeAll(subtasks); // Fork all at once
return combineResults(subtasks);
}

4. Balance Workload

// Split work evenly
int mid = start + (end - start) / 2;
// For uneven workloads, consider dynamic splitting
if (isWorkHeavyOnLeft()) {
// Create smaller left task and larger right task
}

7. Performance Considerations

  • Threshold Tuning: Test different thresholds for optimal performance
  • Task Overhead: Avoid creating too many small tasks
  • Work Stealing: Let the framework balance workload automatically
  • Memory Locality: Design tasks to work on contiguous memory when possible

Summary

Use RecursiveTask when:

  • You need to return a result from computation
  • Performing mathematical computations
  • Processing data that produces output

Use RecursiveAction when:

  • Modifying data structures in place
  • Performing operations with side effects
  • No result needs to be returned

Both classes provide an efficient way to parallelize recursive algorithms while leveraging the work-stealing capabilities of the Fork/Join framework.

Leave a Reply

Your email address will not be published. Required fields are marked *


Macro Nepal Helper