GadaaLabs
Machine Learning Engineering — Production ML Systems
Lesson 8

Inference Serving — Latency, Throughput & Scaling

28 min

A model that takes 800 ms to return a prediction is useless in a real-time product. A model server that falls over under 50 RPS is useless in production. Serving is where ML engineering meets systems engineering, and it has its own performance vocabulary, failure modes, and design patterns that are entirely separate from model training.

Latency vs Throughput

These two goals are in tension:

  • Latency: time from request received to response returned. Measured as p50, p90, p99.
  • Throughput: requests successfully served per second (RPS).

The tension: batching improves throughput (more efficient GPU/CPU utilisation) but increases latency (you wait for the batch to fill). Dynamic batching — accumulating requests for a short window then inferring the whole batch — is the standard production compromise.

Why p99 Matters

If your p99 latency is 2000 ms, one in every hundred users waits 2 seconds. In a UI that makes 10 ML calls per session, that user almost certainly experiences at least one 2-second wait. p99 drives user experience; p50 drives your dashboards. Always measure and set SLAs on p99.

python
import time
import statistics

def measure_latencies(model, inputs: list, n_warmup: int = 10):
    """Measure p50, p90, p99 latencies after warm-up."""
    # Warm up: first N calls are slower due to JIT, cache misses, etc.
    for inp in inputs[:n_warmup]:
        model.predict(inp)

    latencies = []
    for inp in inputs:
        t0 = time.perf_counter()
        model.predict(inp)
        latencies.append((time.perf_counter() - t0) * 1000)  # ms

    latencies.sort()
    n = len(latencies)
    return {
        "p50_ms":  latencies[int(0.50 * n)],
        "p90_ms":  latencies[int(0.90 * n)],
        "p99_ms":  latencies[int(0.99 * n)],
        "mean_ms": statistics.mean(latencies),
    }

Complete FastAPI Model Server

python
# src/server.py  (~80 lines, production-ready)
import asyncio
import hashlib
import json
import os
import time
from contextlib import asynccontextmanager
from typing import Any, Literal

import joblib
import numpy as np
import pandas as pd
import redis.asyncio as aioredis
from fastapi import BackgroundTasks, FastAPI, HTTPException
from pydantic import BaseModel, Field

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_PATH   = os.getenv("MODEL_PATH", "churn_pipeline_v1.joblib")
REDIS_URL    = os.getenv("REDIS_URL",  "redis://localhost:6379")
CACHE_TTL    = int(os.getenv("CACHE_TTL_SECONDS", "3600"))

_state: dict[str, Any] = {}

# ---------------------------------------------------------------------------
# Lifespan: load model and Redis pool at startup, release at shutdown
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
    _state["model"]    = joblib.load(MODEL_PATH)
    _state["redis"]    = await aioredis.from_url(REDIS_URL, decode_responses=True)
    _state["job_store"] = {}   # in-process job store (use Redis in real production)

    # Model warm-up: run a dummy prediction to trigger any lazy initialisation
    # and pre-load ONNX Runtime or CUDA kernels before the first real request.
    dummy = pd.DataFrame([{
        "tenure_months": 12, "monthly_spend": 99.0, "support_tickets": 2,
        "feature_usage_pct": 0.4, "last_login_days_ago": 7,
        "num_integrations": 3, "contract_type": "annual",
        "plan_tier": "growth", "is_enterprise": 0,
    }])
    _state["model"].predict_proba(dummy)
    print(f"[startup] Model loaded from {MODEL_PATH}. Warm-up complete.")
    yield
    await _state["redis"].aclose()
    _state.clear()
    print("[shutdown] Resources released.")

app = FastAPI(title="Inference Server", version="1.0.0", lifespan=lifespan)

# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------
class PredictRequest(BaseModel):
    tenure_months:       int   = Field(..., ge=1, le=120)
    monthly_spend:       float = Field(..., ge=0.0)
    support_tickets:     int   = Field(..., ge=0, le=50)
    feature_usage_pct:   float = Field(..., ge=0.0, le=1.0)
    last_login_days_ago: int   = Field(..., ge=0, le=365)
    num_integrations:    int   = Field(..., ge=0, le=50)
    contract_type:       Literal["monthly", "annual", "multi-year"]
    plan_tier:           Literal["starter", "growth", "enterprise"]
    is_enterprise:       Literal[0, 1]

class PredictResponse(BaseModel):
    churn_probability: float
    churn_predicted:   bool
    cache_hit:         bool = False
    latency_ms:        float

# ---------------------------------------------------------------------------
# Health & readiness endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
async def health():
    return {"status": "ok", "ts": time.time()}

@app.get("/ready")
async def ready():
    if "model" not in _state:
        raise HTTPException(503, "Model not loaded")
    try:
        await _state["redis"].ping()
    except Exception:
        raise HTTPException(503, "Redis unavailable")
    return {"status": "ready"}

Prediction Caching with Redis

python
# Continued from server.py

def _cache_key(req: PredictRequest) -> str:
    """Deterministic cache key: SHA256 of sorted, normalised input fields."""
    payload = json.dumps(req.model_dump(), sort_keys=True, default=str)
    return "pred:" + hashlib.sha256(payload.encode()).hexdigest()

@app.post("/predict", response_model=PredictResponse)
async def predict(req: PredictRequest):
    t0  = time.perf_counter()
    key = _cache_key(req)

    # Cache lookup
    cached = await _state["redis"].get(key)
    if cached is not None:
        prob = float(cached)
        return PredictResponse(
            churn_probability=round(prob, 4),
            churn_predicted=prob >= 0.4,
            cache_hit=True,
            latency_ms=round((time.perf_counter() - t0) * 1000, 2),
        )

    # Cache miss: run inference
    try:
        row  = pd.DataFrame([req.model_dump()])
        prob = float(_state["model"].predict_proba(row)[0, 1])
    except Exception as exc:
        raise HTTPException(500, f"Inference failed: {exc}") from exc

    # Store result in Redis with TTL
    await _state["redis"].setex(key, CACHE_TTL, str(prob))

    return PredictResponse(
        churn_probability=round(prob, 4),
        churn_predicted=prob >= 0.4,
        cache_hit=False,
        latency_ms=round((time.perf_counter() - t0) * 1000, 2),
    )

Cache invalidation strategy: use a TTL of 1 hour for features that change slowly. For real-time features (e.g. last_login_days_ago), either exclude them from the cache key or use a shorter TTL.


Dynamic Batching

For neural network inference, batching amortises the GPU kernel launch overhead. The pattern: accumulate requests for up to 50 ms, then run a single batched inference.

python
import asyncio
import time
from dataclasses import dataclass, field
from typing import Any

@dataclass
class PendingRequest:
    features: Any
    future:   asyncio.Future

class DynamicBatcher:
    """
    Accumulate requests for `window_ms` milliseconds, then run a single
    batched inference. Reduces GPU round-trips without a blocking queue.
    """
    def __init__(self, model, window_ms: float = 50, max_batch: int = 64):
        self._model     = model
        self._window_ms = window_ms / 1000  # convert to seconds
        self._max_batch = max_batch
        self._queue: list[PendingRequest] = []
        self._lock   = asyncio.Lock()
        self._task   = None

    async def predict(self, features) -> float:
        loop   = asyncio.get_event_loop()
        future = loop.create_future()
        async with self._lock:
            self._queue.append(PendingRequest(features=features, future=future))
            if self._task is None or self._task.done():
                self._task = asyncio.create_task(self._flush_after_window())
        return await future

    async def _flush_after_window(self):
        await asyncio.sleep(self._window_ms)
        async with self._lock:
            batch = self._queue[:self._max_batch]
            self._queue = self._queue[self._max_batch:]

        if not batch:
            return

        # Build batch matrix and run single inference call
        X_batch = pd.DataFrame([r.features for r in batch])
        try:
            probs = self._model.predict_proba(X_batch)[:, 1]
        except Exception as exc:
            for r in batch:
                if not r.future.done():
                    r.future.set_exception(exc)
            return

        for r, prob in zip(batch, probs):
            if not r.future.done():
                r.future.set_result(float(prob))

        # If more items remain in the queue, schedule another flush
        async with self._lock:
            if self._queue:
                self._task = asyncio.create_task(self._flush_after_window())

Async Job Queue for Long-Running Inference

For transformer models or anything taking >200 ms, the pattern is: POST returns a job ID immediately; GET polls for results.

python
import asyncio
import uuid
from enum import Enum

class JobStatus(str, Enum):
    pending  = "pending"
    running  = "running"
    complete = "complete"
    failed   = "failed"

class JobRequest(BaseModel):
    features: PredictRequest

class JobResult(BaseModel):
    job_id:  str
    status:  JobStatus
    result:  float | None = None
    error:   str  | None = None

@app.post("/jobs", status_code=202)
async def submit_job(req: JobRequest, background_tasks: BackgroundTasks):
    job_id = str(uuid.uuid4())
    _state["job_store"][job_id] = {"status": JobStatus.pending, "result": None, "error": None}
    background_tasks.add_task(_run_inference_job, job_id, req.features)
    return {"job_id": job_id}

@app.get("/jobs/{job_id}", response_model=JobResult)
async def get_job(job_id: str):
    if job_id not in _state["job_store"]:
        raise HTTPException(404, "Job not found")
    job = _state["job_store"][job_id]
    return JobResult(job_id=job_id, **job)

async def _run_inference_job(job_id: str, req: PredictRequest):
    _state["job_store"][job_id]["status"] = JobStatus.running
    try:
        row  = pd.DataFrame([req.model_dump()])
        # Run blocking inference in a thread pool to avoid blocking the event loop
        prob = await asyncio.get_event_loop().run_in_executor(
            None,
            lambda: float(_state["model"].predict_proba(row)[0, 1])
        )
        _state["job_store"][job_id].update({"status": JobStatus.complete, "result": prob})
    except Exception as exc:
        _state["job_store"][job_id].update({"status": JobStatus.failed, "error": str(exc)})

Horizontal Scaling: Nginx + Uvicorn Workers

nginx
# nginx.conf snippet for load balancing across 4 Uvicorn processes
upstream ml_backend {
    least_conn;                          # route to the least-busy worker
    server 127.0.0.1:8001;
    server 127.0.0.1:8002;
    server 127.0.0.1:8003;
    server 127.0.0.1:8004;
    keepalive 32;                        # persist connections to backends
}

server {
    listen 80;

    location /predict {
        proxy_pass         http://ml_backend;
        proxy_http_version 1.1;
        proxy_set_header   Connection "";
        proxy_set_header   Host $host;
        proxy_read_timeout 30s;
    }

    location /health {
        proxy_pass http://ml_backend;
        access_log off;
    }
}
bash
# Start 4 independent Uvicorn processes (each is a separate Python process,
# bypassing the GIL entirely for CPU-bound inference)
uvicorn src.server:app --host 127.0.0.1 --port 8001 --workers 1 &
uvicorn src.server:app --host 127.0.0.1 --port 8002 --workers 1 &
uvicorn src.server:app --host 127.0.0.1 --port 8003 --workers 1 &
uvicorn src.server:app --host 127.0.0.1 --port 8004 --workers 1 &

For GPU inference, use one process per GPU (multi-process service). For very large models, explore NVIDIA MIG (Multi-Instance GPU) which partitions a single A100 into up to 7 independent GPU instances, each with dedicated memory and compute.


Graceful Shutdown

The lifespan context manager in the server above already handles graceful shutdown — when Uvicorn receives SIGTERM, it stops accepting new connections, lets in-flight requests complete (up to the --timeout-graceful-shutdown deadline, default 30 s), then yields from the lifespan context, releasing the Redis pool and clearing state.

python
# Confirm graceful shutdown with a custom signal handler (optional)
import signal
import sys

def handle_sigterm(signum, frame):
    print("[server] SIGTERM received — graceful shutdown initiated")
    # FastAPI/Uvicorn handles the actual shutdown via lifespan;
    # this handler is for additional cleanup (e.g., flush metrics)
    sys.exit(0)

signal.signal(signal.SIGTERM, handle_sigterm)

Key Takeaways

  • Always measure and set SLAs on p99 latency, not just p50 — p99 is what users actually experience in a multi-call session.
  • Load the model once at startup using FastAPI's lifespan context manager and run a warm-up prediction to eliminate cold-start latency on the first real request.
  • Redis prediction caching with a content-addressed key (SHA256 of normalised inputs) can eliminate redundant inference for repeated or near-identical inputs at zero additional ML cost.
  • Dynamic batching accumulates requests over a short time window (20-100 ms) and runs a single batched inference call — this is the primary throughput lever for GPU-bound models.
  • For inference taking more than 200 ms, use an async job pattern: POST returns a job ID immediately, GET polls for the result, and the actual inference runs in a background task.
  • Horizontal scaling requires stateless server processes — no model state in global Python variables that differ between processes. Redis is the right home for shared state (cache, job results).
  • Use separate Uvicorn processes (not threads) for CPU-bound inference to bypass the GIL; for GPU inference, one process per GPU is the standard pattern.
  • Graceful shutdown is not optional in production: Kubernetes sends SIGTERM before SIGKILL; a server that doesn't drain in-flight requests will produce failed predictions on every deployment.