GadaaLabs
Machine Learning Engineering — Production ML Systems
Lesson 4

Training Infrastructure — GPUs, Distributed Training & Optimisation

28 min

A model that trains in 8 hours when it could train in 2 hours is a 4x productivity tax. Multiply that by dozens of experiments and you are losing weeks of iteration time. Training infrastructure is not about raw hardware — it is about using hardware efficiently. This lesson covers every major technique for maximising GPU utilisation and making training infrastructure reliable enough to run unattended in production.

GPU Memory Hierarchy

Understanding GPU memory is prerequisite to understanding every optimisation in this lesson.

VRAM (HBM): the on-chip high-bandwidth memory on the GPU. Everything that participates in training must fit in VRAM: model weights, gradients, optimizer states (Adam keeps two momentum tensors per parameter — 2x the weight size), and activations from the forward pass. An 80GB A100 sounds like a lot until you store a 7B parameter model in float32 (7B × 4 bytes = 28GB just for weights, before gradients and optimizer state).

Bandwidth vs compute: most operations in deep learning are bandwidth-limited, not compute-limited. Moving 1GB of activations from VRAM to the CUDA cores takes time proportional to memory bandwidth (HBM3 on A100: 2 TB/s). Operations that do very little compute per byte read (element-wise activations, layer norms) are bandwidth-bound. Operations that do a lot of compute per byte (large matrix multiplications) are compute-bound. Mixed precision and fused kernels help by reducing the bandwidth needed per operation.

Mixed Precision Training

Mixed precision training runs the forward and backward passes in FP16 or BF16, while keeping a master copy of weights in FP32 for the optimizer update. This halves memory usage for activations and nearly doubles throughput on Tensor Cores (which natively accelerate FP16/BF16 matmuls).

FP16 vs BF16: BF16 has the same 8-bit exponent as FP32 (so the same numeric range), but only 7 mantissa bits (vs 23 in FP32 and 10 in FP16). FP16 has more precision but can overflow on large activation values. BF16 is generally preferred for training; FP16 is more common for inference where the range is predictable.

python
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

def train_one_epoch_mixed_precision(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    scaler: GradScaler,
    device: torch.device,
    gradient_clip_norm: float = 1.0,
) -> float:
    model.train()
    total_loss = 0.0

    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs  = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)   # set_to_none=True saves memory

        # Forward pass in FP16/BF16
        with autocast(device_type="cuda", dtype=torch.float16):
            outputs = model(inputs)
            loss    = criterion(outputs, targets)

        # Backward pass: scaler prevents underflow in FP16 gradients
        scaler.scale(loss).backward()

        # Unscale before gradient clipping (required before clip_grad_norm_)
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_norm)

        # Optimizer step: only updates if gradients are finite (no NaN/inf)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    return total_loss / len(loader)

# Initialise scaler once, pass it through the training loop
scaler = GradScaler()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Gradient Accumulation

When your batch size is constrained by VRAM, gradient accumulation simulates a larger effective batch by accumulating gradients over multiple forward/backward passes before stepping the optimizer.

python
def train_with_gradient_accumulation(
    model: nn.Module,
    loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    scaler: GradScaler,
    device: torch.device,
    accumulation_steps: int = 8,   # effective_batch = batch_size × 8
) -> float:
    model.train()
    total_loss = 0.0
    optimizer.zero_grad(set_to_none=True)

    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs  = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        with autocast(device_type="cuda", dtype=torch.float16):
            outputs = model(inputs)
            # Normalise loss so gradient magnitude is equivalent to
            # computing it on the full effective batch at once
            loss = criterion(outputs, targets) / accumulation_steps

        scaler.scale(loss).backward()

        # Only step after accumulation_steps mini-batches
        if (batch_idx + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        total_loss += loss.item() * accumulation_steps

    return total_loss / len(loader)

DataLoader Optimisation

The DataLoader is frequently the bottleneck in training pipelines. A GPU that is starved of data will show near-zero utilisation between batches.

python
import multiprocessing
from torch.utils.data import DataLoader, Dataset

def make_optimised_dataloader(
    dataset: Dataset,
    batch_size: int,
    training: bool = True,
) -> DataLoader:
    """
    Production-quality DataLoader with all performance settings.
    """
    num_workers = min(8, multiprocessing.cpu_count())

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=training,
        num_workers=num_workers,       # parallel data loading on CPU
        pin_memory=True,               # allocate batches in pinned (page-locked)
                                       # memory for faster CPU→GPU (H2D) transfer
        persistent_workers=True,       # keep worker processes alive between epochs
                                       # (avoid fork overhead on every epoch)
        prefetch_factor=2,             # each worker pre-fetches 2 batches ahead
        drop_last=training,            # drop incomplete final batch during training
    )

# Tip: profile DataLoader vs GPU time with a simple benchmark
def benchmark_dataloader(loader: DataLoader, n_batches: int = 50) -> None:
    import time
    start = time.perf_counter()
    for i, batch in enumerate(loader):
        if i >= n_batches:
            break
    elapsed = time.perf_counter() - start
    print(f"DataLoader: {n_batches} batches in {elapsed:.2f}s "
          f"({elapsed/n_batches*1000:.1f}ms/batch)")

Distributed Training

When a single GPU is not enough — either because the model is too large or because you need to train faster — you move to distributed training.

DataParallel (single node, easy but inefficient)

DataParallel wraps a model and splits each batch across multiple GPUs on a single machine. The primary GPU aggregates gradients. Simple to use but suboptimal: the primary GPU is a bottleneck, and communication goes through CPU memory.

python
# DataParallel: wraps model, splits batch across GPUs
model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])
model = model.cuda()
# Inference and training code is unchanged

DistributedDataParallel (the production choice)

DDP spawns one process per GPU. Each process has its own model replica. Gradients are synchronised with an AllReduce operation (using NCCL on NVIDIA hardware). No primary GPU bottleneck.

python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import os

def setup_ddp(rank: int, world_size: int) -> None:
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def cleanup_ddp() -> None:
    dist.destroy_process_group()

def train_ddp(rank: int, world_size: int, model_cls, dataset) -> None:
    setup_ddp(rank, world_size)
    device = torch.device(f"cuda:{rank}")

    model = model_cls().to(device)
    model = DDP(model, device_ids=[rank])          # wrap with DDP

    # Each process gets a non-overlapping slice of the data
    sampler = DistributedSampler(
        dataset, num_replicas=world_size, rank=rank
    )
    loader = DataLoader(dataset, batch_size=32, sampler=sampler,
                        num_workers=4, pin_memory=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()
    scaler    = GradScaler()

    for epoch in range(10):
        sampler.set_epoch(epoch)          # reshuffle per epoch
        train_one_epoch_mixed_precision(
            model, loader, optimizer, criterion, scaler, device
        )

    cleanup_ddp()

# Launch with: torch.multiprocessing.spawn(train_ddp, args=(4, MyModel, dataset), nprocs=4)

FSDP (Fully Sharded Data Parallel)

FSDP shards model parameters, gradients, and optimizer state across GPUs. This allows training models that are too large to fit on a single GPU. PyTorch FSDP is the open-source equivalent of ZeRO-3 (DeepSpeed). Use it when a DDP model exceeds single-GPU VRAM.

python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools

# Auto-wrap: shard layers with more than 1M parameters
my_auto_wrap_policy = functools.partial(
    size_based_auto_wrap_policy,
    min_num_params=1_000_000,
)
model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)

Gradient Checkpointing

Gradient checkpointing trades compute for memory. During the forward pass, activations are discarded instead of stored. During the backward pass, they are recomputed on the fly as needed. This reduces activation memory from O(layers) to O(sqrt(layers)), at the cost of ~30% more compute.

python
from torch.utils.checkpoint import checkpoint_sequential

def forward_with_checkpointing(model_layers: nn.Sequential, x: torch.Tensor) -> torch.Tensor:
    """
    Recompute activations during backward pass instead of storing them.
    Reduces memory by ~50-70% for large models.
    segments controls how many checkpointed segments to split into.
    """
    segments = 4
    return checkpoint_sequential(model_layers, segments, x)

# For transformer blocks: wrap individual blocks
class CheckpointedTransformerBlock(nn.Module):
    def __init__(self, block: nn.Module) -> None:
        super().__init__()
        self.block = block

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.utils.checkpoint.checkpoint(self.block, x)

PyTorch Profiler

Profile before optimising. The profiler tells you exactly where time is spent and where VRAM is consumed.

python
from torch.profiler import profile, record_function, ProfilerActivity

def profile_training_step(
    model: nn.Module,
    inputs: torch.Tensor,
    targets: torch.Tensor,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
) -> None:
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as prof:
        with record_function("forward"):
            outputs = model(inputs)
            loss    = criterion(outputs, targets)
        with record_function("backward"):
            loss.backward()
        with record_function("optimizer"):
            optimizer.step()
            optimizer.zero_grad()

    # Print top operations by CUDA time
    print(prof.key_averages().table(
        sort_by="cuda_time_total", row_limit=10
    ))
    # Export to TensorBoard or Chrome trace
    prof.export_chrome_trace("trace.json")

Training on Spot Instances

Spot (preemptible) instances cost 60-90% less than on-demand but can be interrupted with 2 minutes notice. Reliable spot training requires checkpoint-on-interrupt and resume-from-checkpoint.

python
import signal, os
from pathlib import Path

CHECKPOINT_DIR = Path("checkpoints/")
CHECKPOINT_DIR.mkdir(exist_ok=True)

def save_checkpoint(
    epoch: int,
    step: int,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scaler: GradScaler,
    loss: float,
) -> None:
    path = CHECKPOINT_DIR / f"ckpt_epoch{epoch:03d}_step{step:06d}.pt"
    torch.save({
        "epoch":      epoch,
        "step":       step,
        "model":      model.state_dict(),
        "optimizer":  optimizer.state_dict(),
        "scaler":     scaler.state_dict(),
        "loss":       loss,
    }, path)
    # Keep only the last 3 checkpoints
    ckpts = sorted(CHECKPOINT_DIR.glob("ckpt_*.pt"))
    for old in ckpts[:-3]:
        old.unlink()

def load_latest_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scaler: GradScaler,
) -> tuple[int, int]:
    ckpts = sorted(CHECKPOINT_DIR.glob("ckpt_*.pt"))
    if not ckpts:
        return 0, 0
    ckpt = torch.load(ckpts[-1], map_location="cpu")
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    scaler.load_state_dict(ckpt["scaler"])
    print(f"Resumed from epoch={ckpt['epoch']}, step={ckpt['step']}")
    return ckpt["epoch"], ckpt["step"]

def register_interrupt_handler(
    model, optimizer, scaler, current_epoch, current_step
) -> None:
    def handler(signum, frame):
        print("Interrupt received, saving checkpoint...")
        save_checkpoint(current_epoch[0], current_step[0],
                        model, optimizer, scaler, float("nan"))
        os._exit(0)
    signal.signal(signal.SIGTERM, handler)
    signal.signal(signal.SIGINT, handler)

Optuna Hyperparameter Search

Optuna uses Tree-structured Parzen Estimator (TPE) to intelligently explore hyperparameter space, and MedianPruner to stop unpromising trials early.

python
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler

def objective(trial: optuna.Trial) -> float:
    """Define hyperparameter search space and return validation metric."""
    # Suggest hyperparameters
    lr            = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    batch_size    = trial.suggest_categorical("batch_size", [32, 64, 128, 256])
    weight_decay  = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
    n_layers      = trial.suggest_int("n_layers", 2, 8)
    hidden_size   = trial.suggest_categorical("hidden_size", [128, 256, 512])
    dropout       = trial.suggest_float("dropout", 0.0, 0.5)

    # Build and train model with these hyperparameters
    model     = build_model(n_layers, hidden_size, dropout).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    loader    = make_optimised_dataloader(train_dataset, batch_size=batch_size)
    scaler    = GradScaler()

    best_val_auc = 0.0
    for epoch in range(20):
        train_one_epoch_mixed_precision(
            model, loader, optimizer, nn.BCEWithLogitsLoss(), scaler, device
        )
        val_auc = evaluate(model, val_loader)

        # Report intermediate value for pruning
        trial.report(val_auc, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

        best_val_auc = max(best_val_auc, val_auc)

    return best_val_auc

study = optuna.create_study(
    direction="maximize",
    sampler=TPESampler(seed=42),
    pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=3),
)
study.optimize(objective, n_trials=50, n_jobs=1)
print("Best params:", study.best_params)
print("Best ROC-AUC:", study.best_value)

Complete Optimised Training Loop

Assembling all the techniques into a single production training script.

python
import torch, mlflow, optuna
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from pathlib import Path

def train(
    model_cls,
    train_dataset,
    val_dataset,
    config: dict,
    run_name: str,
    checkpoint_every_n_steps: int = 500,
) -> float:
    device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model     = model_cls(
        n_layers    = config["n_layers"],
        hidden_size = config["hidden_size"],
        dropout     = config["dropout"],
    ).to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr           = config["lr"],
        weight_decay = config["weight_decay"],
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config["n_epochs"]
    )
    scaler    = GradScaler()
    criterion = nn.BCEWithLogitsLoss()

    train_loader = make_optimised_dataloader(
        train_dataset, batch_size=config["batch_size"], training=True
    )
    val_loader   = make_optimised_dataloader(
        val_dataset, batch_size=config["batch_size"] * 2, training=False
    )

    # Resume from checkpoint if available
    start_epoch, global_step = load_latest_checkpoint(model, optimizer, scaler)
    current_epoch = [start_epoch]
    current_step  = [global_step]
    register_interrupt_handler(model, optimizer, scaler, current_epoch, current_step)

    with mlflow.start_run(run_name=run_name):
        mlflow.log_params(config)

        for epoch in range(start_epoch, config["n_epochs"]):
            current_epoch[0] = epoch
            model.train()
            epoch_loss = 0.0

            for batch_idx, (inputs, targets) in enumerate(train_loader):
                inputs  = inputs.to(device, non_blocking=True)
                targets = targets.to(device, non_blocking=True)
                optimizer.zero_grad(set_to_none=True)

                with autocast(device_type="cuda", dtype=torch.float16):
                    outputs = model(inputs)
                    loss    = criterion(outputs, targets)

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()

                epoch_loss += loss.item()
                current_step[0] += 1

                if current_step[0] % checkpoint_every_n_steps == 0:
                    save_checkpoint(
                        epoch, current_step[0], model, optimizer, scaler, epoch_loss
                    )
                    mlflow.log_metric("train_loss", epoch_loss / (batch_idx + 1),
                                      step=current_step[0])

            scheduler.step()
            val_auc = evaluate(model, val_loader)
            mlflow.log_metric("val_auc", val_auc, step=epoch)
            print(f"Epoch {epoch}: loss={epoch_loss/len(train_loader):.4f} val_auc={val_auc:.4f}")

        # Log final model
        mlflow.pytorch.log_model(model, "model")
        return val_auc

Key Takeaways

  • GPU memory holds weights + gradients + optimizer state + activations; mixed precision halves activation memory and nearly doubles throughput via Tensor Cores.
  • Use BF16 for training (same range as FP32, avoids overflow) and FP16 for inference where range is predictable; always use GradScaler with FP16 to prevent gradient underflow.
  • Gradient accumulation allows large effective batch sizes without large VRAM; divide the loss by accumulation_steps to keep gradient magnitude consistent.
  • Set pin_memory=True, num_workers=cpu_count(), persistent_workers=True, and prefetch_factor=2 on every training DataLoader; DataLoader starvation wastes expensive GPU time.
  • Prefer DDP over DataParallel for multi-GPU training — DDP has no primary GPU bottleneck and scales linearly. Use FSDP when the model does not fit on a single GPU.
  • Gradient checkpointing reduces activation memory by 50-70% at a 30% compute cost; always enable it when VRAM is the constraint.
  • Profile before optimising: the PyTorch profiler will tell you whether you are CPU-bound, memory-bandwidth-bound, or actually compute-bound.
  • Checkpoint every few hundred steps and register SIGTERM handlers when training on spot instances; a 2-minute warning is enough to save state if your handler is in place.