GadaaLabs
Machine Learning Engineering — Production ML Systems
Lesson 7

Model Packaging — ONNX, TorchScript & Containers

22 min

A model that cannot be reliably loaded six months from now is not a production asset — it is a time bomb. pickle files break across Python versions, .pt files require the same PyTorch version and the same class definition, and a notebook's implicit environment is never reproducible. Production model packaging means choosing a format that is durable, portable, and self-describing, then wrapping it in a container that encapsulates every dependency.

Serialisation Risks

Understanding serialisation risks determines which format to choose.

pickle: Python's default serialisation. Extremely fragile: requires the exact same class definitions, the same Python version, and the same library versions. Contains executable bytecode — loading a pickle file from an untrusted source executes arbitrary code. Never use pickle for long-lived model storage or inter-team sharing.

joblib: a wrapper around pickle that handles large numpy arrays efficiently via memory-mapped files. Still version-dependent, still executable. Acceptable for short-lived local caching; not appropriate for production artifact stores.

cloudpickle: serialises closures and dynamically-defined classes. Useful for distributing functions to Spark workers. Same security risks as pickle. Appropriate only within a homogeneous environment.

python
import pickle, joblib, cloudpickle

# pickle: fragile, insecure, avoid for production
with open("model.pkl", "wb") as f:
    pickle.dump(model, f)

# joblib: better for numpy-heavy models, still version-dependent
joblib.dump(model, "model.joblib", compress=3)
model = joblib.load("model.joblib")

# cloudpickle: handles closures, same security caveats
with open("model_cp.pkl", "wb") as f:
    cloudpickle.dump(model, f)

# The risk: this executes arbitrary code
# model = pickle.load(open("untrusted_model.pkl", "rb"))  # NEVER do this

# Mitigation: always validate SHA256 of model files before loading
import hashlib

def safe_load_model(path: str, expected_sha256: str):
    with open(path, "rb") as f:
        data = f.read()
    actual_sha256 = hashlib.sha256(data).hexdigest()
    if actual_sha256 != expected_sha256:
        raise ValueError(
            f"Model file checksum mismatch: expected {expected_sha256}, "
            f"got {actual_sha256}"
        )
    return joblib.loads(data)

ONNX Export

ONNX (Open Neural Network Exchange) is an open format that represents a model as a computation graph. ONNX models are framework-agnostic: export from PyTorch, run with ONNX Runtime, TensorRT, CoreML, or any other backend.

python
import torch
import torch.nn as nn
import onnx
import onnxruntime as ort
import numpy as np

def export_to_onnx(
    model: nn.Module,
    input_shape: tuple,
    output_path: str,
    opset_version: int = 17,
    dynamic_batch: bool = True,
) -> None:
    """Export a PyTorch model to ONNX with dynamic batch size support."""
    model.eval()
    device = next(model.parameters()).device
    dummy_input = torch.randn(*input_shape, device=device)

    dynamic_axes = None
    if dynamic_batch:
        dynamic_axes = {
            "input":  {0: "batch_size"},
            "output": {0: "batch_size"},
        }

    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        opset_version=opset_version,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes=dynamic_axes,
        export_params=True,          # include weights in the file
        do_constant_folding=True,    # fold constant operations at export time
        verbose=False,
    )
    print(f"Model exported to {output_path}")

def validate_onnx_output(
    model: nn.Module,
    onnx_path: str,
    input_shape: tuple,
    atol: float = 1e-5,
) -> bool:
    """Verify that ONNX Runtime output matches PyTorch output."""
    model.eval()
    dummy_input = torch.randn(*input_shape)

    # PyTorch output
    with torch.no_grad():
        pt_output = model(dummy_input).numpy()

    # ONNX Runtime output
    ort_session = ort.InferenceSession(
        onnx_path,
        providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
    )
    ort_inputs  = {ort_session.get_inputs()[0].name: dummy_input.numpy()}
    ort_output  = ort_session.run(None, ort_inputs)[0]

    match = np.allclose(pt_output, ort_output, atol=atol)
    if not match:
        max_diff = np.abs(pt_output - ort_output).max()
        print(f"ONNX validation FAILED: max diff = {max_diff:.2e}")
    else:
        print(f"ONNX validation passed (max diff < {atol})")
    return match

class ONNXModelServer:
    """Production ONNX Runtime inference wrapper."""

    def __init__(self, onnx_path: str, use_gpu: bool = False) -> None:
        providers = (
            ["CUDAExecutionProvider", "CPUExecutionProvider"]
            if use_gpu
            else ["CPUExecutionProvider"]
        )
        so = ort.SessionOptions()
        so.intra_op_num_threads  = 4
        so.inter_op_num_threads  = 1
        so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

        self.session    = ort.InferenceSession(onnx_path, so, providers=providers)
        self.input_name = self.session.get_inputs()[0].name

    def predict(self, X: np.ndarray) -> np.ndarray:
        return self.session.run(None, {self.input_name: X.astype(np.float32)})[0]

ONNX Common Gotchas

Certain PyTorch ops are not supported or produce incorrect ONNX graphs:

  • Dynamic control flow (if statements on tensor values): use torch.jit.script to handle these before ONNX export, or rewrite as masked operations.
  • torch.Tensor.item(): converts a tensor to a Python scalar inside the graph; use torch.squeeze instead.
  • Custom CUDA kernels: will not export to ONNX.
  • F.interpolate with recompute_scale_factor=True: can produce incorrect dynamic shape graphs in older opsets.

TorchScript

TorchScript compiles a PyTorch model to a static computation graph that can be serialised and run without a Python interpreter. This is essential for C++ inference or for avoiding the GIL in multi-threaded servers.

python
import torch
import torch.nn as nn

# torch.jit.trace: records operations on concrete inputs
# Works for models with fixed control flow (no data-dependent branching)
def export_torchscript_trace(
    model: nn.Module,
    example_input: torch.Tensor,
    output_path: str,
) -> None:
    model.eval()
    with torch.no_grad():
        traced = torch.jit.trace(model, example_input)
    traced.save(output_path)
    print(f"TorchScript (trace) saved to {output_path}")

# torch.jit.script: parses the model source code
# Handles if/for/while on tensor values, requires type annotations
class ScriptableModel(nn.Module):
    def __init__(self, hidden_size: int) -> None:
        super().__init__()
        self.linear = nn.Linear(hidden_size, 1)
        self.threshold: float = 0.5

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        logit = self.linear(x)
        score = torch.sigmoid(logit)
        # Data-dependent control flow: works with torch.jit.script
        if score.item() > self.threshold:
            return torch.ones(1)
        return torch.zeros(1)

def export_torchscript_script(
    model: nn.Module,
    output_path: str,
) -> None:
    model.eval()
    scripted = torch.jit.script(model)
    scripted.save(output_path)

# Loading (no Python class definition required at load time)
def load_torchscript(path: str) -> torch.jit.ScriptModule:
    return torch.jit.load(path)

MLflow pyfunc

MLflow's pyfunc interface wraps any Python prediction logic in a standard interface. This is the most flexible packaging option: it can wrap ONNX, TorchScript, sklearn, or any custom logic, while providing a consistent predict(DataFrame) API.

python
import mlflow.pyfunc
import mlflow
import pandas as pd
import numpy as np
import onnxruntime as ort

class ChurnModelPyfunc(mlflow.pyfunc.PythonModel):
    """Custom MLflow pyfunc wrapping an ONNX model with preprocessing."""

    def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None:
        """Called once at load time. Load model and any preprocessing artifacts."""
        onnx_path   = context.artifacts["onnx_model"]
        scaler_path = context.artifacts["scaler"]

        self.session = ort.InferenceSession(
            onnx_path, providers=["CPUExecutionProvider"]
        )
        self.input_name = self.session.get_inputs()[0].name

        import joblib
        self.scaler = joblib.load(scaler_path)

        self.features = [
            "tenure_days", "total_spend_30d",
            "support_tickets_90d", "last_purchase_days",
        ]

    def predict(
        self,
        context: mlflow.pyfunc.PythonModelContext,
        model_input: pd.DataFrame,
    ) -> pd.DataFrame:
        """Called at inference time. Input is always a DataFrame."""
        X = model_input[self.features].values.astype(np.float32)
        X_scaled = self.scaler.transform(X).astype(np.float32)
        scores = self.session.run(None, {self.input_name: X_scaled})[0]
        return pd.DataFrame({"churn_score": scores.flatten()})

def log_pyfunc_model(
    run_id: str,
    onnx_path: str,
    scaler_path: str,
) -> None:
    """Log the pyfunc model to an existing MLflow run."""
    artifacts = {
        "onnx_model": onnx_path,
        "scaler":     scaler_path,
    }
    conda_env = {
        "channels": ["defaults", "conda-forge"],
        "dependencies": [
            "python=3.11",
            {"pip": [
                "mlflow==2.12.0",
                "onnxruntime==1.17.0",
                "pandas==2.2.0",
                "numpy==1.26.0",
                "scikit-learn==1.4.0",
            ]},
        ],
    }
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            artifact_path="pyfunc_model",
            python_model=ChurnModelPyfunc(),
            artifacts=artifacts,
            conda_env=conda_env,
            code_path=["src/"],
        )

Docker Best Practices for ML Models

A Docker image for a model server must be minimal (fast pull, small attack surface), reproducible (same build = same image), and secure (non-root user, no secrets in the image).

dockerfile
# ---- Stage 1: builder ----
# Install all build dependencies (compilers, dev headers) in this stage.
# None of them appear in the final image.
FROM python:3.11-slim AS builder

WORKDIR /build

# Install build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
    build-essential \
    && rm -rf /var/lib/apt/lists/*

# Copy requirements first (layer caching: only reinstall when requirements change)
COPY requirements.txt .
RUN pip install --user --no-cache-dir -r requirements.txt

# ---- Stage 2: runtime ----
FROM python:3.11-slim AS runtime

WORKDIR /app

# Create a non-root user (security: container should not run as root)
RUN groupadd --gid 1000 mlapp \
    && useradd  --uid 1000 --gid mlapp --shell /bin/bash --create-home mlapp

# Copy only the installed packages from the builder stage
COPY --from=builder /root/.local /home/mlapp/.local

# Copy application code and model artifacts
COPY --chown=mlapp:mlapp src/ ./src/
COPY --chown=mlapp:mlapp models/ ./models/

# Switch to non-root user
USER mlapp

ENV PATH=/home/mlapp/.local/bin:$PATH
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1

# Health check: verify the server starts and responds
HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \
    CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')"

EXPOSE 8080

CMD ["python", "-m", "uvicorn", "src.server:app", "--host", "0.0.0.0", "--port", "8080", "--workers", "1"]
python
# .dockerignore — prevent accidental inclusion of large files and secrets
# .env
# .env.*
# *.pkl (build them inside the image or copy specific model files)
# data/
# notebooks/
# tests/
# .git/
# __pycache__/
# *.pyc
# mlruns/

# Build and tag the image
# docker build -t churn-model-server:v1.2.0 .
# docker run -p 8080:8080 -e MODEL_PATH=/app/models/churn_v2.onnx churn-model-server:v1.2.0

Model Quantisation

Quantisation reduces model size and inference latency by representing weights and activations in lower precision (INT8 instead of FP32 or FP16). PyTorch's dynamic quantisation is the easiest to apply.

python
import torch
import torch.quantization as quant
import os

def apply_dynamic_quantisation(
    model: torch.nn.Module,
    output_path: str,
) -> dict:
    """
    Dynamic INT8 quantisation for linear layers.
    No calibration data needed — scales are computed at inference time.
    Works best for LSTM, Linear, Embedding layers.
    Does NOT help for conv layers (use static quantisation for those).
    """
    model.eval()
    model_fp32_size = sum(p.numel() * 4 for p in model.parameters())  # bytes

    quantised = quant.quantize_dynamic(
        model,
        qconfig_spec={torch.nn.Linear},
        dtype=torch.qint8,
    )

    torch.save(quantised.state_dict(), output_path)
    model_int8_size = os.path.getsize(output_path)

    import time
    dummy = torch.randn(1, 128)

    model.eval()
    t0 = time.perf_counter()
    for _ in range(1000):
        with torch.no_grad():
            _ = model(dummy)
    fp32_latency = (time.perf_counter() - t0) / 1000 * 1000  # ms

    t0 = time.perf_counter()
    for _ in range(1000):
        with torch.no_grad():
            _ = quantised(dummy)
    int8_latency = (time.perf_counter() - t0) / 1000 * 1000  # ms

    return {
        "fp32_size_mb":     round(model_fp32_size / 1e6, 2),
        "int8_size_mb":     round(model_int8_size / 1e6, 2),
        "size_reduction":   f"{model_fp32_size / model_int8_size:.1f}x",
        "fp32_latency_ms":  round(fp32_latency, 3),
        "int8_latency_ms":  round(int8_latency, 3),
        "speedup":          round(fp32_latency / int8_latency, 2),
    }

Complete Dockerfile for a FastAPI Model Server

dockerfile
FROM python:3.11-slim AS builder
WORKDIR /build
RUN apt-get update && apt-get install -y --no-install-recommends build-essential \
    && rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip install --user --no-cache-dir -r requirements.txt

FROM python:3.11-slim AS runtime
WORKDIR /app

RUN groupadd --gid 1000 mlapp \
    && useradd --uid 1000 --gid mlapp --shell /bin/bash --create-home mlapp

COPY --from=builder /root/.local /home/mlapp/.local
COPY --chown=mlapp:mlapp src/ ./src/
COPY --chown=mlapp:mlapp models/churn_v2.onnx ./models/

USER mlapp
ENV PATH=/home/mlapp/.local/bin:$PATH
ENV MODEL_PATH=/app/models/churn_v2.onnx
ENV PYTHONUNBUFFERED=1

HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
    CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')"

EXPOSE 8080
CMD ["python", "-m", "uvicorn", "src.server:app", \
     "--host", "0.0.0.0", "--port", "8080", \
     "--workers", "1", "--log-level", "info"]
python
# requirements.txt (pin all versions for reproducible builds)
# fastapi==0.111.0
# uvicorn[standard]==0.29.0
# onnxruntime==1.17.1
# numpy==1.26.4
# pydantic==2.7.0
# scikit-learn==1.4.2
# joblib==1.4.0

Key Takeaways

  • Never use raw pickle for long-lived model artifacts: it is version-dependent, framework-dependent, and executes arbitrary code on load. Always validate file checksums before loading serialised models.
  • ONNX is the best format for cross-framework and cross-language portability; use ONNX Runtime with ORT_ENABLE_ALL graph optimisations and always validate that ONNX output matches PyTorch output before deploying.
  • Use torch.jit.trace for models with fixed control flow and torch.jit.script when you have data-dependent branching; both produce self-contained TorchScript files that do not require the model class definition at load time.
  • MLflow pyfunc is the most flexible packaging format: it wraps any prediction logic in a consistent predict(DataFrame) interface with a captured environment, making it easy to serve from any MLflow-compatible platform.
  • Multi-stage Docker builds separate the build environment (compilers, dev headers) from the runtime environment, producing smaller and more secure images.
  • Copy requirements.txt before application code in your Dockerfile so Docker's layer cache is invalidated only when dependencies change, not on every code edit.
  • Dynamic INT8 quantisation with quantize_dynamic requires no calibration data and typically reduces model size by 3-4x and improves CPU inference latency by 1.5-2x for linear layers.