GadaaLabs
Machine Learning Engineering
Lesson 6

Inference Serving with FastAPI

16 min

Serving a model is a software engineering problem, not a machine learning one. A well-designed inference server handles malformed requests gracefully, reports its own health, batches requests to maximise GPU utilisation, and scales horizontally without state. This lesson builds that server from scratch.

API Contract Design

Define the request/response schemas before writing any model code — they are your API's public contract:

python
from pydantic import BaseModel, Field
from typing import List

class PredictRequest(BaseModel):
    texts: List[str] = Field(..., min_items=1, max_items=32,
                             description="Input texts to classify")

class PredictionResult(BaseModel):
    label: str
    confidence: float

class PredictResponse(BaseModel):
    predictions: List[PredictionResult]
    model_version: str
    latency_ms: float

Pydantic validates every request automatically. A request with 33 items returns HTTP 422 with a clear error message before touching the model.

Server Setup with ONNX Runtime

python
import time
import asyncio
import onnxruntime as ort
from fastapi import FastAPI
from contextlib import asynccontextmanager

MODEL_PATH = "model.onnx"
session: ort.InferenceSession = None
MODEL_VERSION = "1.2.0"

@asynccontextmanager
async def lifespan(app: FastAPI):
    global session
    session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
    print("Model loaded")
    yield
    session = None  # cleanup on shutdown

app = FastAPI(title="Inference API", version=MODEL_VERSION, lifespan=lifespan)

Loading the session in the lifespan hook ensures the model is ready before the first request and is properly released on graceful shutdown.

Prediction Endpoint with Timing

python
import numpy as np

@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
    t0 = time.perf_counter()

    # Tokenise / preprocess (example: fixed-length padding)
    inputs = preprocess(request.texts)           # returns np.ndarray
    ort_inputs = {"input": inputs.astype(np.float32)}

    logits = session.run(None, ort_inputs)[0]    # shape: (B, num_classes)
    probs  = softmax(logits, axis=-1)
    labels = LABEL_MAP[probs.argmax(axis=-1)]
    confs  = probs.max(axis=-1).tolist()

    latency_ms = (time.perf_counter() - t0) * 1000
    return PredictResponse(
        predictions=[PredictionResult(label=l, confidence=c)
                     for l, c in zip(labels, confs)],
        model_version=MODEL_VERSION,
        latency_ms=round(latency_ms, 2),
    )

Health and Readiness Checks

python
@app.get("/health")
async def health():
    return {"status": "ok"}

@app.get("/ready")
async def ready():
    if session is None:
        from fastapi import HTTPException
        raise HTTPException(status_code=503, detail="Model not loaded")
    return {"status": "ready", "model_version": MODEL_VERSION}

Kubernetes uses /health (liveness probe) and /ready (readiness probe) separately. A slow model load should fail readiness without killing the pod.

Async Batching Under Load

python
import asyncio
from collections import deque

BATCH_SIZE     = 16
BATCH_WAIT_MS  = 20
request_queue: deque = deque()

async def batch_processor():
    while True:
        await asyncio.sleep(BATCH_WAIT_MS / 1000)
        if not request_queue:
            continue
        batch = [request_queue.popleft() for _ in range(min(BATCH_SIZE, len(request_queue)))]
        inputs = np.concatenate([b["input"] for b in batch], axis=0)
        results = session.run(None, {"input": inputs})[0]
        for i, b in enumerate(batch):
            b["future"].set_result(results[i])

Batching trades a small latency increase (up to BATCH_WAIT_MS) for significantly higher throughput and GPU utilisation.

Deployment and Scaling

| Concern | Solution | |---|---| | Multiple replicas | Stateless server — any replica handles any request | | Load balancing | Kubernetes Service with sessionAffinity: None | | Resource limits | requests: cpu: 500m, memory: 1Gi per replica | | Zero-downtime deploys | Rolling update strategy with readiness probe |

Summary

  • Define Pydantic request/response schemas first; they act as contracts and provide free input validation.
  • Load the model in FastAPI's lifespan hook, not at module level, to support graceful startup and shutdown.
  • Expose separate /health (liveness) and /ready (readiness) endpoints for Kubernetes probes.
  • Implement async batching when serving GPU models to improve throughput without unbounded latency.
  • Keep the server stateless so it scales horizontally by simply adding more identical replicas.