Work-Stealing Algorithm in Java

Work-stealing is a scheduling strategy where idle threads "steal" work from busy threads. Java provides excellent support for work-stealing through the ForkJoinPool framework.

ForkJoinPool - Built-in Work-Stealing

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.ForkJoinTask;
import java.util.Arrays;
import java.util.concurrent.ThreadLocalRandom;
public class WorkStealingDemo {
public static void main(String[] args) {
// ForkJoinPool uses work-stealing by default
ForkJoinPool pool = ForkJoinPool.commonPool();
System.out.println("Parallelism: " + pool.getParallelism());
System.out.println("Pool size: " + pool.getPoolSize());
// Example: Parallel array processing
int[] array = createLargeArray(1000000);
SumTask task = new SumTask(array, 0, array.length);
long result = pool.invoke(task);
System.out.println("Sum result: " + result);
}
private static int[] createLargeArray(int size) {
int[] array = new int[size];
Arrays.setAll(array, i -> ThreadLocalRandom.current().nextInt(100));
return array;
}
}
// RecursiveTask for computations that return a result
class SumTask extends RecursiveTask<Long> {
private static final int THRESHOLD = 10000;
private final int[] array;
private final int start;
private final int end;
public SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length = end - start;
// If small enough, compute directly
if (length <= THRESHOLD) {
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
}
// Split task into subtasks
int mid = start + (end - start) / 2;
SumTask leftTask = new SumTask(array, start, mid);
SumTask rightTask = new SumTask(array, mid, end);
// Fork the left task (execute asynchronously)
leftTask.fork();
// Compute right task and wait for left
long rightResult = rightTask.compute();
long leftResult = leftTask.join();
return leftResult + rightResult;
}
}

Custom Work-Stealing Thread Pool

import java.util.concurrent.*;
import java.util.*;
import java.util.concurrent.atomic.*;
public class CustomWorkStealingPool {
private final BlockingQueue<Runnable>[] queues;
private final WorkerThread[] workers;
private final AtomicBoolean shutdown;
@SuppressWarnings("unchecked")
public CustomWorkStealingPool(int poolSize) {
this.queues = new BlockingQueue[poolSize];
this.workers = new WorkerThread[poolSize];
this.shutdown = new AtomicBoolean(false);
// Create work queues for each thread
for (int i = 0; i < poolSize; i++) {
queues[i] = new LinkedBlockingQueue<>();
workers[i] = new WorkerThread(i, queues);
workers[i].start();
}
}
public void submit(int threadId, Runnable task) {
if (!shutdown.get()) {
queues[threadId].offer(task);
}
}
public void submit(Runnable task) {
if (!shutdown.get()) {
// Distribute tasks randomly for load balancing
int randomThread = ThreadLocalRandom.current().nextInt(workers.length);
queues[randomThread].offer(task);
}
}
public void shutdown() {
shutdown.set(true);
for (WorkerThread worker : workers) {
worker.interrupt();
}
}
}
class WorkerThread extends Thread {
private final int myId;
private final BlockingQueue<Runnable>[] allQueues;
private static final AtomicInteger taskCounter = new AtomicInteger(0);
public WorkerThread(int id, BlockingQueue<Runnable>[] allQueues) {
super("Worker-" + id);
this.myId = id;
this.allQueues = allQueues;
}
@Override
public void run() {
BlockingQueue<Runnable> myQueue = allQueues[myId];
while (!Thread.currentThread().isInterrupted()) {
try {
Runnable task = myQueue.poll(100, TimeUnit.MILLISECONDS);
if (task != null) {
// Execute my own task
executeTask(task);
} else {
// Try to steal work from other threads
stealWork();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
}
}
private void executeTask(Runnable task) {
int taskId = taskCounter.incrementAndGet();
System.out.println(Thread.currentThread().getName() + 
" executing task " + taskId);
try {
task.run();
System.out.println(Thread.currentThread().getName() + 
" completed task " + taskId);
} catch (Exception e) {
System.err.println("Task failed: " + e.getMessage());
}
}
private void stealWork() {
// Try to steal from other threads' queues
for (int i = 0; i < allQueues.length; i++) {
if (i == myId) continue; // Skip own queue
BlockingQueue<Runnable> targetQueue = allQueues[i];
Runnable stolenTask = targetQueue.poll(); // Non-blocking poll
if (stolenTask != null) {
System.out.println(Thread.currentThread().getName() + 
" stole work from Worker-" + i);
executeTask(stolenTask);
return; // Return after successfully stealing one task
}
}
}
}

Advanced Work-Stealing with Task Dependencies

import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
class StealableTask implements Runnable {
private final String name;
private final long duration;
private final CountDownLatch completionLatch;
public StealableTask(String name, long duration, CountDownLatch completionLatch) {
this.name = name;
this.duration = duration;
this.completionLatch = completionLatch;
}
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + " executing " + name);
try {
Thread.sleep(duration); // Simulate work
System.out.println(Thread.currentThread().getName() + " completed " + name);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
completionLatch.countDown();
}
}
@Override
public String toString() {
return name;
}
}
public class AdvancedWorkStealingExecutor {
private final ExecutorService[] executors;
private final BlockingQueue<StealableTask>[] taskQueues;
private final int poolSize;
private final AtomicBoolean running;
@SuppressWarnings("unchecked")
public AdvancedWorkStealingExecutor(int poolSize) {
this.poolSize = poolSize;
this.executors = new ExecutorService[poolSize];
this.taskQueues = new BlockingQueue[poolSize];
this.running = new AtomicBoolean(true);
for (int i = 0; i < poolSize; i++) {
taskQueues[i] = new LinkedBlockingQueue<>();
final int threadId = i;
executors[i] = Executors.newSingleThreadExecutor(r -> {
Thread t = new Thread(r, "StealingWorker-" + threadId);
t.setDaemon(true);
return t;
});
// Start worker for each queue
executors[i].execute(createWorker(threadId));
}
}
private Runnable createWorker(int myId) {
return () -> {
BlockingQueue<StealableTask> myQueue = taskQueues[myId];
Random random = new Random();
while (running.get() && !Thread.currentThread().isInterrupted()) {
try {
// Try to get task from own queue
StealableTask task = myQueue.poll(50, TimeUnit.MILLISECONDS);
if (task != null) {
task.run();
} else if (running.get()) {
// Try to steal work
attemptWorkStealing(myId);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
}
};
}
private void attemptWorkStealing(int thiefId) {
// Create random order for stealing to reduce contention
List<Integer> indices = new ArrayList<>();
for (int i = 0; i < poolSize; i++) {
if (i != thiefId) indices.add(i);
}
Collections.shuffle(indices);
for (int victimId : indices) {
BlockingQueue<StealableTask> victimQueue = taskQueues[victimId];
StealableTask stolenTask = victimQueue.poll();
if (stolenTask != null) {
System.out.printf("Worker-%d stole task %s from Worker-%d%n", 
thiefId, stolenTask, victimId);
stolenTask.run();
return;
}
}
}
public void submit(StealableTask task) {
if (!running.get()) {
throw new RejectedExecutionException("Executor is shutdown");
}
// Submit to random queue for initial load distribution
int randomQueue = ThreadLocalRandom.current().nextInt(poolSize);
taskQueues[randomQueue].offer(task);
}
public void submitTo(int threadId, StealableTask task) {
if (!running.get()) {
throw new RejectedExecutionException("Executor is shutdown");
}
taskQueues[threadId].offer(task);
}
public void shutdown() {
running.set(false);
for (ExecutorService executor : executors) {
executor.shutdown();
}
}
public boolean awaitTermination(long timeout, TimeUnit unit) 
throws InterruptedException {
boolean terminated = true;
for (ExecutorService executor : executors) {
terminated &= executor.awaitTermination(timeout, unit);
}
return terminated;
}
}

Work-Stealing for Divide and Conquer Problems

import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
// Generic divide and conquer task with work-stealing
abstract class DivideAndConquerTask<T> implements Runnable {
protected final String name;
protected final int problemSize;
protected final int threshold;
protected final List<DivideAndConquerTask<T>> subtasks = new ArrayList<>();
protected final AtomicReference<T> result = new AtomicReference<>();
public DivideAndConquerTask(String name, int problemSize, int threshold) {
this.name = name;
this.problemSize = problemSize;
this.threshold = threshold;
}
@Override
public void run() {
if (problemSize <= threshold) {
// Base case: solve directly
result.set(computeDirectly());
System.out.println(Thread.currentThread().getName() + 
" directly computed " + name);
} else {
// Divide and create subtasks
List<DivideAndConquerTask<T>> children = divide();
System.out.println(Thread.currentThread().getName() + 
" divided " + name + " into " + children.size() + " subtasks");
// Execute subtasks (in real implementation, these would be stolen)
for (DivideAndConquerTask<T> child : children) {
child.run();
}
// Combine results
result.set(combine(children));
}
}
protected abstract T computeDirectly();
protected abstract List<DivideAndConquerTask<T>> divide();
protected abstract T combine(List<DivideAndConquerTask<T>> subtasks);
public T getResult() {
return result.get();
}
public List<DivideAndConquerTask<T>> getSubtasks() {
return Collections.unmodifiableList(subtasks);
}
@Override
public String toString() {
return name + "[size=" + problemSize + "]";
}
}
// Concrete implementation: Parallel QuickSort
class ParallelQuickSortTask extends DivideAndConquerTask<int[]> {
private final int[] array;
private final int start;
private final int end;
public ParallelQuickSortTask(int[] array, int start, int end, int threshold) {
super("QuickSort[" + start + "-" + end + "]", end - start, threshold);
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected int[] computeDirectly() {
Arrays.sort(array, start, end);
return array;
}
@Override
protected List<DivideAndConquerTask<int[]>> divide() {
if (start >= end - 1) {
return Collections.emptyList();
}
int pivotIndex = partition(array, start, end);
List<DivideAndConquerTask<int[]>> children = new ArrayList<>();
// Create subtasks for left and right partitions
if (pivotIndex - start > 1) {
children.add(new ParallelQuickSortTask(array, start, pivotIndex, threshold));
}
if (end - pivotIndex > 1) {
children.add(new ParallelQuickSortTask(array, pivotIndex, end, threshold));
}
return children;
}
@Override
protected int[] combine(List<DivideAndConquerTask<int[]>> subtasks) {
// In quicksort, combining is trivial as sorting happens in-place
return array;
}
private int partition(int[] array, int start, int end) {
int pivot = array[end - 1];
int i = start - 1;
for (int j = start; j < end - 1; j++) {
if (array[j] <= pivot) {
i++;
swap(array, i, j);
}
}
swap(array, i + 1, end - 1);
return i + 1;
}
private void swap(int[] array, int i, int j) {
int temp = array[i];
array[i] = array[j];
array[j] = temp;
}
}

Performance Monitoring Work-Stealing Pool

import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
public class MonitoredWorkStealingPool {
private final AdvancedWorkStealingExecutor executor;
private final AtomicLong tasksSubmitted = new AtomicLong();
private final AtomicLong tasksCompleted = new AtomicLong();
private final AtomicLong tasksStolen = new AtomicLong();
private final Map<String, AtomicLong> performanceStats = new ConcurrentHashMap<>();
public MonitoredWorkStealingPool(int poolSize) {
this.executor = new AdvancedWorkStealingExecutor(poolSize);
initializeStats();
}
private void initializeStats() {
performanceStats.put("tasks.submitted", tasksSubmitted);
performanceStats.put("tasks.completed", tasksCompleted);
performanceStats.put("tasks.stolen", tasksStolen);
}
public void submit(Runnable task) {
tasksSubmitted.incrementAndGet();
StealableTask monitoredTask = new StealableTask(
"Task-" + tasksSubmitted.get(),
ThreadLocalRandom.current().nextInt(100, 500),
new CountDownLatch(1)
) {
@Override
public void run() {
long startTime = System.nanoTime();
super.run();
long endTime = System.nanoTime();
tasksCompleted.incrementAndGet();
performanceStats.computeIfAbsent("total.execution.time", 
k -> new AtomicLong()).addAndGet(endTime - startTime);
}
};
executor.submit(monitoredTask);
}
public void printStatistics() {
System.out.println("\n=== Work-Stealing Pool Statistics ===");
System.out.printf("Tasks Submitted: %d%n", tasksSubmitted.get());
System.out.printf("Tasks Completed: %d%n", tasksCompleted.get());
System.out.printf("Tasks Stolen: %d%n", tasksStolen.get());
System.out.printf("Stealing Rate: %.2f%%%n", 
(tasksStolen.get() * 100.0) / tasksCompleted.get());
}
public void shutdown() throws InterruptedException {
executor.shutdown();
executor.awaitTermination(5, TimeUnit.SECONDS);
printStatistics();
}
}

Key Benefits of Work-Stealing

  1. Load Balancing - Automatically balances work across threads
  2. Reduced Contention - Threads work on local queues most of the time
  3. High Throughput - Idle threads help busy threads
  4. Scalability - Works well with large numbers of CPU cores

Best Practices

  1. Use Appropriate Threshold - Balance between task creation overhead and parallelism
  2. Avoid Fine-Grained Tasks - Tasks should have enough work to justify overhead
  3. Minimize Shared State - Reduce synchronization between tasks
  4. Use ForkJoinPool for Recursive Problems - Built-in work-stealing is highly optimized

Work-stealing algorithms are particularly effective for recursive, divide-and-conquer problems and can significantly improve performance on multi-core systems.

Leave a Reply

Your email address will not be published. Required fields are marked *


Macro Nepal Helper