Training Loop — Teaching Your Model to Speak
SummaryThis chapter covers the complete training pipeline for...
This chapter covers the complete training pipeline for...
This chapter covers the complete training pipeline for the GPT model built in CH6. Training data preparation creates input-target pairs from text corpora using PyTorch Dataset and DataLoader. Cross-entropy loss measures prediction quality with intuitive explanations. The Adam optimizer adjusts weights with adaptive learning rates. The core training loop combines forward pass, loss computation, backpropagation, and weight updates. Learning rate scheduling with warmup and cosine decay improves training stability. Gradient clipping prevents exploding gradients. A complete, runnable training script brings everything together to train a small LLM on Shakespeare text.
Training Loop — Teaching Your Model to Speak
In the previous chapter, you assembled every transformer component into a complete GPT model. You ran a forward pass, watched token IDs flow through embeddings, transformer blocks, and an output projection, and got back logits — raw prediction scores for every word in the vocabulary. But those predictions were garbage. The model was guessing randomly because it had never seen a single example of real language.
This chapter changes everything. We’re going to train the model.
Think of a newborn child learning to speak. At first, they babble — random sounds with no meaning. But they hear their parents talk, try to imitate, get corrected, and try again. After thousands of repetitions, they start stringing together real words, then sentences, then paragraphs. Your GPT model is that child. Its weights are initialized randomly (babbling), and through training, it will learn the patterns of language one correction at a time.
By the end of this chapter, you will:
- Transform raw text into input-target pairs that a model can learn from
- Understand cross-entropy loss — the measure of how wrong the model’s predictions are
- Configure the Adam optimizer to adjust the model’s weights
- Implement the complete training loop: forward → loss → backward → update
- Add learning rate scheduling for stable training
- Apply gradient clipping to prevent exploding gradients
- Run a complete training script and watch your model improve
Let’s teach your model to speak.
1. Preparing Training Data
Before the model can learn anything, we need training data in a format it can consume. Language models learn by next-token prediction: given a sequence of tokens, predict the next one. This means we need pairs of (input, target) where the target is the input shifted by one position.
The Shifting Trick
Consider this sentence after tokenization:
Tokens: ["Shall", "I", "compare", "thee", "to", "a", "summer", "day"]
IDs: [ 481, 12, 305, 198, 27, 3, 592, 114 ]
We create input-target pairs by slicing:
Input: [481, 12, 305, 198, 27, 3, 592] → "Shall I compare thee to a summer"
Target: [ 12, 305, 198, 27, 3, 592, 114] → "I compare thee to a summer day"
The target is the input shifted right by one token. At every position, the model’s job is to predict what comes next:
| Position | Input token | Target (correct answer) |
|---|---|---|
| 0 | ”Shall" | "I” |
| 1 | ”I" | "compare” |
| 2 | ”compare" | "thee” |
| 3 | ”thee" | "to” |
| 4 | ”to" | "a” |
| 5 | ”a" | "summer” |
| 6 | ”summer" | "day” |
From a single sentence of 8 tokens, we get 7 training examples. Every sentence in our corpus produces len(sentence) - 1 examples. This is incredibly data-efficient — we’re squeezing maximum learning from every piece of text.
Building a PyTorch Dataset
PyTorch provides the Dataset and DataLoader abstractions for feeding data to a model. A Dataset defines what the data is; a DataLoader handles how it’s served (batching, shuffling, etc.).
import torch
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
"""A dataset for language modeling.
Takes a long sequence of token IDs and produces input-target
pairs of fixed length by sliding a window through the text.
"""
def __init__(self, token_ids, seq_len):
"""
Args:
token_ids: 1D tensor of all token IDs in the corpus
seq_len: length of each training sequence
"""
self.token_ids = token_ids
self.seq_len = seq_len
def __len__(self):
# How many sequences can we extract?
# We need seq_len tokens for input + 1 more for the final target
return len(self.token_ids) - self.seq_len
def __getitem__(self, idx):
# Grab a window of seq_len + 1 tokens
chunk = self.token_ids[idx : idx + self.seq_len + 1]
# Input: first seq_len tokens
x = chunk[:-1] # (seq_len,)
# Target: last seq_len tokens (shifted by 1)
y = chunk[1:] # (seq_len,)
return x, y
Let’s test it with a small example:
# Simulate a tokenized corpus: 100 token IDs
all_tokens = torch.randint(0, 1000, (100,))
seq_len = 16
dataset = TextDataset(all_tokens, seq_len)
print(f"Corpus size: {len(all_tokens)} tokens")
print(f"Sequence length: {seq_len}")
print(f"Number of training examples: {len(dataset)}")
# Number of training examples: 84
# Look at one example
x, y = dataset[0]
print(f"\nInput shape: {x.shape}") # torch.Size([16])
print(f"Target shape: {y.shape}") # torch.Size([16])
print(f"Input: {x[:6].tolist()}")
print(f"Target: {y[:6].tolist()}")
# Notice: target[i] == input[i+1] — shifted by one
The DataLoader
A DataLoader wraps the dataset and handles batching and shuffling automatically:
dataloader = DataLoader(
dataset,
batch_size=4, # 4 sequences per batch
shuffle=True, # Randomize order each epoch
drop_last=True, # Drop incomplete final batch
)
# Grab one batch
batch_x, batch_y = next(iter(dataloader))
print(f"Batch input shape: {batch_x.shape}") # torch.Size([4, 16])
print(f"Batch target shape: {batch_y.shape}") # torch.Size([4, 16])
The DataLoader delivers batches of shape (batch_size, seq_len) — exactly what our GPT model expects as input.
Tokenizing a Real Corpus
For training, we need actual text. Let’s use a sample of Shakespeare’s sonnets — small enough to train quickly, rich enough to show real learning:
# A small Shakespeare corpus for demonstration
shakespeare_text = """
Shall I compare thee to a summer's day?
Thou art more lovely and more temperate.
Rough winds do shake the darling buds of May,
And summer's lease hath all too short a date.
Sometime too hot the eye of heaven shines,
And often is his gold complexion dimmed;
And every fair from fair sometime declines,
By chance, or nature's changing course, untrimmed;
But thy eternal summer shall not fade,
Nor lose possession of that fair thou ow'st;
Nor shall death brag thou wand'rest in his shade,
When in eternal lines to Time thou grow'st.
So long as men can breathe, or eyes can see,
So long lives this, and this gives life to thee.
"""
# Simple character-level tokenizer (from CH3)
chars = sorted(set(shakespeare_text))
char_to_id = {ch: i for i, ch in enumerate(chars)}
id_to_char = {i: ch for ch, i in char_to_id.items()}
vocab_size = len(chars)
# Encode the entire corpus
token_ids = torch.tensor([char_to_id[ch] for ch in shakespeare_text],
dtype=torch.long)
print(f"Vocabulary size: {vocab_size}")
print(f"Corpus length: {len(token_ids)} tokens")
print(f"First 20 tokens: {token_ids[:20].tolist()}")
print(f"Decoded back: '{shakespeare_text[:20]}'")
We use character-level tokenization here for simplicity. In production, you’d use BPE (Chapter 3), but characters work well for demonstrating the training loop without additional complexity.
2. Cross-Entropy Loss Explained
The model outputs a probability distribution over the vocabulary at every position. We need a way to measure how wrong those predictions are. That measurement is called the loss function, and for language models, the standard choice is cross-entropy loss.
The Intuition
Imagine a 5-word vocabulary: ["cat", "dog", "the", "sat", "on"]
The correct next word is "sat" (index 3). The model predicts:
Model's prediction: [0.05, 0.05, 0.10, 0.70, 0.10]
cat dog the sat on
The model assigns 70% probability to "sat" — pretty good! Cross-entropy measures how surprised the model is by the correct answer. If the model was confident about the right answer (high probability), the surprise is low and the loss is small. If the model was wrong (low probability on the correct answer), the surprise is high and the loss is large.
Mathematically, cross-entropy for a single prediction is:
$$\text{Loss} = -\log(p_{\text{correct}})$$
Where $p_{\text{correct}}$ is the probability the model assigned to the correct token.
Let’s compute it for different scenarios:
| Scenario | $p_{\text{correct}}$ | $-\log(p)$ | Interpretation |
|---|---|---|---|
| Very confident, correct | 0.90 | 0.105 | Low loss — great! |
| Somewhat confident | 0.70 | 0.357 | Moderate loss |
| Uncertain | 0.20 | 1.609 | High loss — needs work |
| Random guessing (1/5) | 0.20 | 1.609 | Baseline for 5 words |
| Almost wrong | 0.01 | 4.605 | Very high loss — terrible |
The loss is always positive, and lower is better. A perfect model (100% confidence on every correct answer) would have zero loss. A random model with vocabulary size $V$ has expected loss $\log(V)$.
Cross-Entropy in PyTorch
PyTorch’s nn.CrossEntropyLoss takes raw logits (not probabilities!) and computes the loss. It applies softmax internally for numerical stability.
import torch
import torch.nn as nn
# Simulate model output: batch of 2 sequences, length 4, vocabulary of 5
logits = torch.randn(2, 4, 5) # (batch_size, seq_len, vocab_size)
targets = torch.randint(0, 5, (2, 4)) # (batch_size, seq_len)
print(f"Logits shape: {logits.shape}") # torch.Size([2, 4, 5])
print(f"Targets shape: {targets.shape}") # torch.Size([2, 4])
There’s a catch: CrossEntropyLoss expects the class dimension to be the second dimension, not the third. We need to reshape:
loss_fn = nn.CrossEntropyLoss()
# Reshape logits: (batch*seq_len, vocab_size)
# Reshape targets: (batch*seq_len,)
loss = loss_fn(
logits.view(-1, logits.size(-1)), # (8, 5)
targets.view(-1) # (8,)
)
print(f"Loss: {loss.item():.4f}")
Why the reshape? CrossEntropyLoss wants a 2D input (N, C) where N is the number of predictions and C is the number of classes. We flatten the batch and sequence dimensions into one.
What Does Initial Loss Tell Us?
For an untrained model with vocabulary size $V$, the expected initial loss is:
$$\text{Expected initial loss} = \log(V)$$
import math
vocab_size = 1000
expected_loss = math.log(vocab_size)
print(f"Expected initial loss (vocab={vocab_size}): {expected_loss:.4f}")
# Expected initial loss (vocab=1000): 6.9078
If your initial loss is close to $\log(V)$, the model is behaving as expected — uniform random guessing. If it’s significantly higher, something is wrong with your architecture or initialization. This is a valuable sanity check.
3. The Adam Optimizer
The loss function tells us how wrong the model is. The optimizer’s job is to adjust the model’s weights to reduce that wrongness. Think of it as the teacher correcting the child’s pronunciation after each attempt.
What an Optimizer Does
After computing the loss, PyTorch’s autograd system calculates gradients — the direction and magnitude of change needed for each weight to reduce the loss. The optimizer applies these gradients to actually update the weights:
$$\theta_{\text{new}} = \theta_{\text{old}} - \alpha \cdot \nabla_\theta \text{Loss}$$
Where $\theta$ represents the model’s weights, $\alpha$ is the learning rate (how big each step is), and $\nabla_\theta \text{Loss}$ is the gradient.
Why Adam Over Basic SGD?
Stochastic Gradient Descent (SGD) uses the same learning rate for every parameter. This is like telling every student in a class to study at the same pace — some are already good at math and need to focus on writing, others need the opposite.
Adam (Adaptive Moment Estimation) adapts the learning rate for each parameter individually. Parameters that receive large, frequent gradients get smaller updates (they’re already learning fast). Parameters that receive small, infrequent gradients get larger updates (they need more encouragement). This makes Adam far more forgiving of the initial learning rate choice.
Setting Up Adam
import torch.optim as optim
# Assume model is already created (from CH6)
optimizer = optim.Adam(
model.parameters(), # All trainable weights
lr=3e-4, # Learning rate (a good default for transformers)
betas=(0.9, 0.999), # Momentum parameters (usually keep defaults)
weight_decay=0.01, # L2 regularization (prevents overfitting)
)
The Learning Rate: The Most Important Hyperparameter
The learning rate $\alpha$ controls how big each weight update is. It’s the single most important number you’ll tune during training.
Too high (e.g., 0.1): The model takes huge steps, overshoots the optimal weights, and the loss bounces around wildly or explodes to infinity. Like a student who tries to learn the entire textbook in one night — they end up remembering nothing.
Too low (e.g., 1e-7): The model inches forward so slowly that training takes forever and might get stuck in a poor solution. Like a student who reads one word per hour.
Just right (typically 1e-4 to 1e-3 for transformers): Steady progress, smooth loss decrease. The common default for training transformers is 3e-4 (0.0003).
# The effect of learning rate — a mental model
#
# lr = 0.1 → Loss: 6.9, 14.2, NaN, NaN (exploding!)
# lr = 0.001 → Loss: 6.9, 5.8, 4.2, 3.1 (too slow but works)
# lr = 0.0003 → Loss: 6.9, 4.1, 2.3, 1.5 (good convergence)
# lr = 0.0000001 → Loss: 6.9, 6.8, 6.8, 6.7 (barely moving)
4. The Training Loop
Now we have all three ingredients: data (DataLoader), a measure of error (cross-entropy loss), and a way to learn from errors (Adam optimizer). The training loop ties them together.
The Core Loop
Every training step follows the same five-step recipe:
1. Forward pass → Feed data through the model, get predictions
2. Compute loss → Measure how wrong the predictions are
3. Backward pass → Calculate gradients (how to fix each weight)
4. Update weights → Apply gradients via the optimizer
5. Zero gradients → Reset for the next step
This is repeated for every batch in the dataset, and we cycle through the entire dataset multiple times. One complete pass through the dataset is called an epoch.
Batch vs. Epoch Terminology
- Step/Iteration: Processing one batch of data through the loop
- Epoch: One complete pass through the entire training dataset
- Batch size: How many examples are processed in one step
If your dataset has 1000 examples and your batch size is 10, one epoch takes 100 steps. Training for 5 epochs means the model sees every example 5 times, for a total of 500 steps.
Implementation
def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
"""Train the model for one complete epoch.
Args:
model: the GPT model
dataloader: provides batches of (input, target) pairs
optimizer: Adam optimizer
loss_fn: CrossEntropyLoss
device: 'cpu' or 'cuda'
Returns:
Average loss over the epoch
"""
model.train() # Enable dropout and training-specific behavior
total_loss = 0.0
num_batches = 0
for batch_x, batch_y in dataloader:
# Move data to device (CPU or GPU)
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
# Step 1: Forward pass
logits = model(batch_x)
# logits shape: (batch_size, seq_len, vocab_size)
# Step 2: Compute loss
# Reshape for CrossEntropyLoss: (batch*seq_len, vocab_size) vs (batch*seq_len,)
loss = loss_fn(
logits.view(-1, logits.size(-1)),
batch_y.view(-1)
)
# Step 3: Backward pass (compute gradients)
loss.backward()
# Step 4: Update weights
optimizer.step()
# Step 5: Zero gradients for next step
optimizer.zero_grad()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
return avg_loss
Running Multiple Epochs
num_epochs = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
print(f"Training on: {device}")
print(f"Epochs: {num_epochs}")
print(f"Batches per epoch: {len(dataloader)}\n")
for epoch in range(num_epochs):
avg_loss = train_one_epoch(model, dataloader, optimizer, loss_fn, device)
if epoch % 5 == 0 or epoch == num_epochs - 1:
print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f}")
Expected output (approximate):
Training on: cpu
Epochs: 20
Batches per epoch: 21
Epoch 0 | Loss: 3.8521
Epoch 5 | Loss: 2.4103
Epoch 10 | Loss: 1.5687
Epoch 15 | Loss: 0.9214
Epoch 19 | Loss: 0.6433
The loss starts near $\log(\text{vocab_size})$ (random guessing) and decreases over time. This means the model is learning! It’s moving from random babbling toward actual patterns in the text.
5. Learning Rate Scheduling
Using a fixed learning rate for the entire training run works, but we can do better. A learning rate scheduler adjusts the learning rate during training, and modern transformer training universally uses one.
The Warmup + Cosine Decay Strategy
The most popular scheduler for transformers combines two phases:
Warmup (first ~5–10% of training): Start with a very small learning rate and linearly increase it to the target value. This prevents early training instability — when weights are random, large updates can send the model in wild directions.
Cosine decay (remaining 90–95%): Gradually decrease the learning rate following a cosine curve, ending near zero. As the model gets closer to a good solution, smaller steps prevent it from overshooting.
Think of driving a car on an unfamiliar road:
- Warmup: You start slow, getting a feel for the road conditions
- Peak: You reach cruising speed on the highway
- Cosine decay: You gradually slow down as you approach your destination, making precise adjustments to park
Learning Rate
│
max │ ╭──────╮
│ ╱ ╲
│ ╱ ╲
│ ╱ ╲
min │╱ ╲___
└──────────────────────────── Steps
warmup cosine decay
Implementation
import math
def get_lr(step, total_steps, max_lr, min_lr=1e-6, warmup_steps=100):
"""Compute learning rate with linear warmup and cosine decay.
Args:
step: current training step
total_steps: total number of training steps
max_lr: peak learning rate
min_lr: minimum learning rate (floor)
warmup_steps: number of warmup steps
Returns:
Learning rate for this step
"""
# Phase 1: Linear warmup
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
# Phase 2: Cosine decay
progress = (step - warmup_steps) / (total_steps - warmup_steps)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr + (max_lr - min_lr) * cosine_decay
Let’s visualize what this looks like:
total_steps = 1000
warmup_steps = 100
max_lr = 3e-4
# Compute LR at every step
lrs = [get_lr(step, total_steps, max_lr, warmup_steps=warmup_steps)
for step in range(total_steps)]
print(f"Step 0: lr = {lrs[0]:.6f}") # Near zero
print(f"Step 50: lr = {lrs[50]:.6f}") # Halfway through warmup
print(f"Step 100: lr = {lrs[100]:.6f}") # Peak
print(f"Step 500: lr = {lrs[500]:.6f}") # Mid-decay
print(f"Step 999: lr = {lrs[999]:.6f}") # Near minimum
Applying the Scheduler in the Training Loop
To use the scheduler, we update the optimizer’s learning rate at each step:
def set_lr(optimizer, lr):
"""Update the learning rate for all parameter groups."""
for param_group in optimizer.param_groups:
param_group['lr'] = lr
We’ll integrate this into the complete training script in Section 8.
6. Gradient Clipping
There’s one more technique we need before writing our final training script: gradient clipping.
The Exploding Gradient Problem
During backpropagation, gradients flow backward through the network. In deep networks (many transformer blocks stacked), gradients can get multiplied together and grow exponentially — they explode. When this happens, the weight updates become enormous, the loss spikes to infinity or becomes NaN, and training crashes.
It’s like a game of telephone gone wrong: each person amplifies the message a little, and by the time it reaches the end of the line, it’s completely distorted.
The Fix: Gradient Clipping
Gradient clipping is simple: after computing gradients but before updating weights, we check the total magnitude of all gradients. If it exceeds a threshold, we scale all gradients down proportionally so the total magnitude equals the threshold.
$$\text{If } |\mathbf{g}| > \text{max_norm}: \quad \mathbf{g} \leftarrow \mathbf{g} \cdot \frac{\text{max_norm}}{|\mathbf{g}|}$$
This doesn’t change the direction of the gradients — only their magnitude. The model still moves in the right direction; it just takes a shorter step.
PyTorch Implementation
# After loss.backward() and before optimizer.step():
max_norm = 1.0 # Common choice for transformers
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
The clip_grad_norm_ function (note the trailing underscore — it modifies gradients in-place) returns the original gradient norm, which is useful for monitoring:
# In the training loop:
loss.backward()
# Clip gradients and get the original norm
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
if grad_norm > 10.0:
print(f"Warning: large gradient norm {grad_norm:.2f} (clipped to 1.0)")
optimizer.step()
optimizer.zero_grad()
If you see the gradient norm regularly exceeding your max_norm by a large factor, your learning rate might be too high or your model might have an architectural issue.
7. Checkpoint Saving
Training can take hours or days. If your machine crashes at epoch 99 of 100, you don’t want to start over. Checkpoints save the model’s state so you can resume later.
Saving a Checkpoint
def save_checkpoint(model, optimizer, epoch, loss, filepath):
"""Save a training checkpoint.
Args:
model: the GPT model
optimizer: the optimizer (to resume with same momentum)
epoch: current epoch number
loss: current loss value
filepath: where to save (e.g., 'checkpoint_epoch_10.pt')
"""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, filepath)
print(f"Checkpoint saved: {filepath}")
Loading a Checkpoint
def load_checkpoint(filepath, model, optimizer=None):
"""Load a training checkpoint.
Args:
filepath: path to the checkpoint file
model: the GPT model (must have same architecture)
optimizer: optional optimizer to restore state
Returns:
Dictionary with epoch and loss from the checkpoint
"""
checkpoint = torch.load(filepath, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f"Checkpoint loaded: epoch {checkpoint['epoch']}, "
f"loss {checkpoint['loss']:.4f}")
return checkpoint
A good practice is to save a checkpoint every N epochs and always save the best model (lowest loss):
best_loss = float('inf')
for epoch in range(num_epochs):
avg_loss = train_one_epoch(model, dataloader, optimizer, loss_fn, device)
# Save checkpoint every 10 epochs
if (epoch + 1) % 10 == 0:
save_checkpoint(model, optimizer, epoch, avg_loss,
f"checkpoint_epoch_{epoch+1}.pt")
# Save the best model
if avg_loss < best_loss:
best_loss = avg_loss
save_checkpoint(model, optimizer, epoch, avg_loss, "best_model.pt")
8. A Complete Training Script
Now let’s put everything together into one runnable script. This is the culmination of Chapters 3 through 7 — from raw text to a trained language model.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
import math
import time
# ──────────────────────────────────────────────────────
# 1. Configuration
# ──────────────────────────────────────────────────────
@dataclass
class GPTConfig:
vocab_size: int = 65 # Character-level vocabulary
d_model: int = 64 # Embedding dimension
n_heads: int = 4 # Attention heads
n_layers: int = 4 # Transformer blocks
d_ff: int = 256 # Feed-forward hidden size
max_seq_len: int = 64 # Maximum sequence length
dropout: float = 0.1
@dataclass
class TrainConfig:
batch_size: int = 32
num_epochs: int = 50
max_lr: float = 3e-4
min_lr: float = 1e-5
warmup_steps: int = 50
max_grad_norm: float = 1.0
weight_decay: float = 0.01
log_interval: int = 5 # Print every N epochs
checkpoint_interval: int = 25
# ──────────────────────────────────────────────────────
# 2. Dataset
# ──────────────────────────────────────────────────────
class TextDataset(Dataset):
def __init__(self, token_ids, seq_len):
self.token_ids = token_ids
self.seq_len = seq_len
def __len__(self):
return len(self.token_ids) - self.seq_len
def __getitem__(self, idx):
chunk = self.token_ids[idx : idx + self.seq_len + 1]
return chunk[:-1], chunk[1:]
# ──────────────────────────────────────────────────────
# 3. Learning rate schedule
# ──────────────────────────────────────────────────────
def get_lr(step, total_steps, max_lr, min_lr=1e-5, warmup_steps=50):
if step < warmup_steps:
return max_lr * (step + 1) / warmup_steps
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
# ──────────────────────────────────────────────────────
# 4. Prepare data
# ──────────────────────────────────────────────────────
shakespeare_text = """
Shall I compare thee to a summer's day?
Thou art more lovely and more temperate.
Rough winds do shake the darling buds of May,
And summer's lease hath all too short a date.
Sometime too hot the eye of heaven shines,
And often is his gold complexion dimmed;
And every fair from fair sometime declines,
By chance, or nature's changing course, untrimmed;
But thy eternal summer shall not fade,
Nor lose possession of that fair thou ow'st;
Nor shall death brag thou wand'rest in his shade,
When in eternal lines to Time thou grow'st.
So long as men can breathe, or eyes can see,
So long lives this, and this gives life to thee.
From fairest creatures we desire increase,
That thereby beauty's rose might never die,
But as the riper should by time decease,
His tender heir might bear his memory.
But thou, contracted to thine own bright eyes,
Feed'st thy light's flame with self-substantial fuel,
Making a famine where abundance lies,
Thyself thy foe, to thy sweet self too cruel.
When forty winters shall besiege thy brow,
And dig deep trenches in thy beauty's field,
Thy youth's proud livery, so gazed on now,
Will be a tatter'd weed, of small worth held.
Then being ask'd where all thy beauty lies,
Where all the treasure of thy lusty days,
To say, within thine own deep-sunken eyes,
Were an all-eating shame and thriftless praise.
"""
# Build character-level vocabulary
chars = sorted(set(shakespeare_text))
char_to_id = {ch: i for i, ch in enumerate(chars)}
id_to_char = {i: ch for ch, i in char_to_id.items()}
# Encode the corpus
token_ids = torch.tensor([char_to_id[ch] for ch in shakespeare_text],
dtype=torch.long)
gpt_config = GPTConfig(vocab_size=len(chars))
train_config = TrainConfig()
print(f"Vocabulary size: {gpt_config.vocab_size}")
print(f"Corpus length: {len(token_ids)} tokens")
print(f"Model parameters: d_model={gpt_config.d_model}, "
f"n_heads={gpt_config.n_heads}, n_layers={gpt_config.n_layers}")
# Create dataset and dataloader
dataset = TextDataset(token_ids, gpt_config.max_seq_len)
dataloader = DataLoader(dataset, batch_size=train_config.batch_size,
shuffle=True, drop_last=True)
print(f"Training examples: {len(dataset)}")
print(f"Batches per epoch: {len(dataloader)}")
# ──────────────────────────────────────────────────────
# 5. Create model and optimizer
# ──────────────────────────────────────────────────────
# NOTE: Import your GPTModel from CH6, or define it here.
# model = GPTModel(gpt_config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = optim.AdamW(
model.parameters(),
lr=train_config.max_lr,
weight_decay=train_config.weight_decay,
)
loss_fn = nn.CrossEntropyLoss()
# ──────────────────────────────────────────────────────
# 6. Training loop
# ──────────────────────────────────────────────────────
total_steps = train_config.num_epochs * len(dataloader)
global_step = 0
best_loss = float('inf')
print(f"\nTraining on: {device}")
print(f"Total steps: {total_steps}")
print(f"─" * 50)
start_time = time.time()
for epoch in range(train_config.num_epochs):
model.train()
epoch_loss = 0.0
num_batches = 0
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
# Update learning rate
lr = get_lr(global_step, total_steps, train_config.max_lr,
train_config.min_lr, train_config.warmup_steps)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Forward pass
logits = model(batch_x)
# Compute loss
loss = loss_fn(logits.view(-1, logits.size(-1)),
batch_y.view(-1))
# Backward pass
loss.backward()
# Gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), train_config.max_grad_norm
)
# Update weights
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item()
num_batches += 1
global_step += 1
avg_loss = epoch_loss / num_batches
# Logging
if epoch % train_config.log_interval == 0 or epoch == train_config.num_epochs - 1:
elapsed = time.time() - start_time
print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
f"LR: {lr:.6f} | Time: {elapsed:.1f}s")
# Save checkpoint
if (epoch + 1) % train_config.checkpoint_interval == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss,
'config': gpt_config,
}, f"checkpoint_epoch_{epoch+1}.pt")
# Track best model
if avg_loss < best_loss:
best_loss = avg_loss
total_time = time.time() - start_time
print(f"─" * 50)
print(f"Training complete in {total_time:.1f}s")
print(f"Best loss: {best_loss:.4f}")
print(f"Expected loss for random guessing: {math.log(gpt_config.vocab_size):.4f}")
Expected output (approximate):
Vocabulary size: 57
Corpus length: 1026 tokens
Model parameters: d_model=64, n_heads=4, n_layers=4
Training examples: 962
Batches per epoch: 30
Training on: cpu
Total steps: 1500
──────────────────────────────────────────────────
Epoch 0 | Loss: 3.9812 | LR: 0.000006 | Time: 0.8s
Epoch 5 | Loss: 3.2104 | LR: 0.000033 | Time: 4.5s
Epoch 10 | Loss: 2.5437 | LR: 0.000066 | Time: 8.3s
Epoch 15 | Loss: 1.9823 | LR: 0.000292 | Time: 12.1s
Epoch 20 | Loss: 1.4516 | LR: 0.000281 | Time: 16.0s
Epoch 25 | Loss: 1.0234 | LR: 0.000258 | Time: 19.8s
Epoch 30 | Loss: 0.7891 | LR: 0.000224 | Time: 23.6s
Epoch 35 | Loss: 0.5423 | LR: 0.000183 | Time: 27.4s
Epoch 40 | Loss: 0.3867 | LR: 0.000138 | Time: 31.2s
Epoch 45 | Loss: 0.2541 | LR: 0.000092 | Time: 35.0s
Epoch 49 | Loss: 0.1823 | LR: 0.000015 | Time: 38.5s
──────────────────────────────────────────────────
Training complete in 38.5s
Best loss: 0.1823
Expected loss for random guessing: 4.0431
The loss dropped from ~4.0 (random guessing among 57 characters) down to ~0.18. The model has memorized significant portions of the Shakespeare text and can now predict the next character with high accuracy. In Chapter 8, we’ll use this trained model to actually generate new text.
Understanding the Training Progress
Let’s break down what happened:
-
Epochs 0–10 (Warmup phase): The learning rate ramps up gradually. Loss drops from ~4.0 to ~2.5. The model learns basic character frequencies — which characters appear most often.
-
Epochs 10–25 (Fast learning): The learning rate is near its peak. Loss drops rapidly. The model learns common character sequences: “th”, “he”, “in”, “er”, “the”, “and”.
-
Epochs 25–40 (Refinement): The learning rate starts decaying. The model learns longer patterns: common words and short phrases from the sonnets.
-
Epochs 40–50 (Fine-tuning): With a low learning rate, the model makes small adjustments. It memorizes specific phrases and word boundaries.
9. Exercises
Exercise 1: Effect of Learning Rate
Modify the training script to try three different learning rates: 1e-2, 3e-4, and 1e-6. Train for 20 epochs with each and compare the final loss. Which one trains best? Which one fails?
Solution
import math
for lr in [1e-2, 3e-4, 1e-6]:
# Reset model with fresh random weights
model = GPTModel(gpt_config).to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
# Train for 20 epochs (no scheduler, fixed LR for this experiment)
for epoch in range(20):
model.train()
total_loss = 0
n = 0
for bx, by in dataloader:
bx, by = bx.to(device), by.to(device)
logits = model(bx)
loss = loss_fn(logits.view(-1, logits.size(-1)), by.view(-1))
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
n += 1
if epoch == 19:
print(f"LR={lr:.0e} → Final loss: {total_loss/n:.4f}")
# Expected results (approximate):
# LR=1e-02 → Final loss: NaN or very high (training exploded)
# LR=3e-04 → Final loss: ~1.5-2.0 (good convergence)
# LR=1e-06 → Final loss: ~3.7-3.9 (barely learned anything)
Takeaway: 1e-2 is too aggressive and causes training instability. 3e-4 is the sweet spot. 1e-6 is too conservative — the model barely moves from its random initialization.
Exercise 2: Batch Size Experiment
Train the model with batch sizes of 4, 16, and 64. Keep everything else the same (same number of epochs). How does batch size affect training speed and final loss?
Solution
for bs in [4, 16, 64]:
model = GPTModel(gpt_config).to(device)
dl = DataLoader(dataset, batch_size=bs, shuffle=True, drop_last=True)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()
start = time.time()
for epoch in range(20):
model.train()
total_loss = 0
n = 0
for bx, by in dl:
bx, by = bx.to(device), by.to(device)
logits = model(bx)
loss = loss_fn(logits.view(-1, logits.size(-1)), by.view(-1))
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
n += 1
elapsed = time.time() - start
print(f"Batch size={bs:2d} → Loss: {total_loss/n:.4f}, "
f"Steps/epoch: {len(dl)}, Time: {elapsed:.1f}s")
# Expected results:
# Batch size= 4 → Loss: ~0.8, Steps/epoch: 240, Time: ~25s (many steps, slower)
# Batch size=16 → Loss: ~1.0, Steps/epoch: 60, Time: ~12s (good balance)
# Batch size=64 → Loss: ~1.5, Steps/epoch: 15, Time: ~6s (fewer steps, may underfit)
Takeaway: Smaller batches give more weight updates per epoch (more learning steps) but each step is noisier. Larger batches are computationally efficient but provide fewer updates. The optimal batch size balances these trade-offs. For small datasets, smaller batches often work better because the model gets more update steps.
Exercise 3: Monitor Gradient Norms
Modify the training loop to record the gradient norm at every step. Plot the gradient norms over time. When are gradients largest? Does gradient clipping activate?
Solution
model = GPTModel(gpt_config).to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()
grad_norms = []
clipped_count = 0
max_norm = 1.0
for epoch in range(20):
model.train()
for bx, by in dataloader:
bx, by = bx.to(device), by.to(device)
logits = model(bx)
loss = loss_fn(logits.view(-1, logits.size(-1)), by.view(-1))
loss.backward()
# Record gradient norm BEFORE clipping
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=max_norm
)
grad_norms.append(grad_norm.item())
if grad_norm > max_norm:
clipped_count += 1
optimizer.step()
optimizer.zero_grad()
total_steps = len(grad_norms)
print(f"Total steps: {total_steps}")
print(f"Clipped steps: {clipped_count} ({100*clipped_count/total_steps:.1f}%)")
print(f"Max gradient norm: {max(grad_norms):.4f}")
print(f"Mean gradient norm: {sum(grad_norms)/len(grad_norms):.4f}")
print(f"Final gradient norm: {grad_norms[-1]:.4f}")
# Gradient norms are typically largest early in training when the
# model is far from a good solution. They decrease as training
# progresses and the loss landscape becomes smoother.
Takeaway: Gradients are usually largest in the first few hundred steps, when the model is making large adjustments. Gradient clipping activates most during this period. As training progresses, gradients naturally become smaller and clipping becomes less frequent.
Exercise 4: Early Stopping
Implement early stopping: if the loss doesn’t improve for patience epochs (e.g., 5), stop training and restore the best model weights.
Solution
model = GPTModel(gpt_config).to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()
# Early stopping parameters
patience = 5
best_loss = float('inf')
best_state = None
epochs_without_improvement = 0
max_epochs = 100 # Upper bound
for epoch in range(max_epochs):
model.train()
total_loss = 0
n = 0
for bx, by in dataloader:
bx, by = bx.to(device), by.to(device)
logits = model(bx)
loss = loss_fn(logits.view(-1, logits.size(-1)), by.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
n += 1
avg_loss = total_loss / n
if avg_loss < best_loss:
best_loss = avg_loss
best_state = {k: v.clone() for k, v in model.state_dict().items()}
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
if epoch % 5 == 0:
print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | "
f"No improvement for: {epochs_without_improvement} epochs")
if epochs_without_improvement >= patience:
print(f"\nEarly stopping at epoch {epoch}!")
print(f"Best loss: {best_loss:.4f}")
model.load_state_dict(best_state)
break
else:
print(f"Completed all {max_epochs} epochs without early stopping.")
print(f"Best loss: {best_loss:.4f}")
Takeaway: Early stopping prevents overfitting by halting training once the model stops improving. The patience parameter controls how many “bad” epochs we tolerate before stopping — too small and we might stop during a temporary plateau; too large and we waste computation.
Summary
In this chapter, you learned the complete training pipeline for a language model:
-
Data preparation: Text is tokenized into IDs, then sliced into input-target pairs where the target is the input shifted by one position. PyTorch’s
DatasetandDataLoaderhandle batching and shuffling. -
Cross-entropy loss: Measures prediction quality by computing $-\log(p_{\text{correct}})$. Low loss means high confidence on correct answers. Initial loss should be approximately $\log(V)$ for vocabulary size $V$.
-
The Adam optimizer: Adjusts each weight based on its gradient, adapting the learning rate per parameter. The learning rate is the most critical hyperparameter —
3e-4is a solid default for transformers. -
The training loop: Forward pass → compute loss → backward pass (gradients) → update weights → repeat. One pass through the dataset is an epoch.
-
Learning rate scheduling: Warmup + cosine decay starts cautiously, reaches cruising speed, then gradually slows for fine-grained optimization.
-
Gradient clipping: Caps gradient magnitude to prevent exploding gradients in deep networks. A
max_normof 1.0 is standard. -
Checkpointing: Save model and optimizer state periodically to resume training after interruptions and preserve the best model.
Your model can now learn from text. But learning isn’t the same as speaking — in the next chapter, we’ll use the trained model to actually generate new text, one token at a time.