Building Distributed ML Systems with Data Privacy
Article
Federated Learning (FL) is a distributed machine learning approach that enables model training across decentralized devices or servers holding local data samples, without exchanging the data itself. This preserves data privacy and reduces network bandwidth usage while still benefiting from collective learning.
Federated Learning Architecture Overview
Key Components:
- Central Server: Coordinates training and aggregates model updates
- Clients/Nodes: Local devices with data that train models
- Global Model: Shared model maintained by the server
- Local Models: Models trained on client data
- Aggregation Algorithm: Method to combine client updates (e.g., Federated Averaging)
FL Workflow:
1. Server initializes global model 2. Server sends global model to clients 3. Clients train locally on their data 4. Clients send model updates to server 5. Server aggregates updates to improve global model 6. Repeat until convergence
1. Project Setup and Dependencies
Maven Configuration (pom.xml):
<properties>
<dl4j.version>1.0.0-M2.1</dl4j.version>
<nd4j.version>1.0.0-M2.1</nd4j.version>
</properties>
<dependencies>
<!-- Deep Learning for Java -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- ND4J Backend -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>${nd4j.version}</version>
</dependency>
<!-- DataVec for data processing -->
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- Apache Commons Math -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
<!-- Network Communication -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<version>2.7.0</version>
</dependency>
<!-- JSON Processing -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.13.3</version>
</dependency>
<!-- Testing -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
</dependencies>
2. Core Model Definitions
Neural Network Model Factory:
package com.federatedlearning.model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class ModelFactory {
public static MultiLayerNetwork createSimpleNN(int numInputs, int numOutputs, int hiddenNodes) {
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
.seed(1234)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Adam(0.001))
.list()
.layer(0, new DenseLayer.Builder()
.nIn(numInputs)
.nOut(hiddenNodes)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(hiddenNodes)
.nOut(numOutputs)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX)
.build())
.build();
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
return model;
}
public static MultiLayerNetwork createMNISTModel() {
return createSimpleNN(784, 10, 128);
}
public static MultiLayerNetwork createIrisModel() {
return createSimpleNN(4, 3, 8);
}
}
Model Serialization Utilities:
package com.federatedlearning.model;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.*;
import java.util.Base64;
public class ModelSerializer {
public static String modelToBase64(MultiLayerNetwork model) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
org.deeplearning4j.util.ModelSerializer.writeModel(model, baos, true);
return Base64.getEncoder().encodeToString(baos.toByteArray());
}
public static MultiLayerNetwork modelFromBase64(String base64Model) throws IOException {
byte[] modelBytes = Base64.getDecoder().decode(base64Model);
ByteArrayInputStream bais = new ByteArrayInputStream(modelBytes);
return org.deeplearning4j.util.ModelSerializer.restoreMultiLayerNetwork(bais, true);
}
public static String paramsToBase64(INDArray params) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
Nd4j.write(params, dos);
return Base64.getEncoder().encodeToString(baos.toByteArray());
}
public static INDArray paramsFromBase64(String base64Params) throws IOException {
byte[] paramBytes = Base64.getDecoder().decode(base64Params);
ByteArrayInputStream bais = new ByteArrayInputStream(paramBytes);
DataInputStream dis = new DataInputStream(bais);
return Nd4j.read(dis);
}
public static ModelUpdate extractUpdate(MultiLayerNetwork globalModel, MultiLayerNetwork localModel) {
INDArray globalParams = globalModel.params();
INDArray localParams = localModel.params();
INDArray update = localParams.sub(globalParams);
return new ModelUpdate(update, localModel.score());
}
public static void applyUpdate(MultiLayerNetwork model, INDArray update, double learningRate) {
INDArray currentParams = model.params();
INDArray newParams = currentParams.add(update.mul(learningRate));
model.setParams(newParams);
}
public static class ModelUpdate {
private final INDArray parameterUpdate;
private final double loss;
public ModelUpdate(INDArray parameterUpdate, double loss) {
this.parameterUpdate = parameterUpdate;
this.loss = loss;
}
public INDArray getParameterUpdate() { return parameterUpdate; }
public double getLoss() { return loss; }
}
}
3. Federated Server Implementation
Central Server Core:
package com.federatedlearning.server;
import com.federatedlearning.model.ModelSerializer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
@Service
public class FederatedServer {
private MultiLayerNetwork globalModel;
private final List<ClientRegistration> registeredClients;
private final Map<String, ClientUpdate> clientUpdates;
private final ServerMetrics metrics;
private int currentRound;
private final int targetRounds;
private final double minClientRatio;
private final AggregationStrategy aggregationStrategy;
public FederatedServer() {
this.registeredClients = new CopyOnWriteArrayList<>();
this.clientUpdates = new ConcurrentHashMap<>();
this.metrics = new ServerMetrics();
this.targetRounds = 100;
this.minClientRatio = 0.3; // At least 30% of clients must participate
this.aggregationStrategy = AggregationStrategy.FED_AVG;
this.currentRound = 0;
initializeGlobalModel();
}
private void initializeGlobalModel() {
// Initialize with a simple model - can be loaded from file or created fresh
this.globalModel = ModelFactory.createMNISTModel();
metrics.setInitialLoss(globalModel.score());
}
public synchronized String registerClient(String clientId, String clientInfo) {
ClientRegistration registration = new ClientRegistration(clientId, clientInfo, System.currentTimeMillis());
registeredClients.add(registration);
metrics.clientRegistered(clientId);
System.out.println("Client registered: " + clientId + " - Total: " + registeredClients.size());
return getGlobalModelBase64();
}
public synchronized void submitClientUpdate(String clientId, String modelUpdateBase64, int dataSize, double loss) {
try {
INDArray update = ModelSerializer.paramsFromBase64(modelUpdateBase64);
ClientUpdate clientUpdate = new ClientUpdate(clientId, update, dataSize, loss, System.currentTimeMillis());
clientUpdates.put(clientId, clientUpdate);
metrics.updateReceived(clientId, dataSize, loss);
System.out.println("Update received from client: " + clientId + " with " + dataSize + " samples");
// Check if we have enough updates to aggregate
if (shouldAggregate()) {
performAggregation();
}
} catch (Exception e) {
System.err.println("Error processing update from client " + clientId + ": " + e.getMessage());
metrics.updateFailed(clientId);
}
}
private boolean shouldAggregate() {
int minClients = (int) Math.ceil(registeredClients.size() * minClientRatio);
return clientUpdates.size() >= minClients;
}
private synchronized void performAggregation() {
System.out.println("Starting aggregation round " + currentRound +
" with " + clientUpdates.size() + " client updates");
long startTime = System.currentTimeMillis();
try {
INDArray aggregatedUpdate = aggregateUpdates();
applyAggregatedUpdate(aggregatedUpdate);
currentRound++;
metrics.roundCompleted(currentRound, System.currentTimeMillis() - startTime, globalModel.score());
// Clear updates for next round
clientUpdates.clear();
System.out.println("Aggregation completed. Round " + currentRound +
" - Global loss: " + globalModel.score());
// Notify clients about new model
notifyClientsNewModelAvailable();
} catch (Exception e) {
System.err.println("Error during aggregation: " + e.getMessage());
metrics.aggregationFailed();
}
}
private INDArray aggregateUpdates() {
switch (aggregationStrategy) {
case FED_AVG:
return federatedAverage();
case FED_SGD:
return federatedSGD();
case WEIGHTED_AVG:
return weightedAverage();
default:
return federatedAverage();
}
}
private INDArray federatedAverage() {
int totalDataSize = clientUpdates.values().stream()
.mapToInt(ClientUpdate::getDataSize)
.sum();
INDArray weightedSum = null;
for (ClientUpdate update : clientUpdates.values()) {
double weight = (double) update.getDataSize() / totalDataSize;
INDArray weightedUpdate = update.getParameterUpdate().mul(weight);
if (weightedSum == null) {
weightedSum = weightedUpdate;
} else {
weightedSum.addi(weightedUpdate);
}
}
return weightedSum != null ? weightedSum : Nd4j.zerosLike(globalModel.params());
}
private INDArray weightedAverage() {
// Weight by both data size and loss improvement
double totalWeight = clientUpdates.values().stream()
.mapToDouble(this::calculateClientWeight)
.sum();
INDArray weightedSum = null;
for (ClientUpdate update : clientUpdates.values()) {
double weight = calculateClientWeight(update) / totalWeight;
INDArray weightedUpdate = update.getParameterUpdate().mul(weight);
if (weightedSum == null) {
weightedSum = weightedUpdate;
} else {
weightedSum.addi(weightedUpdate);
}
}
return weightedSum != null ? weightedSum : Nd4j.zerosLike(globalModel.params());
}
private INDArray federatedSGD() {
// Simple average of all updates
INDArray sum = null;
int count = 0;
for (ClientUpdate update : clientUpdates.values()) {
if (sum == null) {
sum = update.getParameterUpdate().dup();
} else {
sum.addi(update.getParameterUpdate());
}
count++;
}
return sum != null ? sum.div(count) : Nd4j.zerosLike(globalModel.params());
}
private double calculateClientWeight(ClientUpdate update) {
// Combine data size and loss for weighting
double dataWeight = Math.log(1 + update.getDataSize());
double lossWeight = 1.0 / (1.0 + update.getLoss()); // Prefer clients with lower loss
return dataWeight * lossWeight;
}
private void applyAggregatedUpdate(INDArray aggregatedUpdate) {
ModelSerializer.applyUpdate(globalModel, aggregatedUpdate, getLearningRate());
}
private double getLearningRate() {
// Decay learning rate over rounds
double initialLR = 0.1;
double decayRate = 0.95;
return initialLR * Math.pow(decayRate, currentRound);
}
private void notifyClientsNewModelAvailable() {
// In a real implementation, this would push notifications to clients
// For now, clients will poll for updates
System.out.println("New global model available. Notifying clients...");
}
public String getGlobalModelBase64() {
try {
return ModelSerializer.modelToBase64(globalModel);
} catch (Exception e) {
throw new RuntimeException("Failed to serialize global model", e);
}
}
public ServerStatus getServerStatus() {
return new ServerStatus(
currentRound,
targetRounds,
registeredClients.size(),
clientUpdates.size(),
globalModel.score(),
metrics.getSummary()
);
}
// Getters
public int getCurrentRound() { return currentRound; }
public int getRegisteredClientCount() { return registeredClients.size(); }
public ServerMetrics getMetrics() { return metrics; }
public enum AggregationStrategy {
FED_AVG, // Federated Averaging
FED_SGD, // Federated SGD
WEIGHTED_AVG // Weighted by data size and loss
}
}
Server Data Classes:
package com.federatedlearning.server;
import java.util.concurrent.atomic.AtomicLong;
public class ClientRegistration {
private final String clientId;
private final String clientInfo;
private final long registrationTime;
private long lastUpdateTime;
public ClientRegistration(String clientId, String clientInfo, long registrationTime) {
this.clientId = clientId;
this.clientInfo = clientInfo;
this.registrationTime = registrationTime;
this.lastUpdateTime = registrationTime;
}
// Getters and setters
public String getClientId() { return clientId; }
public String getClientInfo() { return clientInfo; }
public long getRegistrationTime() { return registrationTime; }
public long getLastUpdateTime() { return lastUpdateTime; }
public void setLastUpdateTime(long lastUpdateTime) { this.lastUpdateTime = lastUpdateTime; }
}
class ClientUpdate {
private final String clientId;
private final INDArray parameterUpdate;
private final int dataSize;
private final double loss;
private final long timestamp;
public ClientUpdate(String clientId, INDArray parameterUpdate, int dataSize, double loss, long timestamp) {
this.clientId = clientId;
this.parameterUpdate = parameterUpdate;
this.dataSize = dataSize;
this.loss = loss;
this.timestamp = timestamp;
}
// Getters
public String getClientId() { return clientId; }
public INDArray getParameterUpdate() { return parameterUpdate; }
public int getDataSize() { return dataSize; }
public double getLoss() { return loss; }
public long getTimestamp() { return timestamp; }
}
class ServerStatus {
private final int currentRound;
private final int targetRounds;
private final int registeredClients;
private final int activeClients;
private final double globalLoss;
private final MetricsSummary metrics;
public ServerStatus(int currentRound, int targetRounds, int registeredClients,
int activeClients, double globalLoss, MetricsSummary metrics) {
this.currentRound = currentRound;
this.targetRounds = targetRounds;
this.registeredClients = registeredClients;
this.activeClients = activeClients;
this.globalLoss = globalLoss;
this.metrics = metrics;
}
// Getters
public int getCurrentRound() { return currentRound; }
public int getTargetRounds() { return targetRounds; }
public int getRegisteredClients() { return registeredClients; }
public int getActiveClients() { return activeClients; }
public double getGlobalLoss() { return globalLoss; }
public MetricsSummary getMetrics() { return metrics; }
}
class ServerMetrics {
private final AtomicLong totalRounds;
private final AtomicLong successfulAggregations;
private final AtomicLong failedAggregations;
private final AtomicLong totalClientsRegistered;
private final AtomicLong totalUpdatesReceived;
private final AtomicLong totalDataProcessed;
private double initialLoss;
private double bestLoss;
public ServerMetrics() {
this.totalRounds = new AtomicLong(0);
this.successfulAggregations = new AtomicLong(0);
this.failedAggregations = new AtomicLong(0);
this.totalClientsRegistered = new AtomicLong(0);
this.totalUpdatesReceived = new AtomicLong(0);
this.totalDataProcessed = new AtomicLong(0);
this.initialLoss = Double.MAX_VALUE;
this.bestLoss = Double.MAX_VALUE;
}
public void clientRegistered(String clientId) {
totalClientsRegistered.incrementAndGet();
}
public void updateReceived(String clientId, int dataSize, double loss) {
totalUpdatesReceived.incrementAndGet();
totalDataProcessed.addAndGet(dataSize);
}
public void updateFailed(String clientId) {
// Track failed updates if needed
}
public void roundCompleted(int round, long duration, double currentLoss) {
totalRounds.incrementAndGet();
successfulAggregations.incrementAndGet();
if (currentLoss < bestLoss) {
bestLoss = currentLoss;
}
}
public void aggregationFailed() {
failedAggregations.incrementAndGet();
}
public void setInitialLoss(double initialLoss) {
this.initialLoss = initialLoss;
this.bestLoss = initialLoss;
}
public MetricsSummary getSummary() {
return new MetricsSummary(
totalRounds.get(),
successfulAggregations.get(),
failedAggregations.get(),
totalClientsRegistered.get(),
totalUpdatesReceived.get(),
totalDataProcessed.get(),
initialLoss,
bestLoss
);
}
}
class MetricsSummary {
private final long totalRounds;
private final long successfulAggregations;
private final long failedAggregations;
private final long totalClientsRegistered;
private final long totalUpdatesReceived;
private final long totalDataProcessed;
private final double initialLoss;
private final double bestLoss;
public MetricsSummary(long totalRounds, long successfulAggregations, long failedAggregations,
long totalClientsRegistered, long totalUpdatesReceived, long totalDataProcessed,
double initialLoss, double bestLoss) {
this.totalRounds = totalRounds;
this.successfulAggregations = successfulAggregations;
this.failedAggregations = failedAggregations;
this.totalClientsRegistered = totalClientsRegistered;
this.totalUpdatesReceived = totalUpdatesReceived;
this.totalDataProcessed = totalDataProcessed;
this.initialLoss = initialLoss;
this.bestLoss = bestLoss;
}
// Getters
public long getTotalRounds() { return totalRounds; }
public long getSuccessfulAggregations() { return successfulAggregations; }
public long getFailedAggregations() { return failedAggregations; }
public long getTotalClientsRegistered() { return totalClientsRegistered; }
public long getTotalUpdatesReceived() { return totalUpdatesReceived; }
public long getTotalDataProcessed() { return totalDataProcessed; }
public double getInitialLoss() { return initialLoss; }
public double getBestLoss() { return bestLoss; }
}
4. Federated Client Implementation
Client Core:
package com.federatedlearning.client;
import com.federatedlearning.model.ModelFactory;
import com.federatedlearning.model.ModelSerializer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.Random;
@Service
public class FederatedClient {
private final String clientId;
private final String serverUrl;
private MultiLayerNetwork localModel;
private final LocalDataProvider dataProvider;
private final ClientMetrics metrics;
private int localEpochs = 1;
private double learningRate = 0.01;
private boolean isTraining = false;
public FederatedClient(@Value("${federated.client.id}") String clientId,
@Value("${federated.server.url}") String serverUrl) {
this.clientId = clientId;
this.serverUrl = serverUrl;
this.dataProvider = new LocalDataProvider(clientId);
this.metrics = new ClientMetrics(clientId);
initializeClient();
}
private void initializeClient() {
try {
// Register with server and get initial model
String initialModel = registerWithServer();
this.localModel = ModelSerializer.modelFromBase64(initialModel);
metrics.clientInitialized();
System.out.println("Client " + clientId + " initialized with model from server");
} catch (Exception e) {
System.err.println("Failed to initialize client: " + e.getMessage());
// Fallback: create local model
this.localModel = ModelFactory.createMNISTModel();
}
}
public synchronized void performLocalTraining() {
if (isTraining) {
System.out.println("Client " + clientId + " is already training");
return;
}
isTraining = true;
long startTime = System.currentTimeMillis();
try {
System.out.println("Client " + clientId + " starting local training");
// Get local data
DataSetIterator localData = dataProvider.getTrainingData();
if (localData == null || !localData.hasNext()) {
System.out.println("No local data available for training");
return;
}
// Store current model state for update calculation
MultiLayerNetwork modelBeforeTraining = localModel.clone();
// Perform local training
double lossBefore = localModel.score();
trainLocalModel(localData);
double lossAfter = localModel.score();
// Calculate model update
ModelSerializer.ModelUpdate update =
ModelSerializer.extractUpdate(modelBeforeTraining, localModel);
// Send update to server
sendUpdateToServer(update, dataProvider.getDataSize());
metrics.trainingCompleted(lossBefore, lossAfter, System.currentTimeMillis() - startTime);
System.out.println("Client " + clientId + " completed local training. " +
"Loss: " + lossBefore + " -> " + lossAfter);
} catch (Exception e) {
System.err.println("Error during local training: " + e.getMessage());
metrics.trainingFailed();
} finally {
isTraining = false;
}
}
public synchronized void updateModelFromServer() {
try {
String latestModel = fetchLatestModelFromServer();
if (latestModel != null) {
this.localModel = ModelSerializer.modelFromBase64(latestModel);
metrics.modelUpdated();
System.out.println("Client " + clientId + " updated model from server");
}
} catch (Exception e) {
System.err.println("Failed to update model from server: " + e.getMessage());
}
}
private void trainLocalModel(DataSetIterator trainingData) {
for (int epoch = 0; epoch < localEpochs; epoch++) {
trainingData.reset();
while (trainingData.hasNext()) {
DataSet batch = trainingData.next();
localModel.fit(batch);
}
}
}
private String registerWithServer() {
// HTTP call to server registration endpoint
// Implementation would use RestTemplate or similar
System.out.println("Registering client " + clientId + " with server");
// Mock response - in real implementation, this would be an HTTP call
return MockServerAPI.registerClient(clientId, "ClientInfo");
}
private void sendUpdateToServer(ModelSerializer.ModelUpdate update, int dataSize) {
try {
String updateBase64 = ModelSerializer.paramsToBase64(update.getParameterUpdate());
// HTTP call to server update endpoint
MockServerAPI.submitUpdate(clientId, updateBase64, dataSize, update.getLoss());
metrics.updateSent();
} catch (Exception e) {
System.err.println("Failed to send update to server: " + e.getMessage());
metrics.updateFailed();
}
}
private String fetchLatestModelFromServer() {
// HTTP call to server model endpoint
return MockServerAPI.getLatestModel(clientId);
}
public double evaluateLocalModel() {
try {
DataSetIterator testData = dataProvider.getTestData();
if (testData == null) return Double.NaN;
return localModel.score(testData);
} catch (Exception e) {
System.err.println("Error evaluating model: " + e.getMessage());
return Double.NaN;
}
}
public ClientStatus getClientStatus() {
return new ClientStatus(
clientId,
isTraining,
dataProvider.getDataSize(),
localModel.score(),
metrics.getSummary()
);
}
// Getters and setters
public String getClientId() { return clientId; }
public boolean isTraining() { return isTraining; }
public ClientMetrics getMetrics() { return metrics; }
public void setLocalEpochs(int localEpochs) { this.localEpochs = localEpochs; }
public void setLearningRate(double learningRate) { this.learningRate = learningRate; }
}
Local Data Provider:
package com.federatedlearning.client;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.IOException;
import java.util.Random;
public class LocalDataProvider {
private final String clientId;
private final Random random;
private int dataSize;
public LocalDataProvider(String clientId) {
this.clientId = clientId;
this.random = new Random(clientId.hashCode());
this.dataSize = 1000 + random.nextInt(2000); // Simulate varying data sizes
}
public DataSetIterator getTrainingData() {
try {
// In real implementation, this would load client's local data
// For demonstration, we'll use MNIST with client-specific subsets
int batchSize = 64;
int totalExamples = 60000;
// Simulate different clients having different data distributions
int clientStartIdx = Math.abs(clientId.hashCode()) % (totalExamples - dataSize);
// Mock: return MNIST iterator (in reality, this would be client's private data)
return new MnistDataSetIterator(batchSize, dataSize, false, true, true, clientStartIdx);
} catch (IOException e) {
System.err.println("Failed to load training data: " + e.getMessage());
return null;
}
}
public DataSetIterator getTestData() {
try {
// Mock test data
return new MnistDataSetIterator(64, 1000, false, false, true, 0);
} catch (IOException e) {
System.err.println("Failed to load test data: " + e.getMessage());
return null;
}
}
public int getDataSize() {
return dataSize;
}
public void simulateDataChange() {
// Simulate changing local data over time
this.dataSize = 800 + random.nextInt(2400);
}
}
Client Data Classes:
package com.federatedlearning.client;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
public class ClientMetrics {
private final String clientId;
private final AtomicInteger trainingSessions;
private final AtomicInteger successfulUpdates;
private final AtomicInteger failedUpdates;
private final AtomicInteger modelUpdates;
private final AtomicLong totalTrainingTime;
private double bestLocalLoss;
private double currentLoss;
public ClientMetrics(String clientId) {
this.clientId = clientId;
this.trainingSessions = new AtomicInteger(0);
this.successfulUpdates = new AtomicInteger(0);
this.failedUpdates = new AtomicInteger(0);
this.modelUpdates = new AtomicInteger(0);
this.totalTrainingTime = new AtomicLong(0);
this.bestLocalLoss = Double.MAX_VALUE;
this.currentLoss = Double.MAX_VALUE;
}
public void clientInitialized() {
// Metrics for client initialization
}
public void trainingCompleted(double lossBefore, double lossAfter, long duration) {
trainingSessions.incrementAndGet();
totalTrainingTime.addAndGet(duration);
currentLoss = lossAfter;
if (lossAfter < bestLocalLoss) {
bestLocalLoss = lossAfter;
}
}
public void trainingFailed() {
// Track training failures
}
public void updateSent() {
successfulUpdates.incrementAndGet();
}
public void updateFailed() {
failedUpdates.incrementAndGet();
}
public void modelUpdated() {
modelUpdates.incrementAndGet();
}
public ClientMetricsSummary getSummary() {
return new ClientMetricsSummary(
clientId,
trainingSessions.get(),
successfulUpdates.get(),
failedUpdates.get(),
modelUpdates.get(),
totalTrainingTime.get(),
bestLocalLoss,
currentLoss
);
}
}
class ClientMetricsSummary {
private final String clientId;
private final int trainingSessions;
private final int successfulUpdates;
private final int failedUpdates;
private final int modelUpdates;
private final long totalTrainingTime;
private final double bestLocalLoss;
private final double currentLoss;
public ClientMetricsSummary(String clientId, int trainingSessions, int successfulUpdates,
int failedUpdates, int modelUpdates, long totalTrainingTime,
double bestLocalLoss, double currentLoss) {
this.clientId = clientId;
this.trainingSessions = trainingSessions;
this.successfulUpdates = successfulUpdates;
this.failedUpdates = failedUpdates;
this.modelUpdates = modelUpdates;
this.totalTrainingTime = totalTrainingTime;
this.bestLocalLoss = bestLocalLoss;
this.currentLoss = currentLoss;
}
// Getters
public String getClientId() { return clientId; }
public int getTrainingSessions() { return trainingSessions; }
public int getSuccessfulUpdates() { return successfulUpdates; }
public int getFailedUpdates() { return failedUpdates; }
public int getModelUpdates() { return modelUpdates; }
public long getTotalTrainingTime() { return totalTrainingTime; }
public double getBestLocalLoss() { return bestLocalLoss; }
public double getCurrentLoss() { return currentLoss; }
}
class ClientStatus {
private final String clientId;
private final boolean isTraining;
private final int dataSize;
private final double currentLoss;
private final ClientMetricsSummary metrics;
public ClientStatus(String clientId, boolean isTraining, int dataSize,
double currentLoss, ClientMetricsSummary metrics) {
this.clientId = clientId;
this.isTraining = isTraining;
this.dataSize = dataSize;
this.currentLoss = currentLoss;
this.metrics = metrics;
}
// Getters
public String getClientId() { return clientId; }
public boolean isTraining() { return isTraining; }
public int getDataSize() { return dataSize; }
public double getCurrentLoss() { return currentLoss; }
public ClientMetricsSummary getMetrics() { return metrics; }
}
5. REST API Controllers
Server API:
package com.federatedlearning.api;
import com.federatedlearning.server.FederatedServer;
import com.federatedlearning.server.ServerStatus;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/api/federated/server")
public class ServerController {
@Autowired
private FederatedServer federatedServer;
@PostMapping("/register")
public String registerClient(@RequestParam String clientId,
@RequestParam String clientInfo) {
return federatedServer.registerClient(clientId, clientInfo);
}
@PostMapping("/update")
public void submitUpdate(@RequestParam String clientId,
@RequestParam String modelUpdate,
@RequestParam int dataSize,
@RequestParam double loss) {
federatedServer.submitClientUpdate(clientId, modelUpdate, dataSize, loss);
}
@GetMapping("/model")
public String getGlobalModel() {
return federatedServer.getGlobalModelBase64();
}
@GetMapping("/status")
public ServerStatus getServerStatus() {
return federatedServer.getServerStatus();
}
@GetMapping("/metrics")
public Object getServerMetrics() {
return federatedServer.getMetrics().getSummary();
}
}
Client API:
package com.federatedlearning.api;
import com.federatedlearning.client.FederatedClient;
import com.federatedlearning.client.ClientStatus;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/api/federated/client")
public class ClientController {
@Autowired
private FederatedClient federatedClient;
@PostMapping("/train")
public String startTraining() {
new Thread(() -> federatedClient.performLocalTraining()).start();
return "Training started for client: " + federatedClient.getClientId();
}
@PostMapping("/update-model")
public String updateModel() {
federatedClient.updateModelFromServer();
return "Model update initiated";
}
@GetMapping("/status")
public ClientStatus getClientStatus() {
return federatedClient.getClientStatus();
}
@GetMapping("/evaluate")
public double evaluateModel() {
return federatedClient.evaluateLocalModel();
}
@PostMapping("/config")
public String updateConfig(@RequestParam(required = false) Integer localEpochs,
@RequestParam(required = false) Double learningRate) {
if (localEpochs != null) {
federatedClient.setLocalEpochs(localEpochs);
}
if (learningRate != null) {
federatedClient.setLearningRate(learningRate);
}
return "Configuration updated";
}
}
6. Mock Server API (for testing)
package com.federatedlearning.client;
import com.federatedlearning.model.ModelFactory;
import com.federatedlearning.model.ModelSerializer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import java.util.HashMap;
import java.util.Map;
public class MockServerAPI {
private static MultiLayerNetwork globalModel = ModelFactory.createMNISTModel();
private static Map<String, Long> clientLastUpdate = new HashMap<>();
private static int round = 0;
public static String registerClient(String clientId, String clientInfo) {
try {
clientLastUpdate.put(clientId, System.currentTimeMillis());
return ModelSerializer.modelToBase64(globalModel);
} catch (Exception e) {
throw new RuntimeException("Registration failed", e);
}
}
public static void submitUpdate(String clientId, String modelUpdate, int dataSize, double loss) {
clientLastUpdate.put(clientId, System.currentTimeMillis());
System.out.println("Mock server received update from " + clientId +
" with " + dataSize + " samples, loss: " + loss);
// Simulate occasional model updates
if (clientLastUpdate.size() >= 3) { // When 3 clients have submitted updates
simulateModelAggregation();
clientLastUpdate.clear();
}
}
public static String getLatestModel(String clientId) {
try {
return ModelSerializer.modelToBase64(globalModel);
} catch (Exception e) {
throw new RuntimeException("Failed to get model", e);
}
}
private static void simulateModelAggregation() {
round++;
System.out.println("Mock server performing aggregation round " + round);
// In real implementation, this would aggregate client updates
}
}
7. Configuration and Main Application
Application Configuration:
package com.federatedlearning;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableScheduling;
@SpringBootApplication
@EnableScheduling
public class FederatedLearningApplication {
public static void main(String[] args) {
SpringApplication.run(FederatedLearningApplication.class, args);
}
}
Client Scheduler:
package com.federatedlearning.scheduler;
import com.federatedlearning.client.FederatedClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import java.util.Random;
@Component
public class ClientTrainingScheduler {
@Autowired
private FederatedClient federatedClient;
private final Random random = new Random();
@Scheduled(fixedRate = 60000) // Run every minute
public void scheduleTraining() {
// Randomize training start to avoid all clients training simultaneously
if (random.nextDouble() < 0.3) { // 30% chance to train each minute
federatedClient.performLocalTraining();
}
}
@Scheduled(fixedRate = 300000) // Run every 5 minutes
public void scheduleModelUpdate() {
federatedClient.updateModelFromServer();
}
}
Application Properties:
# Server configuration
federated.server.port=8080
federated.server.min-clients=3
federated.server.target-rounds=100
# Client configuration
federated.client.id=client-${random.uuid}
federated.client.server-url=http://localhost:8080
federated.client.local-epochs=1
federated.client.learning-rate=0.01
# DL4J configuration
dl4j.backend=cpu
nd4j.dtype=float
8. Advanced Federated Learning Features
Differential Privacy:
package com.federatedlearning.privacy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Random;
public class DifferentialPrivacy {
public static INDArray addGaussianNoise(INDArray update, double epsilon, double delta, double sensitivity) {
double sigma = calculateNoiseScale(epsilon, delta, sensitivity);
INDArray noise = Nd4j.randn(update.shape()).mul(sigma);
return update.add(noise);
}
public static INDArray clipUpdate(INDArray update, double clipNorm) {
double currentNorm = update.norm2Number().doubleValue();
if (currentNorm > clipNorm) {
return update.mul(clipNorm / currentNorm);
}
return update;
}
private static double calculateNoiseScale(double epsilon, double delta, double sensitivity) {
return sensitivity * Math.sqrt(2 * Math.log(1.25 / delta)) / epsilon;
}
public static class PrivacyBudget {
private double epsilon;
private double delta;
private final double maxEpsilon;
public PrivacyBudget(double initialEpsilon, double delta, double maxEpsilon) {
this.epsilon = initialEpsilon;
this.delta = delta;
this.maxEpsilon = maxEpsilon;
}
public boolean canSpend(double amount) {
return (epsilon + amount) <= maxEpsilon;
}
public void spend(double amount) {
if (canSpend(amount)) {
epsilon += amount;
} else {
throw new IllegalStateException("Privacy budget exhausted");
}
}
public double getEpsilon() { return epsilon; }
public double getDelta() { return delta; }
}
}
Secure Aggregation:
package com.federatedlearning.security;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.security.SecureRandom;
import java.util.*;
public class SecureAggregation {
public static Map<String, INDArray> maskUpdates(Map<String, INDArray> clientUpdates) {
Map<String, INDArray> maskedUpdates = new HashMap<>();
SecureRandom random = new SecureRandom();
// Generate pairwise masks
List<String> clientIds = new ArrayList<>(clientUpdates.keySet());
Collections.sort(clientIds);
for (int i = 0; i < clientIds.size(); i++) {
String clientId = clientIds.get(i);
INDArray update = clientUpdates.get(clientId);
INDArray maskedUpdate = update.dup();
// Add random masks for privacy
for (int j = 0; j < clientIds.size(); j++) {
if (i != j) {
INDArray mask = generateSymmetricMask(update.shape(), random);
if (i < j) {
maskedUpdate.addi(mask);
} else {
maskedUpdate.subi(mask);
}
}
}
maskedUpdates.put(clientId, maskedUpdate);
}
return maskedUpdates;
}
private static INDArray generateSymmetricMask(long[] shape, SecureRandom random) {
INDArray mask = Nd4j.create(shape);
for (int i = 0; i < mask.length(); i++) {
mask.putScalar(i, random.nextDouble() - 0.5);
}
return mask;
}
public static INDArray reconstructAggregate(Map<String, INDArray> maskedUpdates) {
INDArray sum = null;
for (INDArray maskedUpdate : maskedUpdates.values()) {
if (sum == null) {
sum = maskedUpdate.dup();
} else {
sum.addi(maskedUpdate);
}
}
return sum != null ? sum.div(maskedUpdates.size()) : Nd4j.zeros();
}
}
Best Practices for Federated Learning
1. Privacy Protection:
- Implement differential privacy
- Use secure aggregation protocols
- Minimize data exposure
- Regular privacy audits
2. Performance Optimization:
- Compress model updates
- Implement update sparsification
- Use adaptive learning rates
- Monitor communication costs
3. Robustness:
- Handle client dropouts gracefully
- Implement model validation
- Detect and mitigate poisoning attacks
- Maintain model versioning
4. Monitoring and Debugging:
public class FLMonitor {
public static void logTrainingRound(int round, double loss, int clients) {
System.out.printf("Round %d: Loss=%.4f, Clients=%d%n", round, loss, clients);
}
public static void alertAnomaly(String clientId, double anomalyScore) {
System.err.printf("Anomaly detected: Client=%s, Score=%.4f%n", clientId, anomalyScore);
}
}
Testing Federated Learning System
package com.federatedlearning.test;
import com.federatedlearning.client.FederatedClient;
import com.federatedlearning.server.FederatedServer;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import static org.junit.jupiter.api.Assertions.*;
@SpringBootTest
public class FederatedLearningTest {
@Autowired
private FederatedServer federatedServer;
@Test
public void testServerInitialization() {
assertNotNull(federatedServer);
assertTrue(federatedServer.getRegisteredClientCount() >= 0);
}
@Test
public void testClientRegistration() {
String clientId = "test-client-1";
String model = federatedServer.registerClient(clientId, "Test Client");
assertNotNull(model);
assertTrue(federatedServer.getRegisteredClientCount() > 0);
}
@Test
public void testModelSerialization() {
// Test model serialization/deserialization
// This would verify the model can be properly transmitted
}
}
Conclusion
Federated Learning in Java enables privacy-preserving machine learning across distributed data sources. Key benefits include:
- Data Privacy: Raw data never leaves client devices
- Reduced Bandwidth: Only model updates are transmitted
- Regulatory Compliance: Meets data protection regulations
- Collaborative Learning: Enables learning from diverse data sources
Implementation Checklist:
- ✅ Design federated architecture
- ✅ Implement model serialization
- ✅ Create server aggregation logic
- ✅ Build client training pipeline
- ✅ Add privacy protections
- ✅ Implement monitoring and metrics
- ✅ Test with multiple clients
- ✅ Plan deployment strategy
Federated Learning is particularly valuable for:
- Healthcare applications with sensitive patient data
- Financial services with transaction data
- Mobile applications with user behavior data
- IoT devices with sensor data
- Cross-organizational collaborations
By following this comprehensive guide, you can build robust, scalable, and privacy-preserving federated learning systems in Java that enable collaborative machine learning while respecting data privacy.