A model that cannot be reliably loaded six months from now is not a production asset — it is a time bomb. pickle files break across Python versions, .pt files require the same PyTorch version and the same class definition, and a notebook's implicit environment is never reproducible. Production model packaging means choosing a format that is durable, portable, and self-describing, then wrapping it in a container that encapsulates every dependency.
Serialisation Risks
Understanding serialisation risks determines which format to choose.
pickle: Python's default serialisation. Extremely fragile: requires the exact same class definitions, the same Python version, and the same library versions. Contains executable bytecode — loading a pickle file from an untrusted source executes arbitrary code. Never use pickle for long-lived model storage or inter-team sharing.
joblib: a wrapper around pickle that handles large numpy arrays efficiently via memory-mapped files. Still version-dependent, still executable. Acceptable for short-lived local caching; not appropriate for production artifact stores.
cloudpickle: serialises closures and dynamically-defined classes. Useful for distributing functions to Spark workers. Same security risks as pickle. Appropriate only within a homogeneous environment.
python
import pickle, joblib, cloudpickle# pickle: fragile, insecure, avoid for productionwith open("model.pkl", "wb") as f: pickle.dump(model, f)# joblib: better for numpy-heavy models, still version-dependentjoblib.dump(model, "model.joblib", compress=3)model = joblib.load("model.joblib")# cloudpickle: handles closures, same security caveatswith open("model_cp.pkl", "wb") as f: cloudpickle.dump(model, f)# The risk: this executes arbitrary code# model = pickle.load(open("untrusted_model.pkl", "rb")) # NEVER do this# Mitigation: always validate SHA256 of model files before loadingimport hashlibdef safe_load_model(path: str, expected_sha256: str): with open(path, "rb") as f: data = f.read() actual_sha256 = hashlib.sha256(data).hexdigest() if actual_sha256 != expected_sha256: raise ValueError( f"Model file checksum mismatch: expected {expected_sha256}, " f"got {actual_sha256}" ) return joblib.loads(data)
ONNX Export
ONNX (Open Neural Network Exchange) is an open format that represents a model as a computation graph. ONNX models are framework-agnostic: export from PyTorch, run with ONNX Runtime, TensorRT, CoreML, or any other backend.
python
import torchimport torch.nn as nnimport onnximport onnxruntime as ortimport numpy as npdef export_to_onnx( model: nn.Module, input_shape: tuple, output_path: str, opset_version: int = 17, dynamic_batch: bool = True,) -> None: """Export a PyTorch model to ONNX with dynamic batch size support.""" model.eval() device = next(model.parameters()).device dummy_input = torch.randn(*input_shape, device=device) dynamic_axes = None if dynamic_batch: dynamic_axes = { "input": {0: "batch_size"}, "output": {0: "batch_size"}, } torch.onnx.export( model, dummy_input, output_path, opset_version=opset_version, input_names=["input"], output_names=["output"], dynamic_axes=dynamic_axes, export_params=True, # include weights in the file do_constant_folding=True, # fold constant operations at export time verbose=False, ) print(f"Model exported to {output_path}")def validate_onnx_output( model: nn.Module, onnx_path: str, input_shape: tuple, atol: float = 1e-5,) -> bool: """Verify that ONNX Runtime output matches PyTorch output.""" model.eval() dummy_input = torch.randn(*input_shape) # PyTorch output with torch.no_grad(): pt_output = model(dummy_input).numpy() # ONNX Runtime output ort_session = ort.InferenceSession( onnx_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.numpy()} ort_output = ort_session.run(None, ort_inputs)[0] match = np.allclose(pt_output, ort_output, atol=atol) if not match: max_diff = np.abs(pt_output - ort_output).max() print(f"ONNX validation FAILED: max diff = {max_diff:.2e}") else: print(f"ONNX validation passed (max diff < {atol})") return matchclass ONNXModelServer: """Production ONNX Runtime inference wrapper.""" def __init__(self, onnx_path: str, use_gpu: bool = False) -> None: providers = ( ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"] ) so = ort.SessionOptions() so.intra_op_num_threads = 4 so.inter_op_num_threads = 1 so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL self.session = ort.InferenceSession(onnx_path, so, providers=providers) self.input_name = self.session.get_inputs()[0].name def predict(self, X: np.ndarray) -> np.ndarray: return self.session.run(None, {self.input_name: X.astype(np.float32)})[0]
ONNX Common Gotchas
Certain PyTorch ops are not supported or produce incorrect ONNX graphs:
Dynamic control flow (if statements on tensor values): use torch.jit.script to handle these before ONNX export, or rewrite as masked operations.
torch.Tensor.item(): converts a tensor to a Python scalar inside the graph; use torch.squeeze instead.
Custom CUDA kernels: will not export to ONNX.
F.interpolate with recompute_scale_factor=True: can produce incorrect dynamic shape graphs in older opsets.
TorchScript
TorchScript compiles a PyTorch model to a static computation graph that can be serialised and run without a Python interpreter. This is essential for C++ inference or for avoiding the GIL in multi-threaded servers.
python
import torchimport torch.nn as nn# torch.jit.trace: records operations on concrete inputs# Works for models with fixed control flow (no data-dependent branching)def export_torchscript_trace( model: nn.Module, example_input: torch.Tensor, output_path: str,) -> None: model.eval() with torch.no_grad(): traced = torch.jit.trace(model, example_input) traced.save(output_path) print(f"TorchScript (trace) saved to {output_path}")# torch.jit.script: parses the model source code# Handles if/for/while on tensor values, requires type annotationsclass ScriptableModel(nn.Module): def __init__(self, hidden_size: int) -> None: super().__init__() self.linear = nn.Linear(hidden_size, 1) self.threshold: float = 0.5 def forward(self, x: torch.Tensor) -> torch.Tensor: logit = self.linear(x) score = torch.sigmoid(logit) # Data-dependent control flow: works with torch.jit.script if score.item() > self.threshold: return torch.ones(1) return torch.zeros(1)def export_torchscript_script( model: nn.Module, output_path: str,) -> None: model.eval() scripted = torch.jit.script(model) scripted.save(output_path)# Loading (no Python class definition required at load time)def load_torchscript(path: str) -> torch.jit.ScriptModule: return torch.jit.load(path)
MLflow pyfunc
MLflow's pyfunc interface wraps any Python prediction logic in a standard interface. This is the most flexible packaging option: it can wrap ONNX, TorchScript, sklearn, or any custom logic, while providing a consistent predict(DataFrame) API.
python
import mlflow.pyfuncimport mlflowimport pandas as pdimport numpy as npimport onnxruntime as ortclass ChurnModelPyfunc(mlflow.pyfunc.PythonModel): """Custom MLflow pyfunc wrapping an ONNX model with preprocessing.""" def load_context(self, context: mlflow.pyfunc.PythonModelContext) -> None: """Called once at load time. Load model and any preprocessing artifacts.""" onnx_path = context.artifacts["onnx_model"] scaler_path = context.artifacts["scaler"] self.session = ort.InferenceSession( onnx_path, providers=["CPUExecutionProvider"] ) self.input_name = self.session.get_inputs()[0].name import joblib self.scaler = joblib.load(scaler_path) self.features = [ "tenure_days", "total_spend_30d", "support_tickets_90d", "last_purchase_days", ] def predict( self, context: mlflow.pyfunc.PythonModelContext, model_input: pd.DataFrame, ) -> pd.DataFrame: """Called at inference time. Input is always a DataFrame.""" X = model_input[self.features].values.astype(np.float32) X_scaled = self.scaler.transform(X).astype(np.float32) scores = self.session.run(None, {self.input_name: X_scaled})[0] return pd.DataFrame({"churn_score": scores.flatten()})def log_pyfunc_model( run_id: str, onnx_path: str, scaler_path: str,) -> None: """Log the pyfunc model to an existing MLflow run.""" artifacts = { "onnx_model": onnx_path, "scaler": scaler_path, } conda_env = { "channels": ["defaults", "conda-forge"], "dependencies": [ "python=3.11", {"pip": [ "mlflow==2.12.0", "onnxruntime==1.17.0", "pandas==2.2.0", "numpy==1.26.0", "scikit-learn==1.4.0", ]}, ], } with mlflow.start_run(run_id=run_id): mlflow.pyfunc.log_model( artifact_path="pyfunc_model", python_model=ChurnModelPyfunc(), artifacts=artifacts, conda_env=conda_env, code_path=["src/"], )
Docker Best Practices for ML Models
A Docker image for a model server must be minimal (fast pull, small attack surface), reproducible (same build = same image), and secure (non-root user, no secrets in the image).
dockerfile
# ---- Stage 1: builder ----# Install all build dependencies (compilers, dev headers) in this stage.# None of them appear in the final image.FROM python:3.11-slim AS builderWORKDIR /build# Install build dependenciesRUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ && rm -rf /var/lib/apt/lists/*# Copy requirements first (layer caching: only reinstall when requirements change)COPY requirements.txt .RUN pip install --user --no-cache-dir -r requirements.txt# ---- Stage 2: runtime ----FROM python:3.11-slim AS runtimeWORKDIR /app# Create a non-root user (security: container should not run as root)RUN groupadd --gid 1000 mlapp \ && useradd --uid 1000 --gid mlapp --shell /bin/bash --create-home mlapp# Copy only the installed packages from the builder stageCOPY --from=builder /root/.local /home/mlapp/.local# Copy application code and model artifactsCOPY --chown=mlapp:mlapp src/ ./src/COPY --chown=mlapp:mlapp models/ ./models/# Switch to non-root userUSER mlappENV PATH=/home/mlapp/.local/bin:$PATHENV PYTHONUNBUFFERED=1ENV PYTHONDONTWRITEBYTECODE=1# Health check: verify the server starts and respondsHEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')"EXPOSE 8080CMD ["python", "-m", "uvicorn", "src.server:app", "--host", "0.0.0.0", "--port", "8080", "--workers", "1"]
python
# .dockerignore — prevent accidental inclusion of large files and secrets# .env# .env.*# *.pkl (build them inside the image or copy specific model files)# data/# notebooks/# tests/# .git/# __pycache__/# *.pyc# mlruns/# Build and tag the image# docker build -t churn-model-server:v1.2.0 .# docker run -p 8080:8080 -e MODEL_PATH=/app/models/churn_v2.onnx churn-model-server:v1.2.0
Model Quantisation
Quantisation reduces model size and inference latency by representing weights and activations in lower precision (INT8 instead of FP32 or FP16). PyTorch's dynamic quantisation is the easiest to apply.
python
import torchimport torch.quantization as quantimport osdef apply_dynamic_quantisation( model: torch.nn.Module, output_path: str,) -> dict: """ Dynamic INT8 quantisation for linear layers. No calibration data needed — scales are computed at inference time. Works best for LSTM, Linear, Embedding layers. Does NOT help for conv layers (use static quantisation for those). """ model.eval() model_fp32_size = sum(p.numel() * 4 for p in model.parameters()) # bytes quantised = quant.quantize_dynamic( model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8, ) torch.save(quantised.state_dict(), output_path) model_int8_size = os.path.getsize(output_path) import time dummy = torch.randn(1, 128) model.eval() t0 = time.perf_counter() for _ in range(1000): with torch.no_grad(): _ = model(dummy) fp32_latency = (time.perf_counter() - t0) / 1000 * 1000 # ms t0 = time.perf_counter() for _ in range(1000): with torch.no_grad(): _ = quantised(dummy) int8_latency = (time.perf_counter() - t0) / 1000 * 1000 # ms return { "fp32_size_mb": round(model_fp32_size / 1e6, 2), "int8_size_mb": round(model_int8_size / 1e6, 2), "size_reduction": f"{model_fp32_size / model_int8_size:.1f}x", "fp32_latency_ms": round(fp32_latency, 3), "int8_latency_ms": round(int8_latency, 3), "speedup": round(fp32_latency / int8_latency, 2), }
# requirements.txt (pin all versions for reproducible builds)# fastapi==0.111.0# uvicorn[standard]==0.29.0# onnxruntime==1.17.1# numpy==1.26.4# pydantic==2.7.0# scikit-learn==1.4.2# joblib==1.4.0
Key Takeaways
Never use raw pickle for long-lived model artifacts: it is version-dependent, framework-dependent, and executes arbitrary code on load. Always validate file checksums before loading serialised models.
ONNX is the best format for cross-framework and cross-language portability; use ONNX Runtime with ORT_ENABLE_ALL graph optimisations and always validate that ONNX output matches PyTorch output before deploying.
Use torch.jit.trace for models with fixed control flow and torch.jit.script when you have data-dependent branching; both produce self-contained TorchScript files that do not require the model class definition at load time.
MLflow pyfunc is the most flexible packaging format: it wraps any prediction logic in a consistent predict(DataFrame) interface with a captured environment, making it easy to serve from any MLflow-compatible platform.
Multi-stage Docker builds separate the build environment (compilers, dev headers) from the runtime environment, producing smaller and more secure images.
Copy requirements.txt before application code in your Dockerfile so Docker's layer cache is invalidated only when dependencies change, not on every code edit.
Dynamic INT8 quantisation with quantize_dynamic requires no calibration data and typically reduces model size by 3-4x and improves CPU inference latency by 1.5-2x for linear layers.