GadaaLabs
Machine Learning Engineering
Lesson 5

Model Packaging and Export

14 min

A trained .pt checkpoint is not a deployable artifact. It requires the exact PyTorch version, the exact Python class definition, and the exact random seed state to reconstruct. ONNX and TorchScript decouple the computation graph from the Python runtime, enabling deployment to C++ servers, mobile devices, and inference runtimes that outperform plain PyTorch by 2–5x on CPU.

Export Formats Compared

| Format | Portability | Speed vs eager | Requirements | |---|---|---|---| | PyTorch eager (.pt state dict) | Python only | Baseline | Model class definition | | TorchScript (.pt traced/scripted) | C++ + Python | 1.2–2x | Supported ops only | | ONNX (.onnx) | Any ONNX runtime | 1.5–4x | No dynamic control flow | | ONNX + quantized INT8 | Any ONNX runtime | 3–6x | Slight accuracy trade-off |

Exporting to ONNX

python
import torch
import torch.onnx

model.eval()
dummy_input = torch.randn(1, 3, 224, 224)  # match training input shape

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["logits"],
    dynamic_axes={
        "input":  {0: "batch_size"},   # allow variable batch at runtime
        "logits": {0: "batch_size"},
    },
)
print("ONNX export complete")

dynamic_axes is critical: without it, the exported graph is locked to batch size 1 and will fail on any other batch.

Validating the ONNX Graph

python
import onnx
import onnxruntime as ort
import numpy as np

# Structural validation
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)

# Numerical validation — outputs must match PyTorch within float32 tolerance
session = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
np_input = dummy_input.numpy()
ort_outputs = session.run(None, {"input": np_input})

with torch.no_grad():
    pt_outputs = model(dummy_input).numpy()

np.testing.assert_allclose(pt_outputs, ort_outputs[0], rtol=1e-3, atol=1e-5)
print("Outputs match within tolerance")

If the numerical check fails, the graph optimiser has introduced numerical drift — reduce atol tolerance or disable constant folding.

TorchScript Export

python
# Option 1: tracing (faster, less flexible — no dynamic control flow)
scripted = torch.jit.trace(model, dummy_input)
scripted.save("model_traced.pt")

# Option 2: scripting (supports if/for/while inside the model)
scripted = torch.jit.script(model)
scripted.save("model_scripted.pt")

# Load in C++ or Python
loaded = torch.jit.load("model_scripted.pt")

Prefer scripting when your model has conditional logic. Use tracing for pure feed-forward networks — it is faster to export and produces cleaner graphs.

Benchmarking Inference Latency

python
import time

def benchmark(runner, input_tensor, n_warmup=50, n_runs=500):
    # Warmup
    for _ in range(n_warmup):
        _ = runner(input_tensor)

    start = time.perf_counter()
    for _ in range(n_runs):
        _ = runner(input_tensor)
    elapsed = time.perf_counter() - start

    return (elapsed / n_runs) * 1000  # ms per inference

eager_ms = benchmark(lambda x: model(x), dummy_input)
ort_ms   = benchmark(
    lambda x: session.run(None, {"input": x.numpy()}),
    dummy_input,
)
print(f"PyTorch eager: {eager_ms:.2f} ms")
print(f"ONNX Runtime:  {ort_ms:.2f} ms  ({eager_ms/ort_ms:.1f}x speedup)")

Always warm up before benchmarking — the first several calls load kernels and JIT-compile layers, inflating measured latency.

Summary

  • State dicts alone are not deployable artifacts; export to ONNX or TorchScript for portability and speed.
  • Use dynamic_axes in ONNX export to support variable batch sizes at inference time.
  • Always validate exported models numerically against PyTorch outputs before deploying.
  • Prefer TorchScript scripting over tracing when the model contains conditional Python logic.
  • Benchmark with warmup runs and measure over at least 500 iterations to get stable latency estimates.