The integration of machine learning into enterprise Java applications has traditionally been challenging due to the Python-centric nature of most ML frameworks. Deep Java Library (DJL) bridges this gap by providing a pure Java solution for running pre-trained models from various frameworks, including TensorFlow. This article explores how DJL enables Java developers to seamlessly integrate TensorFlow models into their applications without leaving the JVM ecosystem.
Why DJL for TensorFlow in Java?
Challenges with Traditional Approaches:
- TensorFlow Java API has limited functionality compared to Python
- Complex deployment and dependency management
- Lack of high-level abstractions for common tasks
- Limited model zoo and pre-trained models
DJL Advantages:
- Framework Agnostic: Supports TensorFlow, PyTorch, MXNet, and ONNX
- Pure Java: No native dependencies or Python required
- Easy API: High-level abstractions for common tasks
- Model Zoo: Access to pre-trained models
- Production Ready: Designed for enterprise deployment
Setting Up DJL with TensorFlow
Maven Dependencies:
<properties>
<djl.version>0.25.0</djl.version>
</properties>
<dependencies>
<!-- DJL Core -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${djl.version}</version>
</dependency>
<!-- TensorFlow Engine -->
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-engine</artifactId>
<version>${djl.version}</version>
</dependency>
<!-- DJL Model Zoo (optional) -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
</dependencies>
Gradle:
dependencies {
implementation("ai.djl:api:0.25.0")
implementation("ai.djl.tensorflow:tensorflow-engine:0.25.0")
implementation("ai.djl:model-zoo:0.25.0")
}
Basic Concepts and Architecture
DJL Core Components:
- Model: Container for trained parameters and computation logic
- Predictor: Interface for making predictions
- Translator: Converts between Java objects and NDArrays
- NDList/NDArray: N-dimensional array operations
- Criteria: Model loading and configuration specification
TensorFlow Integration Flow:
TensorFlow SavedModel ↓ DJL Model Zoo/Loader ↓ Translator (Input/Output Processing) ↓ Predictor (Inference Interface) ↓ Java Application
Loading and Running TensorFlow Models
Method 1: Loading from SavedModel Format
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import java.nio.file.Path;
import java.nio.file.Paths;
public class TensorFlowImageClassifier {
private static final String MODEL_DIR = "models/mobilenet_v2";
public static void main(String[] args) {
// Define the model criteria
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get(MODEL_DIR))
.optTranslator(createTranslator())
.optEngine("TensorFlow") // Specify TensorFlow engine
.build();
try (Model model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
// Load and classify image
Image image = ImageFactory.getInstance().fromUrl(
"https://resources.djl.ai/images/kitten.jpg");
Classifications classifications = predictor.predict(image);
// Print results
classifications.items().stream()
.limit(5)
.forEach(item ->
System.out.printf("%s: %.5f%n", item.getClassName(), item.getProbability()));
} catch (Exception e) {
e.printStackTrace();
}
}
private static Translator<Image, Classifications> createTranslator() {
return ImageClassificationTranslator.builder()
.addTransform(new ai.djl.modality.cv.transform.ToTensor())
.optApplySoftmax(true)
.build();
}
}
Method 2: Direct Model Loading
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.nio.file.Paths;
public class CustomTensorFlowModel {
public static void main(String[] args) {
String modelPath = "models/custom_model";
try (Model model = Model.newInstance("custom-model", "TensorFlow")) {
// Load the SavedModel
model.load(Paths.get(modelPath));
// Create custom translator
Translator<float[], float[]> translator = new CustomTranslator();
try (Predictor<float[], float[]> predictor = model.newPredictor(translator)) {
float[] input = {0.1f, 0.2f, 0.3f, 0.4f};
float[] output = predictor.predict(input);
System.out.println("Model output:");
for (int i = 0; i < output.length; i++) {
System.out.printf("Output[%d]: %.4f%n", i, output[i]);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
static class CustomTranslator implements Translator<float[], float[]> {
@Override
public NDList processInput(TranslatorContext ctx, float[] input) {
NDManager manager = ctx.getNDManager();
NDArray array = manager.create(input);
return new NDList(array);
}
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray output = list.get(0);
return output.toFloatArray();
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}
}
Using DJL Model Zoo with TensorFlow Models
DJL provides a model zoo with pre-trained TensorFlow models:
import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Paths;
public class ModelZooExample {
public static void main(String[] args) {
// Load ResNet50 from TensorFlow Hub
Criteria<Image, Classifications> criteria = Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class)
.optFilter("backbone", "resnet50")
.optFilter("flavor", "v2")
.optFilter("engine", "TensorFlow")
.optProgress(new ProgressBar())
.build();
try (ZooModel<Image, Classifications> model = criteria.loadModel();
Predictor<Image, Classifications> predictor = model.newPredictor()) {
Image image = ImageFactory.getInstance().fromUrl(
"https://resources.djl.ai/images/dog.jpg");
Classifications classifications = predictor.predict(image);
System.out.println("Top 5 predictions:");
classifications.items().stream()
.limit(5)
.forEach(item ->
System.out.printf("%s: %.5f%n", item.getClassName(), item.getProbability()));
} catch (ModelException | IOException | TranslateException e) {
e.printStackTrace();
}
}
}
Advanced: Custom TensorFlow Model with Complex Input/Output
Handling Multiple Inputs/Outputs:
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.util.Map;
public class MultiInputTensorFlowModel {
static class ModelInput {
public final float[] feature1;
public final int[] feature2;
public final float[][] sequence;
public ModelInput(float[] feature1, int[] feature2, float[][] sequence) {
this.feature1 = feature1;
this.feature2 = feature2;
this.sequence = sequence;
}
}
static class ModelOutput {
public final float[] output1;
public final float probability;
public ModelOutput(float[] output1, float probability) {
this.output1 = output1;
this.probability = probability;
}
}
static class MultiInputTranslator implements Translator<ModelInput, ModelOutput> {
@Override
public NDList processInput(TranslatorContext ctx, ModelInput input) {
NDManager manager = ctx.getNDManager();
// Convert first input feature
NDArray input1 = manager.create(input.feature1);
input1.setName("input_feature_1");
// Convert second input feature
NDArray input2 = manager.create(input.feature2, new Shape(input.feature2.length));
input2.setName("input_feature_2");
input2 = input2.toType(DataType.INT32, false);
// Convert sequence input
NDArray input3 = manager.create(input.sequence);
input3.setName("input_sequence");
return new NDList(input1, input2, input3);
}
@Override
public ModelOutput processOutput(TranslatorContext ctx, NDList list) {
// Assuming model has two outputs
NDArray output1 = list.get(0); // First output
NDArray output2 = list.get(1); // Second output
float[] output1Array = output1.toFloatArray();
float probability = output2.getFloat();
return new ModelOutput(output1Array, probability);
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}
public static void main(String[] args) {
String modelPath = "models/complex_model";
try (Model model = Model.newInstance("complex-model", "TensorFlow")) {
model.load(Paths.get(modelPath));
MultiInputTranslator translator = new MultiInputTranslator();
try (Predictor<ModelInput, ModelOutput> predictor = model.newPredictor(translator)) {
// Prepare input data
float[] feature1 = {0.1f, 0.5f, 0.8f};
int[] feature2 = {1, 3, 5, 7};
float[][] sequence = {
{0.1f, 0.2f, 0.3f},
{0.4f, 0.5f, 0.6f},
{0.7f, 0.8f, 0.9f}
};
ModelInput input = new ModelInput(feature1, feature2, sequence);
ModelOutput output = predictor.predict(input);
System.out.println("Model Output:");
System.out.printf("Output array length: %d%n", output.output1.length);
System.out.printf("Probability: %.4f%n", output.probability);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
Batch Processing for High Performance
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
public class BatchInferenceExample {
static class BatchTranslator implements Translator<List<float[]>, List<float[]>> {
@Override
public NDList processInput(TranslatorContext ctx, List<float[]> batch) {
NDManager manager = ctx.getNDManager();
// Stack all arrays into a single batch
NDArray[] arrays = batch.stream()
.map(data -> manager.create(data))
.toArray(NDArray[]::new);
NDArray batchArray = manager.stack(arrays);
return new NDList(batchArray);
}
@Override
public List<float[]> processOutput(TranslatorContext ctx, NDList list) {
NDArray batchOutput = list.get(0);
// Split batch into individual results
return Arrays.stream(batchOutput.split(batchOutput.size()))
.map(NDArray::toFloatArray)
.collect(Collectors.toList());
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}
public static void main(String[] args) {
String modelPath = "models/batch_model";
try (Model model = Model.newInstance("batch-model", "TensorFlow")) {
model.load(Paths.get(modelPath));
BatchTranslator translator = new BatchTranslator();
try (Predictor<List<float[]>, List<float[]>> predictor = model.newPredictor(translator)) {
// Create batch of inputs
List<float[]> batchInputs = Arrays.asList(
new float[]{0.1f, 0.2f, 0.3f},
new float[]{0.4f, 0.5f, 0.6f},
new float[]{0.7f, 0.8f, 0.9f},
new float[]{1.0f, 1.1f, 1.2f}
);
long startTime = System.currentTimeMillis();
List<float[]> batchOutputs = predictor.predict(batchInputs);
long endTime = System.currentTimeMillis();
System.out.printf("Processed %d items in %d ms%n",
batchInputs.size(), (endTime - startTime));
for (int i = 0; i < batchOutputs.size(); i++) {
System.out.printf("Output %d: %s%n", i, Arrays.toString(batchOutputs.get(i)));
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
Real-World Example: Sentiment Analysis with TensorFlow
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
public class SentimentAnalysisExample {
static class SentimentInput {
public final String text;
public SentimentInput(String text) {
this.text = text;
}
}
static class SentimentOutput {
public final float positiveScore;
public final float negativeScore;
public final String sentiment;
public SentimentOutput(float positiveScore, float negativeScore) {
this.positiveScore = positiveScore;
this.negativeScore = negativeScore;
this.sentiment = positiveScore > negativeScore ? "POSITIVE" : "NEGATIVE";
}
}
static class SentimentTranslator implements Translator<SentimentInput, SentimentOutput> {
private BertFullTokenizer tokenizer;
private DefaultVocabulary vocabulary;
@Override
public void prepare(TranslatorContext ctx) {
// Initialize tokenizer and vocabulary
// This would typically load from files
vocabulary = DefaultVocabulary.builder()
.add(Arrays.asList("[CLS]", "[SEP]", "hello", "world", "good", "bad", "great", "terrible"))
.optUnknownToken("[UNK]")
.build();
tokenizer = new BertFullTokenizer(vocabulary, false);
}
@Override
public NDList processInput(TranslatorContext ctx, SentimentInput input) {
NDManager manager = ctx.getNDManager();
// Tokenize text
List<String> tokens = tokenizer.tokenize(input.text);
tokens.add(0, "[CLS]");
tokens.add("[SEP]");
// Convert tokens to indices
long[] indices = tokens.stream()
.mapToLong(token -> vocabulary.getIndex(token))
.toArray();
// Create input array
NDArray inputArray = manager.create(indices);
inputArray = inputArray.reshape(1, -1); // Add batch dimension
return new NDList(inputArray);
}
@Override
public SentimentOutput processOutput(TranslatorContext ctx, NDList list) {
NDArray logits = list.get(0);
float[] scores = logits.squeeze().toFloatArray();
// Apply softmax
float expSum = (float) (Math.exp(scores[0]) + Math.exp(scores[1]));
float positiveScore = (float) Math.exp(scores[0]) / expSum;
float negativeScore = (float) Math.exp(scores[1]) / expSum;
return new SentimentOutput(positiveScore, negativeScore);
}
}
public static void main(String[] args) {
String modelPath = "models/sentiment_model";
try (Model model = Model.newInstance("sentiment-model", "TensorFlow")) {
model.load(Paths.get(modelPath));
SentimentTranslator translator = new SentimentTranslator();
try (Predictor<SentimentInput, SentimentOutput> predictor = model.newPredictor(translator)) {
String[] texts = {
"This movie is absolutely fantastic!",
"I hated every minute of this film.",
"It was okay, nothing special.",
"Brilliant acting and great storyline!"
};
for (String text : texts) {
SentimentInput input = new SentimentInput(text);
SentimentOutput output = predictor.predict(input);
System.out.printf("Text: %s%n", text);
System.out.printf("Sentiment: %s (Positive: %.3f, Negative: %.3f)%n%n",
output.sentiment, output.positiveScore, output.negativeScore);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
Performance Optimization and Best Practices
1. Resource Management:
// Always use try-with-resources
try (Model model = Model.newInstance("model", "TensorFlow");
Predictor<Input, Output> predictor = model.newPredictor(translator)) {
// Use predictor
} catch (Exception e) {
// Handle exception
}
2. NDManager Management:
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array = manager.create(new float[]{1.0f, 2.0f, 3.0f});
// Operations on array
} // Automatically closed and memory freed
3. Model Caching:
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelPath(Paths.get(MODEL_DIR))
.optOption("modelType", "saved_model")
.optOption("tags", "serve") // TensorFlow serving tag
.build();
4. GPU Acceleration:
// Check for GPU availability
Engine engine = Engine.getInstance();
System.out.println("Available engines: " + engine.getAllEngines());
System.out.println("GPU supported: " + engine.getGpuCount());
// Use GPU if available
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optDevice(Device.gpu()) // Prefer GPU
.optModelPath(Paths.get(MODEL_DIR))
.build();
Troubleshooting Common Issues
1. Model Loading Errors:
try {
model.load(Paths.get(modelPath));
} catch (MalformedModelException e) {
System.err.println("Invalid model format: " + e.getMessage());
} catch (IOException e) {
System.err.println("I/O error loading model: " + e.getMessage());
}
2. Input/Output Shape Mismatches:
// Debug input/output shapes
NDList processInput(TranslatorContext ctx, Input input) {
NDArray array = // ... create array
System.out.println("Input shape: " + array.getShape());
return new NDList(array);
}
3. Memory Management:
// Monitor memory usage
Runtime runtime = Runtime.getRuntime();
long usedMemory = runtime.totalMemory() - runtime.freeMemory();
System.out.printf("Used memory: %.2f MB%n", usedMemory / 1024.0 / 1024.0);
Conclusion
DJL provides a powerful, enterprise-ready solution for integrating TensorFlow models into Java applications:
- Seamless Integration: Pure Java implementation with no Python dependencies
- Framework Agnostic: Support for multiple ML frameworks
- High Performance: GPU acceleration and batch processing
- Production Ready: Robust error handling and resource management
- Developer Friendly: Intuitive API and comprehensive documentation
By leveraging DJL, Java developers can:
- Deploy TensorFlow models in production environments
- Build ML-powered microservices
- Integrate AI capabilities into existing Java applications
- Maintain consistent development workflows across teams
Whether you're building real-time inference services, batch processing pipelines, or intelligent applications, DJL with TensorFlow provides the tools and performance needed for enterprise AI solutions.