ONNX Model Inference in Java: Complete Guide

ONNX (Open Neural Network Exchange) is an open format for machine learning models that enables interoperability between different frameworks. Java applications can perform ONNX model inference using several libraries, with Microsoft's ONNX Runtime being the most popular choice.


ONNX Runtime for Java

What is ONNX Runtime?

ONNX Runtime is a cross-platform inference engine that runs models from various ML frameworks in ONNX format. It provides Java bindings for deploying models in Java applications.

Key Features

  • Cross-platform: Windows, Linux, macOS
  • Hardware acceleration: CPU, GPU, TensorRT, OpenVINO
  • Multiple execution providers: CPU, CUDA, DirectML, TensorRT
  • Memory efficient: Optimized memory usage
  • Thread-safe: Supports concurrent inference

Dependencies

Maven Dependencies

<properties>
<onnxruntime.version>1.16.0</onnxruntime.version>
</properties>
<dependencies>
<!-- ONNX Runtime Core -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>${onnxruntime.version}</version>
</dependency>
<!-- GPU Support (Optional) -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime_gpu</artifactId>
<version>${onnxruntime.version}</version>
</dependency>
</dependencies>

Gradle Dependencies

dependencies {
implementation 'com.microsoft.onnxruntime:onnxruntime:1.16.0'
// For GPU support
implementation 'com.microsoft.onnxruntime:onnxruntime_gpu:1.16.0'
}

Basic ONNX Inference

Example 1: Basic ONNX Model Inference

import ai.onnxruntime.*;
import java.util.*;
public class BasicONNXInference {
public static void main(String[] args) {
String modelPath = "model.onnx";
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(modelPath, new OrtSession.SessionOptions())) {
// Get model info
printModelInfo(session);
// Prepare input data
Map<String, OnnxTensor> inputs = prepareInputs(env, session);
// Run inference
OrtSession.Result results = session.run(inputs);
// Process outputs
processOutputs(results);
} catch (OrtException e) {
e.printStackTrace();
}
}
private static void printModelInfo(OrtSession session) throws OrtException {
System.out.println("=== Model Information ===");
// Input information
System.out.println("Inputs:");
Map<String, NodeInfo> inputInfo = session.getInputInfo();
for (Map.Entry<String, NodeInfo> entry : inputInfo.entrySet()) {
NodeInfo info = entry.getValue();
System.out.println("  Name: " + entry.getKey());
System.out.println("  Type: " + info.getType());
System.out.println("  Shape: " + Arrays.toString(info.getInfo().getShape()));
}
// Output information
System.out.println("Outputs:");
Map<String, NodeInfo> outputInfo = session.getOutputInfo();
for (Map.Entry<String, NodeInfo> entry : outputInfo.entrySet()) {
NodeInfo info = entry.getValue();
System.out.println("  Name: " + entry.getKey());
System.out.println("  Type: " + info.getType());
}
}
private static Map<String, OnnxTensor> prepareInputs(OrtEnvironment env, OrtSession session) 
throws OrtException {
Map<String, NodeInfo> inputInfo = session.getInputInfo();
Map<String, OnnxTensor> inputs = new HashMap<>();
for (Map.Entry<String, NodeInfo> entry : inputInfo.entrySet()) {
String inputName = entry.getKey();
NodeInfo info = entry.getValue();
// Example: Create random input based on expected shape
long[] shape = info.getInfo().getShape();
System.out.println("Preparing input: " + inputName + " with shape: " + Arrays.toString(shape));
// Create sample data (replace with actual data)
float[] inputData = createSampleData(shape);
OnnxTensor tensor = OnnxTensor.createTensor(env, inputData, shape);
inputs.put(inputName, tensor);
}
return inputs;
}
private static float[] createSampleData(long[] shape) {
// Calculate total elements
long totalElements = 1;
for (long dim : shape) {
totalElements *= dim;
}
// Create sample data (random values between 0 and 1)
float[] data = new float[(int) totalElements];
Random random = new Random(42); // Fixed seed for reproducibility
for (int i = 0; i < data.length; i++) {
data[i] = random.nextFloat();
}
return data;
}
private static void processOutputs(OrtSession.Result results) throws OrtException {
System.out.println("=== Inference Results ===");
for (Map.Entry<String, OnnxTensor> output : results) {
String outputName = output.getKey();
OnnxTensor tensor = output.getValue();
System.out.println("Output: " + outputName);
System.out.println("Type: " + tensor.getInfo().type.toString());
System.out.println("Shape: " + Arrays.toString(tensor.getInfo().getShape()));
// Extract and print values based on tensor type
Object values = tensor.getValue();
printTensorValues(values);
}
}
private static void printTensorValues(Object values) {
if (values instanceof float[]) {
float[] floatValues = (float[]) values;
System.out.println("Values: " + Arrays.toString(
Arrays.copyOf(floatValues, Math.min(10, floatValues.length))));
} else if (values instanceof float[][]) {
float[][] floatValues = (float[][]) values;
System.out.println("First row: " + Arrays.toString(
Arrays.copyOf(floatValues[0], Math.min(10, floatValues[0].length))));
} else if (values instanceof long[]) {
long[] longValues = (long[]) values;
System.out.println("Values: " + Arrays.toString(
Arrays.copyOf(longValues, Math.min(10, longValues.length))));
} else {
System.out.println("Unsupported tensor type: " + values.getClass().getSimpleName());
}
}
}

Image Classification Example

Example 2: Image Classification with ONNX

import ai.onnxruntime.*;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.FloatBuffer;
import java.util.*;
public class ImageClassifier {
private final OrtEnvironment env;
private final OrtSession session;
private final List<String> labels;
private final int inputWidth;
private final int inputHeight;
private final int inputChannels;
private final float[] mean;
private final float[] std;
public ImageClassifier(String modelPath, String labelsPath, 
int width, int height, int channels) throws OrtException, IOException {
this.env = OrtEnvironment.getEnvironment();
// Configure session options
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
sessionOptions.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL);
this.session = env.createSession(modelPath, sessionOptions);
this.labels = loadLabels(labelsPath);
this.inputWidth = width;
this.inputHeight = height;
this.inputChannels = channels;
// ImageNet normalization parameters
this.mean = new float[]{0.485f, 0.456f, 0.406f};
this.std = new float[]{0.229f, 0.224f, 0.225f};
}
private List<String> loadLabels(String labelsPath) throws IOException {
List<String> labels = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(new FileReader(labelsPath))) {
String line;
while ((line = reader.readLine()) != null) {
labels.add(line.trim());
}
}
return labels;
}
public ClassificationResult classify(File imageFile) throws OrtException, IOException {
// Preprocess image
float[] processedImage = preprocessImage(imageFile);
// Create input tensor
long[] shape = {1, inputChannels, inputHeight, inputWidth}; // NCHW format
OnnxTensor inputTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(processedImage), shape);
// Run inference
Map<String, OnnxTensor> inputs = Collections.singletonMap("input", inputTensor);
OrtSession.Result results = session.run(inputs);
// Process results
return processClassificationResults(results);
}
private float[] preprocessImage(File imageFile) throws IOException {
BufferedImage image = ImageIO.read(imageFile);
// Resize image
BufferedImage resizedImage = resizeImage(image, inputWidth, inputHeight);
// Convert to float array and normalize
float[] floatArray = new float[inputChannels * inputWidth * inputHeight];
int index = 0;
for (int c = 0; c < inputChannels; c++) {
for (int y = 0; y < inputHeight; y++) {
for (int x = 0; x < inputWidth; x++) {
Color color = new Color(resizedImage.getRGB(x, y));
float pixelValue;
switch (c) {
case 0 -> pixelValue = color.getRed();   // R channel
case 1 -> pixelValue = color.getGreen(); // G channel
case 2 -> pixelValue = color.getBlue();  // B channel
default -> pixelValue = 0;
}
// Normalize: (pixel / 255 - mean) / std
float normalizedValue = (pixelValue / 255.0f - mean[c]) / std[c];
floatArray[index++] = normalizedValue;
}
}
}
return floatArray;
}
private BufferedImage resizeImage(BufferedImage originalImage, int targetWidth, int targetHeight) {
BufferedImage resizedImage = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = resizedImage.createGraphics();
g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g.drawImage(originalImage, 0, 0, targetWidth, targetHeight, null);
g.dispose();
return resizedImage;
}
private ClassificationResult processClassificationResults(OrtSession.Result results) throws OrtException {
OnnxTensor outputTensor = (OnnxTensor) results.get(0);
float[][] probabilities = (float[][]) outputTensor.getValue();
// Get top predictions
List<Prediction> predictions = getTopPredictions(probabilities[0], 5);
return new ClassificationResult(predictions);
}
private List<Prediction> getTopPredictions(float[] probabilities, int topK) {
PriorityQueue<Prediction> pq = new PriorityQueue<>(
(a, b) -> Float.compare(b.probability, a.probability)
);
for (int i = 0; i < probabilities.length; i++) {
pq.offer(new Prediction(labels.get(i), i, probabilities[i]));
}
List<Prediction> topPredictions = new ArrayList<>();
for (int i = 0; i < topK && !pq.isEmpty(); i++) {
topPredictions.add(pq.poll());
}
return topPredictions;
}
public void close() throws OrtException {
session.close();
env.close();
}
// Data classes
public static class ClassificationResult {
private final List<Prediction> predictions;
private final long timestamp;
public ClassificationResult(List<Prediction> predictions) {
this.predictions = predictions;
this.timestamp = System.currentTimeMillis();
}
public List<Prediction> getPredictions() { return predictions; }
public long getTimestamp() { return timestamp; }
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Classification Results:\n");
for (int i = 0; i < predictions.size(); i++) {
Prediction p = predictions.get(i);
sb.append(String.format("%d. %s (%.4f)\n", i + 1, p.label, p.probability));
}
return sb.toString();
}
}
public static class Prediction {
public final String label;
public final int classId;
public final float probability;
public Prediction(String label, int classId, float probability) {
this.label = label;
this.classId = classId;
this.probability = probability;
}
}
public static void main(String[] args) {
try {
ImageClassifier classifier = new ImageClassifier(
"resnet50.onnx", 
"imagenet_classes.txt", 
224, 224, 3
);
File imageFile = new File("test_image.jpg");
ClassificationResult result = classifier.classify(imageFile);
System.out.println(result);
classifier.close();
} catch (Exception e) {
e.printStackTrace();
}
}
}

Advanced ONNX Features

Example 3: Multiple Inputs/Outputs and Batching

import ai.onnxruntime.*;
import java.util.*;
public class AdvancedONNXInference {
private final OrtEnvironment env;
private final OrtSession session;
private final SessionConfig config;
public AdvancedONNXInference(String modelPath, SessionConfig config) throws OrtException {
this.env = OrtEnvironment.getEnvironment();
this.config = config;
OrtSession.SessionOptions sessionOptions = createSessionOptions();
this.session = env.createSession(modelPath, sessionOptions);
}
private OrtSession.SessionOptions createSessionOptions() throws OrtException {
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
// Set optimization level
switch (config.optimizationLevel) {
case HIGH -> sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
case MEDIUM -> sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.BASIC_OPT);
case LOW -> sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.NO_OPT);
}
// Set execution mode
if (config.enableParallelExecution) {
sessionOptions.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.PARALLEL);
sessionOptions.setInterOpNumThreads(config.interOpThreads);
sessionOptions.setIntraOpNumThreads(config.intraOpThreads);
}
// Set memory pattern
if (config.enableMemoryPattern) {
sessionOptions.enableMemoryPattern(true);
}
// Set CPU affinity (if needed)
if (config.cpuAffinity != null) {
sessionOptions.setCPUArenaAllocator(config.cpuArenaAllocator);
}
return sessionOptions;
}
public BatchInferenceResult runBatchInference(List<Map<String, float[]>> batchInputs) 
throws OrtException {
List<Map<String, OnnxTensor>> batchTensors = new ArrayList<>();
List<OrtSession.Result> batchResults = new ArrayList<>();
try {
// Process each input in batch
for (Map<String, float[]> inputData : batchInputs) {
Map<String, OnnxTensor> inputTensors = createInputTensors(inputData);
batchTensors.add(inputTensors);
// Run inference
OrtSession.Result result = session.run(inputTensors);
batchResults.add(result);
}
// Process batch results
return processBatchResults(batchResults);
} finally {
// Clean up tensors
for (Map<String, OnnxTensor> tensors : batchTensors) {
for (OnnxTensor tensor : tensors.values()) {
tensor.close();
}
}
for (OrtSession.Result result : batchResults) {
result.close();
}
}
}
private Map<String, OnnxTensor> createInputTensors(Map<String, float[]> inputData) 
throws OrtException {
Map<String, OnnxTensor> inputTensors = new HashMap<>();
for (Map.Entry<String, float[]> entry : inputData.entrySet()) {
String inputName = entry.getKey();
float[] data = entry.getValue();
// Get expected shape from model (you might need to adjust this)
long[] shape = getExpectedShape(inputName, data.length);
OnnxTensor tensor = OnnxTensor.createTensor(env, data, shape);
inputTensors.put(inputName, tensor);
}
return inputTensors;
}
private long[] getExpectedShape(String inputName, int dataLength) throws OrtException {
// This is a simplified approach - in practice, you should get shape from model metadata
Map<String, NodeInfo> inputInfo = session.getInputInfo();
NodeInfo info = inputInfo.get(inputName);
if (info != null) {
return info.getInfo().getShape();
}
// Fallback: assume 1D tensor
return new long[]{1, dataLength};
}
private BatchInferenceResult processBatchResults(List<OrtSession.Result> batchResults) 
throws OrtException {
List<Map<String, float[]>> batchOutputs = new ArrayList<>();
Map<String, Statistics> outputStatistics = new HashMap<>();
for (OrtSession.Result result : batchResults) {
Map<String, float[]> output = new HashMap<>();
for (Map.Entry<String, OnnxTensor> entry : result) {
String outputName = entry.getKey();
OnnxTensor tensor = entry.getValue();
float[] values = (float[]) tensor.getValue();
output.put(outputName, values);
// Update statistics
updateStatistics(outputStatistics, outputName, values);
}
batchOutputs.add(output);
}
return new BatchInferenceResult(batchOutputs, outputStatistics);
}
private void updateStatistics(Map<String, Statistics> statistics, String outputName, float[] values) {
Statistics stats = statistics.computeIfAbsent(outputName, k -> new Statistics());
for (float value : values) {
stats.update(value);
}
}
public void close() throws OrtException {
session.close();
env.close();
}
// Configuration and Result classes
public static class SessionConfig {
public OptimizationLevel optimizationLevel = OptimizationLevel.HIGH;
public boolean enableParallelExecution = true;
public int interOpThreads = 2;
public int intraOpThreads = 4;
public boolean enableMemoryPattern = true;
public boolean cpuArenaAllocator = true;
public int[] cpuAffinity = null;
public enum OptimizationLevel {
LOW, MEDIUM, HIGH
}
}
public static class BatchInferenceResult {
public final List<Map<String, float[]>> batchOutputs;
public final Map<String, Statistics> statistics;
public final long inferenceTime;
public BatchInferenceResult(List<Map<String, float[]>> batchOutputs, 
Map<String, Statistics> statistics) {
this.batchOutputs = batchOutputs;
this.statistics = statistics;
this.inferenceTime = System.currentTimeMillis();
}
}
public static class Statistics {
private double sum = 0;
private double sumSquares = 0;
private int count = 0;
private float min = Float.MAX_VALUE;
private float max = Float.MIN_VALUE;
public void update(float value) {
sum += value;
sumSquares += value * value;
count++;
min = Math.min(min, value);
max = Math.max(max, value);
}
public float getMean() { return (float) (sum / count); }
public float getStdDev() { 
double mean = sum / count;
return (float) Math.sqrt((sumSquares / count) - (mean * mean));
}
public float getMin() { return min; }
public float getMax() { return max; }
public int getCount() { return count; }
}
}

Performance Optimization

Example 4: Performance Monitoring and Optimization

import ai.onnxruntime.*;
import java.util.*;
public class PerformanceOptimizedInference {
private final OrtEnvironment env;
private final OrtSession session;
private final PerformanceMonitor monitor;
public PerformanceOptimizedInference(String modelPath, boolean useGPU) throws OrtException {
this.env = OrtEnvironment.getEnvironment();
this.monitor = new PerformanceMonitor();
OrtSession.SessionOptions sessionOptions = createOptimizedSessionOptions(useGPU);
this.session = env.createSession(modelPath, sessionOptions);
}
private OrtSession.SessionOptions createOptimizedSessionOptions(boolean useGPU) throws OrtException {
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
// Performance optimizations
sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
sessionOptions.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.PARALLEL);
sessionOptions.setInterOpNumThreads(2);
sessionOptions.setIntraOpNumThreads(Runtime.getRuntime().availableProcessors());
sessionOptions.enableMemoryPattern(true);
sessionOptions.setCPUArenaAllocator(true);
// GPU acceleration if available and requested
if (useGPU && isGPUSupported()) {
try {
sessionOptions.addCUDA(0); // Use first GPU
System.out.println("GPU acceleration enabled");
} catch (OrtException e) {
System.out.println("GPU not available, falling back to CPU: " + e.getMessage());
}
}
return sessionOptions;
}
private boolean isGPUSupported() {
try {
// Check if CUDA is available
OrtSession.SessionOptions testOptions = new OrtSession.SessionOptions();
testOptions.addCUDA(0);
testOptions.close();
return true;
} catch (OrtException e) {
return false;
}
}
public InferenceResult runInferenceWithMonitoring(Map<String, float[]> inputs) throws OrtException {
long startTime = System.nanoTime();
try {
// Create input tensors
Map<String, OnnxTensor> inputTensors = createInputTensors(inputs);
long tensorCreationTime = System.nanoTime();
// Run inference
OrtSession.Result results = session.run(inputTensors);
long inferenceTime = System.nanoTime();
// Process results
Map<String, float[]> outputs = extractOutputs(results);
long processingTime = System.nanoTime();
// Record metrics
monitor.recordInference(
tensorCreationTime - startTime,
inferenceTime - tensorCreationTime,
processingTime - inferenceTime
);
return new InferenceResult(outputs, monitor.getLatestMetrics());
} catch (OrtException e) {
monitor.recordError();
throw e;
}
}
private Map<String, OnnxTensor> createInputTensors(Map<String, float[]> inputs) throws OrtException {
Map<String, OnnxTensor> tensors = new HashMap<>();
for (Map.Entry<String, float[]> entry : inputs.entrySet()) {
// In practice, you'd get the actual shape from model metadata
long[] shape = inferShape(entry.getValue());
OnnxTensor tensor = OnnxTensor.createTensor(env, entry.getValue(), shape);
tensors.put(entry.getKey(), tensor);
}
return tensors;
}
private long[] inferShape(float[] data) {
// Simplified - in practice, use model metadata
return new long[]{1, data.length};
}
private Map<String, float[]> extractOutputs(OrtSession.Result results) throws OrtException {
Map<String, float[]> outputs = new HashMap<>();
for (Map.Entry<String, OnnxTensor> entry : results) {
outputs.put(entry.getKey(), (float[]) entry.getValue().getValue());
}
return outputs;
}
public PerformanceStats getPerformanceStats() {
return monitor.getStatistics();
}
public void close() throws OrtException {
session.close();
env.close();
}
// Performance monitoring classes
public static class PerformanceMonitor {
private final List<Long> tensorCreationTimes = new ArrayList<>();
private final List<Long> inferenceTimes = new ArrayList<>();
private final List<Long> processingTimes = new ArrayList<>();
private int errorCount = 0;
private int totalInferences = 0;
public void recordInference(long tensorCreationTime, long inferenceTime, long processingTime) {
tensorCreationTimes.add(tensorCreationTime);
inferenceTimes.add(inferenceTime);
processingTimes.add(processingTime);
totalInferences++;
}
public void recordError() {
errorCount++;
}
public InferenceMetrics getLatestMetrics() {
if (inferenceTimes.isEmpty()) return new InferenceMetrics(0, 0, 0);
int lastIndex = inferenceTimes.size() - 1;
return new InferenceMetrics(
tensorCreationTimes.get(lastIndex),
inferenceTimes.get(lastIndex),
processingTimes.get(lastIndex)
);
}
public PerformanceStats getStatistics() {
return new PerformanceStats(
calculateStats(tensorCreationTimes),
calculateStats(inferenceTimes),
calculateStats(processingTimes),
totalInferences,
errorCount
);
}
private TimeStats calculateStats(List<Long> times) {
if (times.isEmpty()) return new TimeStats(0, 0, 0, 0);
long sum = 0;
long min = Long.MAX_VALUE;
long max = Long.MIN_VALUE;
for (long time : times) {
sum += time;
min = Math.min(min, time);
max = Math.max(max, time);
}
double average = (double) sum / times.size();
return new TimeStats(average, min, max, times.size());
}
}
public static class InferenceResult {
public final Map<String, float[]> outputs;
public final InferenceMetrics metrics;
public InferenceResult(Map<String, float[]> outputs, InferenceMetrics metrics) {
this.outputs = outputs;
this.metrics = metrics;
}
}
public static class InferenceMetrics {
public final long tensorCreationTimeNs;
public final long inferenceTimeNs;
public final long processingTimeNs;
public final long totalTimeNs;
public InferenceMetrics(long tensorCreationTimeNs, long inferenceTimeNs, long processingTimeNs) {
this.tensorCreationTimeNs = tensorCreationTimeNs;
this.inferenceTimeNs = inferenceTimeNs;
this.processingTimeNs = processingTimeNs;
this.totalTimeNs = tensorCreationTimeNs + inferenceTimeNs + processingTimeNs;
}
}
public static class PerformanceStats {
public final TimeStats tensorCreationStats;
public final TimeStats inferenceStats;
public final TimeStats processingStats;
public final int totalInferences;
public final int errorCount;
public final double errorRate;
public PerformanceStats(TimeStats tensorCreationStats, TimeStats inferenceStats, 
TimeStats processingStats, int totalInferences, int errorCount) {
this.tensorCreationStats = tensorCreationStats;
this.inferenceStats = inferenceStats;
this.processingStats = processingStats;
this.totalInferences = totalInferences;
this.errorCount = errorCount;
this.errorRate = totalInferences > 0 ? (double) errorCount / totalInferences : 0;
}
}
public static class TimeStats {
public final double averageNs;
public final long minNs;
public final long maxNs;
public final int count;
public TimeStats(double averageNs, long minNs, long maxNs, int count) {
this.averageNs = averageNs;
this.minNs = minNs;
this.maxNs = maxNs;
this.count = count;
}
public double getAverageMs() { return averageNs / 1_000_000.0; }
public double getMinMs() { return minNs / 1_000_000.0; }
public double getMaxMs() { return maxNs / 1_000_000.0; }
}
}

Best Practices

1. Resource Management

public class ONNXResourceManager implements AutoCloseable {
private final OrtEnvironment env;
private final OrtSession session;
private final List<OnnxTensor> allocatedTensors = new ArrayList<>();
public ONNXResourceManager(String modelPath) throws OrtException {
this.env = OrtEnvironment.getEnvironment();
this.session = env.createSession(modelPath, new OrtSession.SessionOptions());
}
public OnnxTensor createTensor(float[] data, long[] shape) throws OrtException {
OnnxTensor tensor = OnnxTensor.createTensor(env, data, shape);
allocatedTensors.add(tensor);
return tensor;
}
@Override
public void close() throws OrtException {
// Close all allocated tensors
for (OnnxTensor tensor : allocatedTensors) {
tensor.close();
}
session.close();
env.close();
}
}

2. Error Handling

public class SafeONNXInference {
public static Optional<InferenceResult> safeInference(
String modelPath, Map<String, float[]> inputs) {
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession(modelPath, new OrtSession.SessionOptions())) {
Map<String, OnnxTensor> inputTensors = new HashMap<>();
try {
// Create input tensors
for (Map.Entry<String, float[]> entry : inputs.entrySet()) {
OnnxTensor tensor = OnnxTensor.createTensor(env, entry.getValue(), 
new long[]{1, entry.getValue().length});
inputTensors.put(entry.getKey(), tensor);
}
// Run inference
OrtSession.Result results = session.run(inputTensors);
return Optional.of(processResults(results));
} finally {
// Clean up input tensors
for (OnnxTensor tensor : inputTensors.values()) {
tensor.close();
}
}
} catch (OrtException e) {
System.err.println("Inference failed: " + e.getMessage());
return Optional.empty();
}
}
// ... processResults method
}

Common Issues and Solutions

1. Shape Mismatch

public class ShapeValidator {
public static void validateInputShape(OrtSession session, 
String inputName, 
long[] expectedShape) throws OrtException {
Map<String, NodeInfo> inputInfo = session.getInputInfo();
NodeInfo info = inputInfo.get(inputName);
if (info == null) {
throw new IllegalArgumentException("Input not found: " + inputName);
}
long[] actualShape = info.getInfo().getShape();
if (!Arrays.equals(expectedShape, actualShape)) {
throw new IllegalArgumentException(
String.format("Shape mismatch for %s. Expected: %s, Actual: %s",
inputName, Arrays.toString(expectedShape), Arrays.toString(actualShape)));
}
}
}

2. Memory Management

public class MemoryAwareInference {
private static final long MAX_MEMORY_BYTES = 512 * 1024 * 1024; // 512MB
public boolean canRunInference(long estimatedMemory) {
Runtime runtime = Runtime.getRuntime();
long usedMemory = runtime.totalMemory() - runtime.freeMemory();
long availableMemory = runtime.maxMemory() - usedMemory;
return availableMemory > estimatedMemory && availableMemory > MAX_MEMORY_BYTES;
}
}

Conclusion

ONNX model inference in Java provides:

Key Benefits:

  • Framework Interoperability: Run models from PyTorch, TensorFlow, etc.
  • Performance: Optimized inference with hardware acceleration
  • Cross-platform: Consistent behavior across different systems
  • Production Ready: Battle-tested in enterprise environments

Use Cases:

  • Real-time Inference: Web services, mobile apps
  • Batch Processing: Data pipelines, ETL processes
  • Edge Computing: IoT devices, embedded systems
  • Microservices: Cloud-native applications

Performance Tips:

  1. Use appropriate execution providers (CPU/GPU)
  2. Enable memory patterns and arena allocators
  3. Use batching for throughput optimization
  4. Monitor memory usage and performance metrics
  5. Pre-allocate tensors when possible

ONNX Runtime for Java enables seamless integration of machine learning models into Java applications, making it an excellent choice for enterprises looking to leverage AI capabilities in their existing Java ecosystem.


Next Steps: Explore ONNX model optimization techniques, investigate different execution providers (TensorRT, OpenVINO), or implement model versioning and A/B testing for production deployments.

Leave a Reply

Your email address will not be published. Required fields are marked *


Macro Nepal Helper