Skip to main content
building large language models from scratch a beginners guide with python and pytorch

Scaling Up — From Toy Model to Real LLM

36 min read Chapter 9 of 11
Summary

This chapter bridges the gap between the toy...

This chapter bridges the gap between the toy model trained in CH7-CH8 and production-scale LLMs. Gradient accumulation simulates larger batch sizes within memory constraints. Mixed precision training with FP16/BF16 reduces memory usage and accelerates computation. Checkpointing enables resumable training across interruptions. Model size configurations compare parameter counts from 2M to 350M+ with corresponding hardware requirements. Scaling laws relate model size, data quantity, and compute to model capability. Data quality and quantity guidelines follow Chinchilla-optimal ratios. Weight initialization strategies ensure stable training convergence.

Scaling Up — From Toy Model to Real LLM

In the previous chapters, you built a GPT model from scratch, trained it on text, and generated new sentences from it. That model worked — the loss went down, the generated text was vaguely coherent, and you understood every piece of the pipeline from tokenization to generation.

But let’s be honest: the text quality was rough. Your model had maybe 2 million parameters, 2 transformer layers, and 128 dimensions. That’s a toy. ChatGPT has 175 billion parameters. Even the smallest useful open-source models have hundreds of millions. The gap between what you built and what actually works well is enormous.

This chapter bridges that gap. Not by building a 175-billion-parameter model (you’d need a data center for that), but by giving you every technique you need to scale up systematically. You’ll learn how to train with larger effective batch sizes without running out of memory, how to cut memory usage in half with mixed precision, how to save and resume training across days or weeks, and how to estimate the compute and data requirements for any model size.

By the end of this chapter, you will:

  • Understand why more parameters produce better models (scaling laws)
  • Implement gradient accumulation for larger effective batch sizes
  • Use mixed precision training (FP16/BF16) to halve memory usage
  • Save and load training checkpoints for resumable training
  • Know the parameter counts, hardware requirements, and training times for models from 2M to 350M+
  • Understand how much data you need (Chinchilla scaling)
  • Apply proper weight initialization for stable training
  • Estimate training costs on cloud hardware

Let’s scale up.


1. Why Scale Matters

Your tiny model from Chapter 7 had about 2 million parameters across 2 layers with 128 dimensions. It could learn simple patterns — common phrases, basic grammar, word associations. But it couldn’t hold a conversation, answer questions, or write coherent paragraphs. Why?

Think of parameters as the model’s memory capacity. Each parameter stores a tiny piece of information about language. With 2 million parameters, the model can memorize a few thousand patterns: “the cat sat on the,” “once upon a time,” “in the morning.” But human language has millions of subtle patterns — sarcasm, metaphor, technical jargon, narrative structure, logical reasoning. You need more storage space to capture them.

This isn’t just intuition. Researchers at OpenAI and DeepMind have discovered remarkably consistent scaling laws — mathematical relationships between three variables:

  1. Model size (number of parameters, $N$)
  2. Data size (number of training tokens, $D$)
  3. Compute budget (FLOPs spent training, $C$)

The key finding: model performance (measured by loss) follows a power law in each of these variables. Double the parameters and you get a predictable improvement. Double the data and you get a predictable improvement. The relationship is smooth — there aren’t sudden jumps where a model “learns to think.” It’s a gradual slope where more resources consistently produce better results.

$$L(N) \approx \left(\frac{N_c}{N}\right)^{\alpha_N}$$

Where $L$ is the loss, $N$ is the parameter count, and $\alpha_N \approx 0.076$. The loss decreases as a power law as you add more parameters.

What This Means in Practice

Here’s what different model sizes can typically do:

2M parameters   → Memorizes common phrases, poor grammar
25M parameters  → Reasonable grammar, limited coherence
124M parameters → Coherent paragraphs, basic reasoning (GPT-2 small)
350M parameters → Solid text quality, follows instructions somewhat
1.5B parameters → Good general text, can follow many instructions (GPT-2 XL)
7B+ parameters  → Conversational, can reason, write code (Llama-2 7B)
70B+ parameters → Strong reasoning, nuanced understanding (Llama-2 70B)
175B+ parameters → ChatGPT-3.5 territory

The jump from 2M to 124M is massive — it’s the difference between gibberish and readable text. The jump from 124M to 7B is another transformation — it’s the difference between readable text and genuinely useful AI. Each order of magnitude in parameters unlocks capabilities that simply didn’t exist at the previous scale.

But you can’t just make the model bigger and hope for the best. Larger models need:

  • More memory — a 124M model needs ~500MB; a 7B model needs ~28GB
  • More data — training a 124M model on too little data leads to overfitting
  • More compute — training time scales roughly linearly with parameter count
  • Better training techniques — the tricks that worked for 2M parameters become essential at scale

The rest of this chapter covers those techniques.


2. Gradient Accumulation

Here’s a common problem: research papers say “we trained with batch size 512” but your GPU can only fit batch size 8. Larger batches produce more stable gradient estimates because each update averages over more examples, reducing noise. But you simply don’t have enough memory for 512 examples at once.

Gradient accumulation solves this elegantly. Instead of processing 512 examples in one go, you process 8 examples 64 times, accumulating the gradients before making a single weight update. The math works out to be equivalent — you get the same gradient as if you’d processed all 512 at once.

Think of it like this: you need to carry 64 bags of groceries from your car to the kitchen. You could try to carry all 64 at once (batch size 64 — you’ll drop everything). Or you could make 8 trips carrying 8 bags each (gradient accumulation with 8 micro-batches of size 8). The groceries all end up in the same place.

The Math

In normal training, you compute the loss for a batch, compute gradients, and update weights:

$$\theta \leftarrow \theta - \eta \cdot \nabla_\theta L(\text{batch})$$

With gradient accumulation over $K$ micro-batches, you accumulate gradients and then update:

$$g = \frac{1}{K} \sum_{k=1}^{K} \nabla_\theta L(\text{micro-batch}_k)$$

$$\theta \leftarrow \theta - \eta \cdot g$$

This is mathematically identical to computing the gradient over the full batch of size $K \times \text{micro_batch_size}$.

Implementation

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

def train_with_gradient_accumulation(
    model,
    train_loader,
    optimizer,
    accumulation_steps=8,
    epochs=10,
):
    """
    Train with gradient accumulation. Effective batch size =
    actual batch size × accumulation_steps.

    Args:
        model: your GPT model
        train_loader: DataLoader with your training data
        optimizer: optimizer (e.g., AdamW)
        accumulation_steps: how many micro-batches to accumulate
        epochs: number of training epochs
    """
    model.train()
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()  # Zero gradients at the start

        for step, (input_ids, targets) in enumerate(train_loader):
            # Forward pass
            logits = model(input_ids)  # (batch, seq_len, vocab_size)

            # Reshape for cross-entropy
            logits = logits.view(-1, logits.size(-1))
            targets = targets.view(-1)
            loss = loss_fn(logits, targets)

            # Scale loss by accumulation steps so the average is correct
            loss = loss / accumulation_steps

            # Backward pass — gradients ACCUMULATE (they add up)
            loss.backward()

            total_loss += loss.item() * accumulation_steps

            # Only update weights every accumulation_steps
            if (step + 1) % accumulation_steps == 0:
                # Optional: clip gradients to prevent explosions
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                optimizer.step()     # Update weights using accumulated gradients
                optimizer.zero_grad()  # Reset gradients for next accumulation cycle

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} — Loss: {avg_loss:.4f}")

Why Loss Scaling Matters

Notice the line loss = loss / accumulation_steps. This is crucial. Without it, the accumulated gradient would be $K$ times too large, because loss.backward() adds gradients to the existing .grad tensors. Dividing by $K$ ensures the accumulated gradient equals the average over all micro-batches, just like a single large batch.

Choosing Accumulation Steps

# Your GPU fits batch size 8
# Research suggests batch size 256 works best for your model
actual_batch_size = 8
desired_effective_batch_size = 256
accumulation_steps = desired_effective_batch_size // actual_batch_size  # = 32

print(f"Actual batch size: {actual_batch_size}")
print(f"Accumulation steps: {accumulation_steps}")
print(f"Effective batch size: {actual_batch_size * accumulation_steps}")
# Effective batch size: 256
Effective Batch SizeStabilitySpeedMemory
8Noisy gradientsFast stepsLow
32ModerateModerateLow
128StableSlower per “real” stepLow
512Very stableSlow per “real” stepLow

The key insight: gradient accumulation trades time for memory. Each weight update takes $K$ times longer (because you do $K$ forward-backward passes), but memory usage stays the same as processing a single micro-batch. For most training scenarios, this trade-off is well worth it.


3. Mixed Precision Training (FP16/BF16)

By default, PyTorch stores every number as a 32-bit floating-point value (FP32). Each model parameter, each gradient, each intermediate activation — all 32 bits. But do you really need 32 bits of precision for every calculation?

For most of the math in training, the answer is no. Mixed precision training uses 16-bit numbers (half the storage) for the bulk of computation, keeping 32-bit only where precision really matters. This gives you two benefits simultaneously:

  1. 2× memory savings — 16-bit numbers take half the space
  2. Faster math — modern GPUs have specialized hardware for 16-bit operations (Tensor Cores on NVIDIA GPUs)

Think of it like sketching. When an artist draws a complex scene, they start with rough pencil sketches (low precision) to figure out composition, proportions, and layout. Only the final details are rendered in fine detail (high precision). The rough sketches need to be roughly right, not perfectly precise. Mixed precision works the same way — use rough math for the bulk of computation, and fine math only for the critical accumulation steps.

FP16 vs BF16

There are two 16-bit formats:

  • FP16 (float16): 1 sign bit + 5 exponent bits + 10 mantissa bits. More precision but smaller range. Can cause overflows/underflows with large gradients.
  • BF16 (bfloat16): 1 sign bit + 8 exponent bits + 7 mantissa bits. Less precision but same range as FP32. More stable for training.
FP32:  1 sign | 8 exponent | 23 mantissa  →  ~7 decimal digits, range ±3.4×10³⁸
FP16:  1 sign | 5 exponent | 10 mantissa  →  ~3 decimal digits, range ±65,504
BF16:  1 sign | 8 exponent |  7 mantissa  →  ~2 decimal digits, range ±3.4×10³⁸

BF16 is generally preferred for training because it has the same exponent range as FP32, avoiding overflow issues. FP16 is more widely supported on older hardware. Most modern GPUs (NVIDIA A100+, H100) support both.

Implementation with PyTorch AMP

PyTorch’s Automatic Mixed Precision (AMP) makes this easy:

import torch
from torch.cuda.amp import autocast, GradScaler

def train_with_mixed_precision(
    model,
    train_loader,
    optimizer,
    epochs=10,
    accumulation_steps=1,
    device="cuda",
):
    """
    Train with mixed precision (FP16) for 2x memory savings and faster math.

    Requires a CUDA GPU. Falls back to FP32 on CPU.
    """
    model = model.to(device)
    loss_fn = torch.nn.CrossEntropyLoss()

    # GradScaler prevents FP16 underflow by scaling the loss
    scaler = GradScaler()

    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()

        for step, (input_ids, targets) in enumerate(train_loader):
            input_ids = input_ids.to(device)
            targets = targets.to(device)

            # autocast: automatically use FP16 where safe, FP32 where needed
            with autocast():
                logits = model(input_ids)
                logits = logits.view(-1, logits.size(-1))
                targets = targets.view(-1)
                loss = loss_fn(logits, targets)
                loss = loss / accumulation_steps

            # Scale loss to prevent FP16 gradient underflow, then backward
            scaler.scale(loss).backward()

            total_loss += loss.item() * accumulation_steps

            if (step + 1) % accumulation_steps == 0:
                # Unscale gradients before clipping
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Step optimizer with scaled gradients
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} — Loss: {avg_loss:.4f}")

What the GradScaler Does

FP16 has a problem: very small gradient values (like $10^{-8}$) get rounded to zero because FP16 can’t represent numbers that small. If gradients are zero, the model stops learning.

The GradScaler fixes this by multiplying the loss by a large number (say, 1024) before computing gradients. This “scales up” all the gradients so they’re large enough to survive FP16 rounding. Then, before the optimizer step, it divides the gradients back down by the same factor. The result is numerically equivalent to FP32 training, but the intermediate values are big enough to fit in FP16.

Without scaling:
  loss = 0.003 → gradient = 0.00000012 → FP16 rounds to 0.0 → PROBLEM

With scaling (scale factor = 1024):
  scaled_loss = 0.003 × 1024 = 3.072 → gradient = 0.000123 → FP16 keeps it
  → unscale: 0.000123 / 1024 = 0.00000012 → correct gradient recovered

The GradScaler also dynamically adjusts the scale factor — if it detects inf or nan values (the scale was too high), it reduces the scale and skips that update.

When Mixed Precision Helps (and When It Doesn’t)

SituationBenefit
Training on modern GPU (A100, H100, RTX 3090+)2-3× speedup, 2× memory savings
Training on older GPU (V100, RTX 2080)1.5-2× speedup, 2× memory savings
Training on CPUNo benefit — CPUs don’t have FP16 acceleration
Inference onlyModerate speedup, good memory savings
Very small models (< 10M params)Overhead may negate gains
# Check if your hardware supports mixed precision
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")

    # Check for BF16 support (Ampere and newer)
    if torch.cuda.is_bf16_supported():
        print("BF16 supported — use dtype=torch.bfloat16 for best results")
    else:
        print("BF16 not supported — use FP16 with GradScaler")
else:
    print("No GPU available — mixed precision won't help, use FP32")

4. Checkpointing and Resuming

Training a real language model takes days, weeks, or even months. Machines crash. Power goes out. Cloud instances get preempted. You accidentally close the terminal. Without checkpointing, all that training is lost and you start from scratch.

Checkpointing saves a snapshot of everything needed to resume training: the model weights, the optimizer state (momentum, adaptive learning rates), the current epoch, and the current loss. When you restart, you load the checkpoint and pick up exactly where you left off.

Think of it like saving your progress in a video game. You wouldn’t play for 40 hours without saving. Training a model for 40 hours without checkpointing is equally reckless.

What to Save

A checkpoint needs to include:

  1. Model state dict — all the learned weights
  2. Optimizer state dict — Adam’s momentum and variance estimates (without these, the optimizer “forgets” its history and training quality degrades)
  3. Epoch number — so you know where you are
  4. Step number — for learning rate schedulers that depend on step count
  5. Loss — to verify training resumed correctly
  6. Any other state — learning rate scheduler, gradient scaler (for mixed precision), random number generator state

Implementation

import os
import torch


def save_checkpoint(
    model,
    optimizer,
    epoch,
    step,
    loss,
    path="checkpoints/checkpoint.pt",
    scaler=None,
    scheduler=None,
):
    """
    Save a training checkpoint.

    Args:
        model: the model being trained
        optimizer: the optimizer
        epoch: current epoch number
        step: current global step
        loss: current loss value
        path: where to save the checkpoint file
        scaler: GradScaler (if using mixed precision)
        scheduler: learning rate scheduler (if using one)
    """
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(path), exist_ok=True)

    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "step": step,
        "loss": loss,
    }

    # Optional components
    if scaler is not None:
        checkpoint["scaler_state_dict"] = scaler.state_dict()
    if scheduler is not None:
        checkpoint["scheduler_state_dict"] = scheduler.state_dict()

    torch.save(checkpoint, path)
    print(f"Checkpoint saved: epoch={epoch}, step={step}, loss={loss:.4f}")


def load_checkpoint(path, model, optimizer, scaler=None, scheduler=None):
    """
    Load a training checkpoint and restore all state.

    Returns:
        epoch: the epoch to resume from
        step: the step to resume from
        loss: the loss at checkpoint time
    """
    checkpoint = torch.load(path, weights_only=False)

    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    if scaler is not None and "scaler_state_dict" in checkpoint:
        scaler.load_state_dict(checkpoint["scaler_state_dict"])
    if scheduler is not None and "scheduler_state_dict" in checkpoint:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

    epoch = checkpoint["epoch"]
    step = checkpoint["step"]
    loss = checkpoint["loss"]

    print(f"Checkpoint loaded: epoch={epoch}, step={step}, loss={loss:.4f}")
    return epoch, step, loss

Training Loop with Periodic Checkpointing

def train_with_checkpointing(
    model,
    train_loader,
    optimizer,
    epochs=100,
    checkpoint_dir="checkpoints",
    save_every_n_epochs=5,
    resume_from=None,
):
    """
    Full training loop with periodic checkpointing and optional resume.

    Args:
        model: your GPT model
        train_loader: DataLoader with training data
        optimizer: optimizer (e.g., AdamW)
        epochs: total number of epochs
        checkpoint_dir: where to save checkpoints
        save_every_n_epochs: save a checkpoint every N epochs
        resume_from: path to a checkpoint to resume from (or None)
    """
    loss_fn = torch.nn.CrossEntropyLoss()
    start_epoch = 0
    global_step = 0

    # Resume from checkpoint if provided
    if resume_from is not None and os.path.exists(resume_from):
        start_epoch, global_step, last_loss = load_checkpoint(
            resume_from, model, optimizer
        )
        start_epoch += 1  # Start from the NEXT epoch
        print(f"Resuming from epoch {start_epoch}")

    model.train()
    for epoch in range(start_epoch, epochs):
        total_loss = 0

        for step, (input_ids, targets) in enumerate(train_loader):
            optimizer.zero_grad()

            logits = model(input_ids)
            logits = logits.view(-1, logits.size(-1))
            targets = targets.view(-1)

            loss = loss_fn(logits, targets)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            global_step += 1

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} — Loss: {avg_loss:.4f}")

        # Save checkpoint periodically
        if (epoch + 1) % save_every_n_epochs == 0:
            save_checkpoint(
                model, optimizer, epoch, global_step, avg_loss,
                path=os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pt"),
            )

    # Always save final checkpoint
    save_checkpoint(
        model, optimizer, epochs - 1, global_step, avg_loss,
        path=os.path.join(checkpoint_dir, "checkpoint_final.pt"),
    )

Checkpoint Management

Training for 100 epochs with save_every_n_epochs=5 creates 20 checkpoint files. Each one stores the full model and optimizer state, which can be hundreds of megabytes. A practical strategy is to keep only the last few:

def cleanup_old_checkpoints(checkpoint_dir, keep_last_n=3):
    """Keep only the most recent N checkpoints, delete the rest."""
    import glob

    checkpoint_files = sorted(
        glob.glob(os.path.join(checkpoint_dir, "checkpoint_epoch_*.pt")),
        key=os.path.getmtime,
    )

    # Delete all but the last N
    files_to_delete = checkpoint_files[:-keep_last_n]
    for f in files_to_delete:
        os.remove(f)
        print(f"Deleted old checkpoint: {f}")

Tip: Always keep the “best” checkpoint (lowest validation loss) in addition to the most recent ones. You may want to roll back to the best checkpoint if the model starts overfitting.


5. Model Size Configurations

Now let’s talk specifics. How do you configure a model for different sizes? In Chapter 6, we defined our model with a config object that includes d_model, n_layers, n_heads, and other hyperparameters. Here are practical configurations spanning three orders of magnitude:

NameParamsd_modelLayersHeadsd_ffContextTraining Time
Tiny~2M12824512256Minutes (CPU)
Small~25M384661536512Hours (1 GPU)
Medium~124M768121230721024Days (1 GPU)
Large~350M1024241640961024Weeks (1+ GPU)
XL~1.5B1600482564001024Weeks (multi-GPU)
# Model configurations you can directly use

configs = {
    "tiny": {
        "vocab_size": 50257,
        "d_model": 128,
        "n_layers": 2,
        "n_heads": 4,
        "d_ff": 512,
        "max_seq_len": 256,
        "dropout": 0.1,
    },
    "small": {
        "vocab_size": 50257,
        "d_model": 384,
        "n_layers": 6,
        "n_heads": 6,
        "d_ff": 1536,
        "max_seq_len": 512,
        "dropout": 0.1,
    },
    "medium_gpt2": {
        "vocab_size": 50257,
        "d_model": 768,
        "n_layers": 12,
        "n_heads": 12,
        "d_ff": 3072,
        "max_seq_len": 1024,
        "dropout": 0.1,
    },
    "large": {
        "vocab_size": 50257,
        "d_model": 1024,
        "n_layers": 24,
        "n_heads": 16,
        "d_ff": 4096,
        "max_seq_len": 1024,
        "dropout": 0.1,
    },
}

Counting Parameters

How many parameters does each configuration actually have? Let’s count:

def count_parameters(config):
    """
    Estimate the parameter count for a transformer language model.

    Major parameter groups:
    - Token embedding: vocab_size × d_model
    - Position embedding: max_seq_len × d_model
    - Per layer:
        - Attention (Q, K, V, O projections): 4 × d_model²
        - Feed-forward: 2 × d_model × d_ff
        - Layer norms: 4 × d_model
    - Final layer norm: 2 × d_model
    - Output projection: d_model × vocab_size (often tied with embedding)
    """
    V = config["vocab_size"]
    D = config["d_model"]
    L = config["n_layers"]
    F = config["d_ff"]
    S = config["max_seq_len"]

    token_embedding = V * D
    position_embedding = S * D

    # Each layer
    attention = 4 * D * D          # Q, K, V, Output projections
    feed_forward = 2 * D * F      # Up-projection + down-projection
    layer_norms = 4 * D           # 2 layer norms × 2 params each (weight + bias)
    per_layer = attention + feed_forward + layer_norms

    total_layers = L * per_layer
    final_layer_norm = 2 * D
    # Output projection often tied with token embedding, so we don't count it twice

    total = token_embedding + position_embedding + total_layers + final_layer_norm

    return total

for name, config in configs.items():
    params = count_parameters(config)
    print(f"{name:15s}: {params:>12,} parameters ({params/1e6:.1f}M)")

# Output:
# tiny           :    6,554,624 parameters (6.6M)
# small          :   29,649,408 parameters (29.6M)
# medium_gpt2    :  124,439,808 parameters (124.4M)
# large          :  303,235,072 parameters (303.2M)

The dominant cost is the transformer layers. The token embedding is large (50257 × d_model), but the per-layer attention and feed-forward parameters grow as $d_{model}^2$, which dominates as models get bigger.

Memory Requirements

A rough rule: each parameter takes 4 bytes (FP32) or 2 bytes (FP16/BF16). During training, you also need memory for optimizer state (Adam stores 2 extra values per parameter) and gradients:

def estimate_memory(config, precision="fp32"):
    """Estimate training memory in GB."""
    params = count_parameters(config)
    bytes_per_param = 4 if precision == "fp32" else 2

    model_memory = params * bytes_per_param
    # Adam stores momentum + variance = 2 extra FP32 values per parameter
    optimizer_memory = params * 4 * 2
    gradient_memory = params * bytes_per_param
    # Activations depend on batch size and sequence length (rough estimate)
    activation_memory = params * bytes_per_param * 2  # very rough

    total = model_memory + optimizer_memory + gradient_memory + activation_memory
    total_gb = total / (1024 ** 3)
    return total_gb

for name, config in configs.items():
    mem_fp32 = estimate_memory(config, "fp32")
    mem_fp16 = estimate_memory(config, "fp16")
    print(f"{name:15s}: FP32 ~{mem_fp32:.1f} GB | FP16 ~{mem_fp16:.1f} GB")

Note: These are rough estimates. Actual memory depends heavily on batch size, sequence length, and whether activation checkpointing is used. The optimizer state (Adam’s momentum and variance) is always stored in FP32 even during mixed precision training — that’s where the “mixed” in mixed precision comes from.


6. Scaling Laws and Data Requirements

How much data do you need? This turns out to have a surprisingly precise answer, thanks to the Chinchilla scaling laws published by DeepMind in 2022.

The Chinchilla paper found that most existing LLMs were over-parameterized and under-trained — they had too many parameters relative to the amount of training data. The optimal balance is roughly:

$$D_{\text{optimal}} \approx 20 \times N$$

Where $D$ is the number of training tokens and $N$ is the number of parameters. In other words, you need about 20 tokens of training data per parameter.

What This Means for Different Model Sizes

Model Size    | Parameters |  Optimal Training Tokens | ~Data Size
--------------+------------+--------------------------+----------
Tiny (2M)     | 2,000,000  |  40,000,000 (40M)        | ~200 MB text
Small (25M)   | 25,000,000 |  500,000,000 (500M)      | ~2.5 GB text
Medium (124M) | 124,000,000| 2,480,000,000 (2.5B)     | ~12 GB text
Large (350M)  | 350,000,000| 7,000,000,000 (7B)       | ~35 GB text
XL (1.5B)     | 1,500,000,000| 30,000,000,000 (30B)   | ~150 GB text

For reference, common training datasets:

  • Wikipedia English ≈ 4 billion tokens (~16 GB text)
  • BookCorpus ≈ 1 billion tokens (~4 GB text)
  • The Pile ≈ 380 billion tokens (~800 GB text)
  • RedPajama ≈ 1.2 trillion tokens
  • Common Crawl ≈ hundreds of trillions of tokens (raw, needs filtering)

Your tiny model from Chapter 7 could be adequately trained on a single book. A GPT-2 sized model needs Wikipedia-scale data. Larger models need curated web dumps.

Data Quality Matters More Than Quantity

Having 100 GB of text doesn’t help if it’s full of spam, duplicates, and nonsense. In fact, data quality often matters more than data quantity. A model trained on 10 GB of clean, well-written text will outperform one trained on 100 GB of noisy web scrapes.

Here’s a simple data filtering pipeline:

import re
from collections import Counter


def filter_text(text, min_line_length=20, max_duplicate_ratio=0.3):
    """
    Basic text quality filter for training data.

    Args:
        text: raw text string
        min_line_length: drop lines shorter than this
        max_duplicate_ratio: drop documents with too many duplicate lines

    Returns:
        cleaned text or None if document is too low quality
    """
    lines = text.split("\n")

    # Filter out very short lines (likely navigation, headers, etc.)
    lines = [line.strip() for line in lines if len(line.strip()) >= min_line_length]

    if len(lines) == 0:
        return None

    # Check for excessive duplication
    line_counts = Counter(lines)
    duplicate_lines = sum(count - 1 for count in line_counts.values() if count > 1)
    duplicate_ratio = duplicate_lines / len(lines)

    if duplicate_ratio > max_duplicate_ratio:
        return None  # Too many duplicates — likely boilerplate

    # Remove lines that look like boilerplate
    filtered_lines = []
    boilerplate_patterns = [
        r"^copyright\s",
        r"^all rights reserved",
        r"^click here",
        r"^subscribe to",
        r"^cookie policy",
        r"^terms of service",
    ]

    for line in lines:
        is_boilerplate = any(
            re.search(pattern, line.lower()) for pattern in boilerplate_patterns
        )
        if not is_boilerplate:
            filtered_lines.append(line)

    if len(filtered_lines) < 3:
        return None  # Too little content left

    return "\n".join(filtered_lines)


def estimate_tokens(text, chars_per_token=4):
    """Rough estimate: ~4 characters per token for English text."""
    return len(text) // chars_per_token


# Example usage
raw_texts = [
    "This is a well-written article about machine learning...",  # good
    "Click here\nSubscribe\nCookies\nTerms",  # boilerplate
    "The same line\n" * 100,  # duplicated
]

for i, text in enumerate(raw_texts):
    result = filter_text(text)
    status = "KEPT" if result else "DROPPED"
    print(f"Document {i}: {status}")

The Deduplication Problem

Another critical issue is deduplication. If the same Wikipedia article appears 50 times in your training set, the model will memorize it rather than learning general patterns. Even near-duplicates (similar articles with minor differences) can hurt.

At scale, deduplication uses techniques like MinHash or SimHash to efficiently find near-duplicate documents. For smaller datasets, a simpler approach works:

import hashlib

def deduplicate_documents(documents):
    """Remove exact duplicate documents using content hashing."""
    seen_hashes = set()
    unique_documents = []

    for doc in documents:
        # Normalize whitespace before hashing
        normalized = " ".join(doc.split())
        doc_hash = hashlib.md5(normalized.encode()).hexdigest()

        if doc_hash not in seen_hashes:
            seen_hashes.add(doc_hash)
            unique_documents.append(doc)

    removed = len(documents) - len(unique_documents)
    print(f"Removed {removed} duplicates ({removed/len(documents)*100:.1f}%)")
    return unique_documents

7. Compute Requirements

How long does training actually take? And how much does it cost?

Estimating FLOPs

The total floating-point operations (FLOPs) needed to train a transformer is approximately:

$$C \approx 6 \times N \times D$$

Where $C$ is total FLOPs, $N$ is parameters, and $D$ is training tokens. The factor of 6 comes from: forward pass ≈ 2ND, backward pass ≈ 4ND.

def estimate_training_time(
    num_params,
    num_tokens,
    gpu_flops_per_second,
    utilization=0.3,
):
    """
    Estimate training time.

    Args:
        num_params: model parameter count
        num_tokens: total training tokens
        gpu_flops_per_second: peak GPU throughput
        utilization: what fraction of peak you actually achieve (0.3-0.5 typical)

    Returns:
        estimated hours
    """
    total_flops = 6 * num_params * num_tokens
    effective_flops = gpu_flops_per_second * utilization
    seconds = total_flops / effective_flops
    hours = seconds / 3600
    return hours, total_flops

# GPU FLOPS (approximate peak FP16/BF16 throughput)
gpus = {
    "RTX 3090":      71e12,    # 71 TFLOPS
    "RTX 4090":      165e12,   # 165 TFLOPS
    "A100 (80GB)":   312e12,   # 312 TFLOPS
    "H100":          990e12,   # 990 TFLOPS
}

# Example: training a 124M model on 2.5B tokens
num_params = 124_000_000
num_tokens = 2_500_000_000

print(f"Model: {num_params/1e6:.0f}M params, {num_tokens/1e9:.1f}B tokens")
print(f"{'GPU':<20} {'Hours':>10} {'Days':>8} {'Cost*':>10}")
print("-" * 52)

for gpu_name, flops in gpus.items():
    hours, total_flops = estimate_training_time(num_params, num_tokens, flops)
    days = hours / 24
    # Rough cloud costs per GPU-hour
    cost_per_hour = {"RTX 3090": 0.5, "RTX 4090": 0.8, "A100 (80GB)": 2.0, "H100": 3.5}
    cost = hours * cost_per_hour.get(gpu_name, 1.0)
    print(f"{gpu_name:<20} {hours:>10.1f} {days:>8.1f} {f'${cost:.0f}':>10}")

# Output:
# Model: 124M params, 2.5B tokens
# GPU                      Hours     Days      Cost*
# ----------------------------------------------------
# RTX 3090                  87.3     3.6       $44
# RTX 4090                  37.6     1.6       $30
# A100 (80GB)               19.9     0.8       $40
# H100                       6.3     0.3       $22

*Cloud costs are approximate and vary by provider. As of 2024, A100 instances cost $1.50-$3.00/hour on major cloud providers. Spot/preemptible instances can be 60-70% cheaper.

Why Most People Fine-Tune

Look at those numbers. Even for a “small” 124M model, you need days of GPU time and billions of tokens. For a 7B model, multiply everything by ~50. Training from scratch requires:

  • Data you probably don’t have (filtered, deduplicated, diverse)
  • Compute that costs thousands of dollars minimum
  • Expertise to debug training instabilities at scale
  • Time — weeks to months for useful model sizes

This is why fine-tuning is the practical approach for most people. You start with a pre-trained model (someone else already spent the millions of dollars on pre-training), then train it on your specific, smaller dataset. With fine-tuning, you can adapt a 7B model in hours with a single GPU. We’ll cover fine-tuning in Chapter 10.

Training from scratch makes sense when:

  • You’re a large company building a foundation model
  • You need a model for a language not well-represented in existing models
  • You have unique data that requires a specialized tokenizer
  • You’re learning (like in this book!)

8. Weight Initialization

Before training starts, every weight in your model is random. Surprisingly, how random matters enormously. Bad initialization can cause training to fail completely, even if everything else is correct.

The Problem with Naive Initialization

Consider what happens in a deep neural network when you multiply by layer after layer of weights:

import torch

# Simulate passing through 12 transformer layers
d_model = 768
x = torch.randn(1, d_model)  # input with reasonable values

# Bad initialization: weights too large
for layer in range(12):
    W = torch.randn(d_model, d_model) * 0.1  # std=0.1
    x = x @ W

print(f"Output magnitude after 12 layers: {x.abs().mean().item():.6f}")
# Could be astronomically large or essentially zero

If weights are too large, the signal explodes — activations grow exponentially, eventually becoming inf. If weights are too small, the signal vanishes — activations shrink to zero, and gradients disappear with them. In either case, the model can’t learn.

The visualization:

Too large: input → 1.0 → 3.2 → 10.5 → 34.1 → 112 → 365 → NaN → CRASH
Too small: input → 1.0 → 0.3 → 0.09 → 0.03 → 0.01 → 0.003 → ~0 → STUCK
Just right: input → 1.0 → 0.95 → 1.05 → 0.98 → 1.02 → 0.99 → STABLE

Xavier (Glorot) Initialization

Xavier initialization sets weights so that the variance of the output equals the variance of the input. For a linear layer with $n_{in}$ inputs and $n_{out}$ outputs:

$$W \sim \mathcal{N}\left(0, \frac{2}{n_{in} + n_{out}}\right)$$

This works well for layers with symmetric activation functions (like tanh or no activation).

Kaiming (He) Initialization

For layers followed by ReLU (which zeros out negative values), Xavier isn’t quite right because ReLU cuts the variance in half. Kaiming initialization compensates:

$$W \sim \mathcal{N}\left(0, \frac{2}{n_{in}}\right)$$

GPT-Style Initialization

In practice, GPT-2 and most modern transformers use a simpler scheme:

def init_weights(module, n_layers):
    """
    Initialize weights following GPT-2 conventions.

    - Linear layers: normal distribution with std=0.02
    - Embedding layers: normal distribution with std=0.02
    - Residual projections: scaled by 1/sqrt(2*n_layers) to prevent
      accumulation across the residual stream
    - Layer norm: weight=1, bias=0
    - Biases: zero
    """
    if isinstance(module, nn.Linear):
        # Standard initialization
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)

    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    elif isinstance(module, nn.LayerNorm):
        torch.nn.init.ones_(module.weight)
        torch.nn.init.zeros_(module.bias)


def init_residual_projections(model, n_layers):
    """
    Scale the output projection of each transformer block.

    The residual stream accumulates contributions from every layer.
    Without scaling, the variance grows with the number of layers.
    Scaling by 1/sqrt(2*n_layers) keeps the variance stable.

    The factor of 2 accounts for two residual connections per layer:
    one after attention and one after the feed-forward network.
    """
    scale = 1.0 / (2 * n_layers) ** 0.5

    for name, param in model.named_parameters():
        # Find the output projection in each attention and FFN block
        if "W_o" in name or "fc2" in name:  # attention output or FFN output
            with torch.no_grad():
                param.mul_(scale)


# Usage
model = GPTModel(config)
model.apply(lambda m: init_weights(m, config["n_layers"]))
init_residual_projections(model, config["n_layers"])

What Happens with Bad Initialization

# Demonstration: training with different initializations

def test_initialization(init_std, n_steps=100):
    """Train for a few steps and track loss to show initialization effect."""
    model = GPTModel(config)

    # Override with our custom std
    for param in model.parameters():
        torch.nn.init.normal_(param, mean=0.0, std=init_std)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss()

    losses = []
    for step in range(n_steps):
        # Dummy data for demonstration
        x = torch.randint(0, config["vocab_size"], (4, 32))
        logits = model(x)
        loss = loss_fn(logits.view(-1, logits.size(-1)), x.view(-1))

        if torch.isnan(loss) or torch.isinf(loss):
            losses.append(float("inf"))
            break

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())

    return losses

# Compare
for std in [0.001, 0.02, 0.1, 1.0]:
    losses = test_initialization(std)
    final = losses[-1] if losses[-1] != float("inf") else "DIVERGED"
    print(f"init std={std:<6} → final loss: {final}")

# Expected:
# init std=0.001  → final loss: 10.82 (very slow learning — signal too small)
# init std=0.02   → final loss: 8.45  (good — this is the standard)
# init std=0.1    → final loss: 9.31  (unstable, higher loss)
# init std=1.0    → final loss: DIVERGED (loss goes to NaN immediately)

The sweet spot of std=0.02 is not arbitrary — it’s calibrated so that the output variance stays stable across the network’s depth.


9. Putting It All Together

Here’s a complete training script that combines gradient accumulation, mixed precision, checkpointing, and proper initialization — everything you need to train a model larger than the toy version from Chapter 7:

import os
import time
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader


def train_scaled(
    model,
    train_loader,
    config,
    epochs=50,
    lr=3e-4,
    accumulation_steps=8,
    use_mixed_precision=True,
    checkpoint_dir="checkpoints",
    save_every_n_epochs=5,
    resume_from=None,
    device="cuda",
):
    """
    Production-ready training loop with all scaling techniques.
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler(enabled=use_mixed_precision and device == "cuda")

    start_epoch = 0
    global_step = 0

    # Resume if checkpoint exists
    if resume_from and os.path.exists(resume_from):
        start_epoch, global_step, _ = load_checkpoint(
            resume_from, model, optimizer, scaler=scaler
        )
        start_epoch += 1

    model.train()
    for epoch in range(start_epoch, epochs):
        epoch_loss = 0
        epoch_start = time.time()
        optimizer.zero_grad()

        for step, (input_ids, targets) in enumerate(train_loader):
            input_ids = input_ids.to(device)
            targets = targets.to(device)

            # Mixed precision forward pass
            with autocast(enabled=use_mixed_precision and device == "cuda"):
                logits = model(input_ids)
                logits = logits.view(-1, logits.size(-1))
                targets_flat = targets.view(-1)
                loss = loss_fn(logits, targets_flat)
                loss = loss / accumulation_steps

            # Backward with gradient scaling
            scaler.scale(loss).backward()
            epoch_loss += loss.item() * accumulation_steps

            # Update weights every accumulation_steps
            if (step + 1) % accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                global_step += 1

        # Epoch stats
        elapsed = time.time() - epoch_start
        avg_loss = epoch_loss / len(train_loader)
        tokens_per_sec = (
            len(train_loader) * train_loader.batch_size
            * input_ids.size(1) / elapsed
        )
        print(
            f"Epoch {epoch+1}/{epochs} | "
            f"Loss: {avg_loss:.4f} | "
            f"Time: {elapsed:.0f}s | "
            f"Tokens/s: {tokens_per_sec:.0f}"
        )

        # Checkpoint
        if (epoch + 1) % save_every_n_epochs == 0:
            save_checkpoint(
                model, optimizer, epoch, global_step, avg_loss,
                path=os.path.join(checkpoint_dir, f"ckpt_epoch_{epoch+1}.pt"),
                scaler=scaler,
            )

    return model

This single function incorporates every scaling technique from this chapter. Start with the “small” config (25M parameters), then graduate to “medium” (124M) when you’re comfortable.


10. Exercises

Exercise 1: Gradient Accumulation Verification

Prove that gradient accumulation produces the same result as a large batch. Write a script that:

  1. Creates a tiny model and a dataset of 32 examples
  2. Trains for 1 step with batch_size=32 (no accumulation)
  3. Trains for 1 step with batch_size=8 and accumulation_steps=4
  4. Compares the resulting model weights — they should be identical
Solution
import torch
import torch.nn as nn
import copy

# Tiny model for testing
torch.manual_seed(42)
model = nn.Linear(10, 5)

# Fixed data: 32 examples
torch.manual_seed(0)
X = torch.randn(32, 10)
Y = torch.randint(0, 5, (32,))

loss_fn = nn.CrossEntropyLoss()

# --- Method 1: Single large batch (batch_size=32) ---
model_large = copy.deepcopy(model)
optimizer = torch.optim.SGD(model_large.parameters(), lr=0.01)

optimizer.zero_grad()
logits = model_large(X)
loss = loss_fn(logits, Y)
loss.backward()
optimizer.step()

weights_large = model_large.weight.data.clone()

# --- Method 2: Gradient accumulation (batch_size=8, 4 accumulations) ---
model_accum = copy.deepcopy(model)
optimizer = torch.optim.SGD(model_accum.parameters(), lr=0.01)

optimizer.zero_grad()
accumulation_steps = 4
batch_size = 8

for i in range(accumulation_steps):
    start = i * batch_size
    end = start + batch_size
    X_batch = X[start:end]
    Y_batch = Y[start:end]

    logits = model_accum(X_batch)
    loss = loss_fn(logits, Y_batch) / accumulation_steps  # Scale!
    loss.backward()

optimizer.step()

weights_accum = model_accum.weight.data.clone()

# --- Compare ---
diff = (weights_large - weights_accum).abs().max().item()
print(f"Max weight difference: {diff:.10f}")
# Should be very close to 0 (floating-point precision ~1e-7)

if diff < 1e-5:
    print("PASS: Gradient accumulation produces the same result!")
else:
    print("FAIL: Results differ — check the loss scaling")

Key insight: The / accumulation_steps scaling is essential. Without it, the accumulated gradient would be 4× too large, producing different weights.

Exercise 2: Checkpoint Resume Test

Write a script that:

  1. Trains a model for 10 epochs, saving a checkpoint at epoch 5
  2. Creates a new model instance and loads the epoch-5 checkpoint
  3. Trains the resumed model for 5 more epochs
  4. Compares the final loss of the continuous run (10 epochs straight) vs the resumed run (5 + 5)

They should produce the same final loss if checkpointing is correct.

Solution
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import os

torch.manual_seed(42)

# Simple model and data
def make_model():
    torch.manual_seed(42)
    return nn.Sequential(nn.Linear(20, 50), nn.ReLU(), nn.Linear(50, 10))

def make_data():
    torch.manual_seed(0)
    X = torch.randn(256, 20)
    Y = torch.randint(0, 10, (256,))
    return DataLoader(TensorDataset(X, Y), batch_size=32, shuffle=False)

loss_fn = nn.CrossEntropyLoss()

# --- Run 1: Train 10 epochs straight ---
model1 = make_model()
opt1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
loader = make_data()

for epoch in range(10):
    for X, Y in loader:
        opt1.zero_grad()
        loss = loss_fn(model1(X), Y)
        loss.backward()
        opt1.step()
    if epoch == 9:
        print(f"Continuous — epoch {epoch+1}, loss: {loss.item():.6f}")
        continuous_loss = loss.item()

# --- Run 2: Train 5 epochs, checkpoint, load, train 5 more ---
model2 = make_model()
opt2 = torch.optim.Adam(model2.parameters(), lr=1e-3)
loader = make_data()

for epoch in range(5):
    for X, Y in loader:
        opt2.zero_grad()
        loss = loss_fn(model2(X), Y)
        loss.backward()
        opt2.step()

# Save checkpoint at epoch 5
os.makedirs("test_checkpoints", exist_ok=True)
torch.save({
    "model_state_dict": model2.state_dict(),
    "optimizer_state_dict": opt2.state_dict(),
    "epoch": 4,
}, "test_checkpoints/test_ckpt.pt")

# Load into new model
model3 = make_model()
opt3 = torch.optim.Adam(model3.parameters(), lr=1e-3)
ckpt = torch.load("test_checkpoints/test_ckpt.pt", weights_only=False)
model3.load_state_dict(ckpt["model_state_dict"])
opt3.load_state_dict(ckpt["optimizer_state_dict"])

# Train 5 more epochs
loader = make_data()
for epoch in range(5, 10):
    for X, Y in loader:
        opt3.zero_grad()
        loss = loss_fn(model3(X), Y)
        loss.backward()
        opt3.step()
    if epoch == 9:
        print(f"Resumed   — epoch {epoch+1}, loss: {loss.item():.6f}")
        resumed_loss = loss.item()

diff = abs(continuous_loss - resumed_loss)
print(f"\nLoss difference: {diff:.10f}")
if diff < 1e-5:
    print("PASS: Checkpoint resume produces identical results!")
else:
    print("FAIL: Results differ — check optimizer state saving")

# Cleanup
os.remove("test_checkpoints/test_ckpt.pt")
os.rmdir("test_checkpoints")

Key insight: Saving the optimizer state dict is critical. Without it, Adam’s momentum and variance estimates reset to zero, and the resumed training behaves differently — the model “forgets” the gradient history it built up during the first 5 epochs.

Exercise 3: Compute Budget Planner

Build a function plan_training() that takes a target model size (in parameters) and a compute budget (in GPU-hours on a specific GPU) and outputs:

  • How many tokens you can afford to train on
  • Whether the training is Chinchilla-optimal (D ≈ 20N)
  • If not optimal, whether you should use a smaller model with more data or a larger model with less data
  • Estimated cost on a cloud provider
Solution
def plan_training(
    target_params,
    gpu_hours_budget,
    gpu_name="A100 (80GB)",
    cost_per_hour=2.0,
    utilization=0.3,
):
    """
    Plan training given a parameter target and compute budget.

    Args:
        target_params: desired model size in parameters
        gpu_hours_budget: number of GPU-hours available
        gpu_name: GPU type
        cost_per_hour: cloud cost per GPU-hour
        utilization: expected fraction of peak GPU throughput
    """
    # Peak FLOPS by GPU type (FP16/BF16)
    gpu_flops = {
        "RTX 3090": 71e12,
        "RTX 4090": 165e12,
        "A100 (80GB)": 312e12,
        "H100": 990e12,
    }

    if gpu_name not in gpu_flops:
        print(f"Unknown GPU: {gpu_name}")
        return

    peak_flops = gpu_flops[gpu_name]
    effective_flops = peak_flops * utilization

    # Total compute budget in FLOPs
    budget_flops = gpu_hours_budget * 3600 * effective_flops

    # Tokens we can afford: C = 6 * N * D → D = C / (6 * N)
    affordable_tokens = budget_flops / (6 * target_params)

    # Chinchilla-optimal tokens
    optimal_tokens = 20 * target_params

    # Chinchilla-optimal model size for our budget
    # C = 6 * N * 20N = 120 * N^2 → N = sqrt(C / 120)
    optimal_params = (budget_flops / 120) ** 0.5

    cost = gpu_hours_budget * cost_per_hour

    print(f"{'='*60}")
    print(f"Training Plan")
    print(f"{'='*60}")
    print(f"Target model:     {target_params/1e6:.0f}M parameters")
    print(f"GPU:              {gpu_name}")
    print(f"Budget:           {gpu_hours_budget:.0f} GPU-hours (${cost:.0f})")
    print(f"")
    print(f"Affordable tokens: {affordable_tokens/1e9:.1f}B")
    print(f"Optimal tokens:    {optimal_tokens/1e9:.1f}B (Chinchilla 20×N)")
    print(f"")

    ratio = affordable_tokens / optimal_tokens
    if 0.8 <= ratio <= 1.2:
        print(f"Status: GOOD — near Chinchilla-optimal ({ratio:.1f}× optimal)")
    elif ratio > 1.2:
        print(f"Status: OVER-TRAINED — you have {ratio:.1f}× optimal tokens")
        print(f"  → Consider a larger model: {optimal_params/1e6:.0f}M params")
        print(f"    would be Chinchilla-optimal for this budget")
    else:
        print(f"Status: UNDER-TRAINED — you can only afford {ratio:.1f}× optimal tokens")
        print(f"  → Consider a smaller model: {optimal_params/1e6:.0f}M params")
        print(f"    would be Chinchilla-optimal for this budget")

    print(f"\nEstimated training time: {gpu_hours_budget:.0f} hours = {gpu_hours_budget/24:.1f} days")


# Example scenarios
print("\n--- Scenario 1: Student with limited budget ---")
plan_training(
    target_params=124_000_000,
    gpu_hours_budget=24,
    gpu_name="RTX 3090",
    cost_per_hour=0.5,
)

print("\n--- Scenario 2: Startup with modest budget ---")
plan_training(
    target_params=350_000_000,
    gpu_hours_budget=200,
    gpu_name="A100 (80GB)",
    cost_per_hour=2.0,
)

print("\n--- Scenario 3: Well-funded research lab ---")
plan_training(
    target_params=1_500_000_000,
    gpu_hours_budget=5000,
    gpu_name="H100",
    cost_per_hour=3.5,
)

Expected output for Scenario 1:

Training Plan
============================================================
Target model:     124M parameters
GPU:              RTX 3090
Budget:           24 GPU-hours ($12)
Affordable tokens: 0.3B
Optimal tokens:    2.5B (Chinchilla 20×N)

Status: UNDER-TRAINED — you can only afford 0.1× optimal tokens
  → Consider a smaller model: 19M params
    would be Chinchilla-optimal for this budget

This exercise highlights the harsh reality: Chinchilla-optimal training of even moderate models requires significant compute. For most practitioners, fine-tuning pre-trained models is far more practical.


Summary

This chapter covered the essential techniques for scaling model training beyond toy experiments:

  • Scaling laws tell us that more parameters + more data + more compute = better models, following predictable power laws.
  • Gradient accumulation lets you simulate large batch sizes without needing more GPU memory — accumulate gradients over multiple micro-batches before updating.
  • Mixed precision training halves memory usage and accelerates computation by using FP16/BF16 for most operations, with FP32 only where precision matters.
  • Checkpointing saves model and optimizer state so training can resume after interruptions — essential for any training run longer than a few hours.
  • Model configurations range from 2M (minutes on CPU) to 350M+ (weeks on multi-GPU), with specific d_model, layers, and heads for each size.
  • Chinchilla scaling recommends ~20 tokens per parameter for optimal training. Data quality and deduplication matter as much as quantity.
  • Compute estimation helps you plan training budgets. A 124M model takes ~1-3 days on a single modern GPU; larger models require multi-GPU setups.
  • Weight initialization with std=0.02 and residual scaling prevents gradient explosion/vanishing and enables stable training.

The gap between a 2M toy model and a production 7B model is large, but it’s a continuous spectrum. Every technique in this chapter — gradient accumulation, mixed precision, checkpointing, proper initialization — applies at every scale. Master them on a 25M model, and you’ll be ready for anything larger.

In the next chapter, we’ll cover fine-tuning — how to take a pre-trained model and adapt it to specific tasks without training from scratch.