The Transformer Architecture — Attention Is All You Need
SummaryThis chapter breaks down the transformer architecture from...
This chapter breaks down the transformer architecture from...
This chapter breaks down the transformer architecture from the 2017 'Attention is All You Need' paper into digestible components. Starting with the intuition behind attention as a spotlight mechanism, it walks through self-attention with concrete numerical examples, scaled dot-product attention, and multi-head attention. Causal masking for autoregressive language models prevents future token leakage. Layer normalization stabilizes training, feed-forward networks provide non-linear transformation, and residual connections ensure gradient flow. Each component is implemented from scratch with shape annotations before showing PyTorch equivalents.
The Transformer Architecture — Attention Is All You Need
In the previous chapter, you learned how to turn token IDs into rich embedding vectors — dense numerical representations where similar words sit near each other in a high-dimensional space. You also added positional encodings so the model knows word order.
But embeddings alone don’t understand language. The embedding for “bank” is the same whether the sentence is “I sat by the river bank” or “I went to the bank to deposit money.” Embeddings are static — they don’t change based on context.
This chapter is about the mechanism that makes embeddings context-aware: attention. It is the core idea behind the transformer architecture, which powers GPT, BERT, LLaMA, and virtually every modern language model. By the end of this chapter, you will understand every component inside a transformer block and implement each one from scratch.
Let’s begin.
1. Why Transformers?
The Problem with Reading One Word at a Time
Before transformers, the dominant approach to processing language was the Recurrent Neural Network (RNN). An RNN reads text one word at a time, left to right, carrying a hidden state that acts as a “memory” of what it has read so far.
Think of it like reading a long novel with amnesia. You read the first page and remember some of it. Then you read the second page, and some of the first page fades. By chapter 10, you’ve lost almost everything from chapter 1. RNNs suffer from this exact problem — it’s called the vanishing gradient problem. Information from early in a sequence gradually disappears as the network processes more tokens.
There’s a second problem: RNNs are slow. Because they process words one at a time, you can’t parallelize the computation. Word 5 depends on word 4, which depends on word 3, and so on. On modern GPUs designed for massive parallel computation, this is a devastating bottleneck.
The Transformer Revolution
In 2017, researchers at Google published a paper titled “Attention Is All You Need” (Vaswani et al.). Their key insight was radical: throw away recurrence entirely. Instead of reading words one at a time, process all words simultaneously and let each word decide which other words are relevant to it.
This is like having the entire book open in front of you at once, with the ability to glance at any page whenever you need to. No forgetting, no sequential bottleneck.
The mechanism that enables this is called self-attention, and the architecture built around it is called the transformer. It is the foundation of every large language model today.
Let’s build it piece by piece.
2. Intuition for Attention
The Spotlight Analogy
Consider this sentence:
“The cat sat on the mat because it was tired.”
What does “it” refer to? You know immediately — “it” refers to “the cat.” But how did you figure that out? You didn’t process words one at a time. You looked back at the whole sentence and connected “it” to “cat” based on meaning.
Attention works exactly like a spotlight. When the model is processing the word “it,” it shines a spotlight across all the other words in the sentence and asks: “Which of you are most relevant to me right now?” The word “cat” lights up brightly. The word “mat” stays dim. The word “the” barely registers.
The strength of the spotlight — how brightly each word lights up — is called an attention weight. Words that are more relevant get higher weights.
The Query-Key-Value Metaphor
Attention uses three components, and the best metaphor is a library search:
-
Query (Q): “What am I looking for?” — This is the question posed by the current word. When processing “it,” the query is something like “I’m a pronoun — who am I referring to?”
-
Key (K): “What do I contain?” — Each word advertises a description of itself. The word “cat” might advertise “I’m a noun, an animal, the subject of this sentence.”
-
Value (V): “Here’s my actual information.” — Once the query finds a matching key, the value is the content that gets passed along. Think of it as the actual book you pull off the library shelf.
The attention mechanism computes how well each query matches each key (a compatibility score), then uses those scores to create a weighted combination of all values.
Let’s make this concrete with numbers.
3. Self-Attention Step by Step
Setting Up a Tiny Example
We’ll work with three words: “The”, “cat”, “sat”. Imagine each word has already been embedded into a 4-dimensional vector (in practice, dimensions are much larger — 768, 1024, or more — but 4 keeps things readable).
import torch
import torch.nn as nn
import torch.nn.functional as F
# Three words, each embedded into 4 dimensions
# In real models, these come from the embedding layer (CH4)
x = torch.tensor([
[1.0, 0.0, 1.0, 0.0], # "The"
[0.0, 2.0, 0.0, 2.0], # "cat"
[1.0, 1.0, 1.0, 1.0], # "sat"
])
print(f"Input shape: {x.shape}")
# Input shape: torch.Size([3, 4])
# 3 tokens, each with 4-dimensional embedding
Creating Q, K, V with Weight Matrices
To produce queries, keys, and values, we multiply the input by three separate weight matrices. These matrices are learned parameters — the model adjusts them during training to get better at finding relevant connections.
# Embedding dimension
d_model = 4
# For this example, Q, K, V have the same dimension as the input
d_k = 4 # dimension of keys and queries
d_v = 4 # dimension of values
# Weight matrices (in practice, these are learned)
# We initialize them manually for clarity
torch.manual_seed(42)
W_Q = torch.randn(d_model, d_k) # (4, 4) — projects input to queries
W_K = torch.randn(d_model, d_k) # (4, 4) — projects input to keys
W_V = torch.randn(d_model, d_v) # (4, 4) — projects input to values
print(f"W_Q shape: {W_Q.shape}") # (4, 4)
print(f"W_K shape: {W_K.shape}") # (4, 4)
print(f"W_V shape: {W_V.shape}") # (4, 4)
# Compute Q, K, V by multiplying input by weight matrices
Q = x @ W_Q # (3, 4) @ (4, 4) = (3, 4) — one query per token
K = x @ W_K # (3, 4) @ (4, 4) = (3, 4) — one key per token
V = x @ W_V # (3, 4) @ (4, 4) = (3, 4) — one value per token
print(f"\nQ (queries) shape: {Q.shape}") # (3, 4)
print(f"K (keys) shape: {K.shape}") # (3, 4)
print(f"V (values) shape: {V.shape}") # (3, 4)
print(f"\nQ (queries):\n{Q}")
print(f"\nK (keys):\n{K}")
print(f"\nV (values):\n{V}")
Each row of Q is a query for one token. Each row of K is a key for one token. Each row of V is a value for one token.
Step 1: Compute Attention Scores (Q × K^T)
The first step is to measure how well each query matches each key. We do this with a dot product — the same operation we used in Chapter 2 to measure similarity between vectors. Higher dot product means higher similarity.
# Compute raw attention scores: how well each query matches each key
# Q: (3, 4), K^T: (4, 3) → scores: (3, 3)
attention_scores = Q @ K.T
print(f"Attention scores shape: {attention_scores.shape}") # (3, 3)
print(f"\nRaw attention scores:\n{attention_scores}")
The result is a 3×3 matrix. The entry at position (i, j) tells us how much token i should “attend to” token j. For example:
- Row 0 tells us how much “The” attends to [“The”, “cat”, “sat”]
- Row 1 tells us how much “cat” attends to [“The”, “cat”, “sat”]
- Row 2 tells us how much “sat” attends to [“The”, “cat”, “sat”]
Step 2: Apply Softmax to Get Attention Weights
Raw scores can be any number — positive, negative, large, small. We need to turn them into probabilities that sum to 1 for each token. Softmax does exactly this.
# Convert raw scores into probabilities (attention weights)
# Each row sums to 1
attention_weights = F.softmax(attention_scores, dim=-1)
print(f"Attention weights shape: {attention_weights.shape}") # (3, 3)
print(f"\nAttention weights (each row sums to 1):\n{attention_weights}")
print(f"\nRow sums: {attention_weights.sum(dim=-1)}") # Should be [1, 1, 1]
Now each row is a probability distribution. If “cat” has weight 0.7 for itself and 0.2 for “sat” and 0.1 for “The,” it means “cat” pays 70% of its attention to itself, 20% to “sat,” and 10% to “The.”
Step 3: Multiply by V to Get the Output
The final step is to use these weights to create a weighted combination of values. Each token’s output is a weighted sum of all value vectors, where the weights come from the attention distribution.
# Weighted sum of values using attention weights
# attention_weights: (3, 3), V: (3, 4) → output: (3, 4)
output = attention_weights @ V
print(f"Output shape: {output.shape}") # (3, 4) — same shape as input
print(f"\nSelf-attention output:\n{output}")
The output has the same shape as the input — each token still has a 4-dimensional representation. But now each representation has been enriched with context from other tokens. The embedding for “cat” is no longer just about “cat” — it now carries information about “The” and “sat” too, weighted by relevance.
The Complete Self-Attention in One Function
def self_attention_naive(x, W_Q, W_K, W_V):
"""
Naive self-attention (without scaling).
Args:
x: Input tensor of shape (seq_len, d_model)
W_Q, W_K, W_V: Weight matrices, each of shape (d_model, d_k)
Returns:
Output tensor of shape (seq_len, d_v)
"""
# Step 1: Project input into queries, keys, values
Q = x @ W_Q # (seq_len, d_k)
K = x @ W_K # (seq_len, d_k)
V = x @ W_V # (seq_len, d_v)
# Step 2: Compute attention scores
scores = Q @ K.T # (seq_len, seq_len)
# Step 3: Softmax to get weights
weights = F.softmax(scores, dim=-1) # (seq_len, seq_len)
# Step 4: Weighted sum of values
output = weights @ V # (seq_len, d_v)
return output, weights
output, weights = self_attention_naive(x, W_Q, W_K, W_V)
print(f"Output shape: {output.shape}") # (3, 4)
print(f"Weights shape: {weights.shape}") # (3, 3)
4. Scaled Dot-Product Attention
The Problem with Large Dimensions
Our example used 4-dimensional embeddings. Real models use 768, 1024, or even 12288 dimensions. When you compute dot products between high-dimensional vectors, the resulting values tend to be very large in magnitude.
Why is this a problem? Because of how softmax behaves with large inputs:
# Softmax with small values — spread out, informative distribution
small_scores = torch.tensor([1.0, 2.0, 3.0])
print(f"Small scores softmax: {F.softmax(small_scores, dim=0)}")
# Something like [0.09, 0.24, 0.67] — a useful distribution
# Softmax with large values — extremely peaky, almost one-hot
large_scores = torch.tensor([10.0, 20.0, 30.0])
print(f"Large scores softmax: {F.softmax(large_scores, dim=0)}")
# Something like [0.0000, 0.0000, 1.0000] — all weight on one element
When softmax becomes too “peaky” — concentrating almost all weight on a single token — the model can’t learn nuanced attention patterns. It also causes problems during training because gradients become extremely small (nearly zero) for most positions.
The Fix: Divide by √d_k
The solution from the original paper is elegant: divide the attention scores by the square root of the key dimension before applying softmax. This keeps the scores in a reasonable range regardless of how large d_k is.
The mathematical intuition: if Q and K elements are independent random variables with mean 0 and variance 1, the dot product of two d_k-dimensional vectors has variance d_k. Dividing by √d_k brings the variance back to 1.
import math
def scaled_dot_product_attention(Q, K, V):
"""
Scaled dot-product attention as described in
"Attention Is All You Need" (Vaswani et al., 2017).
Args:
Q: Queries of shape (..., seq_len, d_k)
K: Keys of shape (..., seq_len, d_k)
V: Values of shape (..., seq_len, d_v)
Returns:
output: Weighted values of shape (..., seq_len, d_v)
weights: Attention weights of shape (..., seq_len, seq_len)
"""
d_k = Q.shape[-1]
# Compute scaled attention scores
# Q @ K^T: (..., seq_len, seq_len)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
# Convert to probabilities
weights = F.softmax(scores, dim=-1) # (..., seq_len, seq_len)
# Weighted sum of values
output = weights @ V # (..., seq_len, d_v)
return output, weights
# Test it
Q = x @ W_Q
K = x @ W_K
V = x @ W_V
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Scaled attention output shape: {output.shape}") # (3, 4)
print(f"Attention weights:\n{weights}")
# Notice: weights are more spread out than without scaling
This is the complete attention function used inside every transformer. Everything else — multi-head attention, masking — builds on top of this.
5. Multi-Head Attention
Why One Spotlight Isn’t Enough
Consider the sentence: “The animal didn’t cross the street because it was too wide.”
When processing “it,” the model needs to figure out that “it” refers to “the street” (because streets can be wide, animals can’t). But there are multiple types of relationships the model might want to track simultaneously:
- Grammatical: “it” is a pronoun — what noun does it replace?
- Semantic: “wide” describes a physical property — what things can be wide?
- Positional: “it” is close to “street” and far from “animal”
A single attention mechanism computes one set of weights — one spotlight. Multi-head attention uses multiple spotlights, each looking for a different type of relationship.
How Multi-Head Attention Works
The idea is simple:
- Split the embedding into multiple “heads” — if d_model = 512 and we use 8 heads, each head works with 512 / 8 = 64 dimensions.
- Each head performs its own independent attention computation.
- Concatenate all heads’ outputs back together.
- Project the concatenated result with a final linear layer.
Each head learns to attend to different aspects of the input. One head might learn grammatical relationships, another might learn semantic similarity, another might track coreference.
Implementation from Scratch
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention from scratch.
Args:
d_model: Dimension of the model (embedding size)
num_heads: Number of attention heads
"""
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, \
f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # dimension per head
# Linear projections for Q, K, V
# Each projects from d_model to d_model, but we'll reshape into heads
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
# Final output projection
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
"""
Args:
x: Input of shape (batch_size, seq_len, d_model)
mask: Optional mask of shape (seq_len, seq_len)
Returns:
output: Shape (batch_size, seq_len, d_model)
"""
batch_size, seq_len, _ = x.shape
# Step 1: Project to Q, K, V
# Each: (batch_size, seq_len, d_model)
Q = self.W_Q(x)
K = self.W_K(x)
V = self.W_V(x)
# Step 2: Reshape into multiple heads
# (batch_size, seq_len, d_model) → (batch_size, num_heads, seq_len, d_k)
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
print(f" Q after split into heads: {Q.shape}")
# (batch_size, num_heads, seq_len, d_k)
# Step 3: Scaled dot-product attention for each head
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
# scores: (batch_size, num_heads, seq_len, seq_len)
# Apply mask if provided (we'll cover masking in the next section)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
# weights: (batch_size, num_heads, seq_len, seq_len)
attn_output = weights @ V
# attn_output: (batch_size, num_heads, seq_len, d_k)
# Step 4: Concatenate heads
# (batch_size, num_heads, seq_len, d_k) → (batch_size, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.d_model
)
print(f" After concatenating heads: {attn_output.shape}")
# Step 5: Final linear projection
output = self.W_O(attn_output)
# output: (batch_size, seq_len, d_model)
return output
# ---- Test it ----
d_model = 8 # small for demonstration
num_heads = 2 # 2 heads, each working with d_k = 4
mha = MultiHeadAttention(d_model, num_heads)
# Create a batch of 1 sequence with 5 tokens, each 8-dimensional
torch.manual_seed(42)
x = torch.randn(1, 5, d_model) # (batch=1, seq_len=5, d_model=8)
print(f"Input shape: {x.shape}")
output = mha(x)
print(f"Output shape: {output.shape}") # (1, 5, 8) — same as input
Notice the key insight: the input and output shapes are identical. A multi-head attention layer takes contextualized embeddings in and produces contextualized embeddings out. It can be stacked — the output of one attention layer becomes the input to the next.
Using PyTorch’s Built-in Multi-Head Attention
PyTorch provides nn.MultiheadAttention, which does the same thing:
# PyTorch built-in version
mha_pytorch = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
batch_first=True # so input is (batch, seq_len, d_model)
)
output_pt, weights_pt = mha_pytorch(x, x, x) # self-attention: Q=K=V=x
print(f"PyTorch MHA output shape: {output_pt.shape}") # (1, 5, 8)
print(f"PyTorch MHA weights shape: {weights_pt.shape}") # (1, 5, 5)
The three arguments (x, x, x) mean we’re doing self-attention — queries, keys, and values all come from the same input. In encoder-decoder models, keys and values can come from a different source, but for language models like GPT, it’s always self-attention.
6. Causal Masking
Why Language Models Must Not See the Future
When training a language model, the task is: given the words so far, predict the next word. For the sentence “The cat sat on the mat”:
| Input so far | Predict next |
|---|---|
| The | cat |
| The cat | sat |
| The cat sat | on |
| The cat sat on | the |
| The cat sat on the | mat |
If the model could look at the entire sentence while predicting, it would just copy the answer — that’s cheating. The model must only attend to words at the current position and before it, never to future words.
This restriction is called causal masking (also called autoregressive masking).
The Triangular Mask
Causal masking is implemented with a simple triangular matrix. For a 5-token sequence:
Position 0 can see: [✓ ✗ ✗ ✗ ✗] → only itself
Position 1 can see: [✓ ✓ ✗ ✗ ✗] → itself and position 0
Position 2 can see: [✓ ✓ ✓ ✗ ✗] → positions 0, 1, 2
Position 3 can see: [✓ ✓ ✓ ✓ ✗] → positions 0, 1, 2, 3
Position 4 can see: [✓ ✓ ✓ ✓ ✓] → all positions
This is a lower triangular matrix: 1s on and below the diagonal, 0s above.
Implementation
def create_causal_mask(seq_len):
"""
Creates a causal (autoregressive) mask.
Returns a lower-triangular matrix of shape (seq_len, seq_len)
where entry (i, j) is 1 if position i can attend to position j
(i.e., j <= i), and 0 otherwise.
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask
# Create mask for 5 tokens
mask = create_causal_mask(5)
print(f"Causal mask shape: {mask.shape}")
print(f"\nCausal mask:\n{mask}")
# tensor([[1., 0., 0., 0., 0.],
# [1., 1., 0., 0., 0.],
# [1., 1., 1., 0., 0.],
# [1., 1., 1., 1., 0.],
# [1., 1., 1., 1., 1.]])
How the Mask Is Applied
The mask is applied to the attention scores before softmax. Positions that should be invisible are set to negative infinity (-inf), so that softmax assigns them a weight of exactly zero.
def scaled_dot_product_attention_with_mask(Q, K, V, mask=None):
"""
Scaled dot-product attention with optional causal mask.
Args:
Q: Queries (..., seq_len, d_k)
K: Keys (..., seq_len, d_k)
V: Values (..., seq_len, d_v)
mask: Binary mask (seq_len, seq_len), 1 = attend, 0 = block
Returns:
output, weights
"""
d_k = Q.shape[-1]
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
# Apply causal mask: set blocked positions to -inf
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
output = weights @ V
return output, weights
# Demonstrate: 4 tokens, 8-dimensional embeddings
torch.manual_seed(0)
seq_len = 4
d_k = 8
Q = torch.randn(seq_len, d_k)
K = torch.randn(seq_len, d_k)
V = torch.randn(seq_len, d_k)
# Without mask — every token attends to every other token
output_no_mask, weights_no_mask = scaled_dot_product_attention_with_mask(Q, K, V)
print("Attention weights WITHOUT mask:")
print(weights_no_mask)
print()
# With causal mask — each token only attends to itself and previous tokens
mask = create_causal_mask(seq_len)
output_masked, weights_masked = scaled_dot_product_attention_with_mask(Q, K, V, mask)
print("Attention weights WITH causal mask:")
print(weights_masked)
# Notice: upper triangle is all zeros
Look at the masked weights carefully. The upper triangle (future positions) is zero. Token 0 only attends to itself (100% weight). Token 1 attends to tokens 0 and 1. Token 2 attends to tokens 0, 1, and 2. And so on.
This is exactly what we want for autoregressive language models like GPT.
7. Layer Normalization
Why Normalization Matters
As data flows through a deep network — layer after layer of matrix multiplications and non-linear activations — the numbers can drift. Sometimes they grow very large. Sometimes they shrink toward zero. When numbers drift out of a reasonable range, two bad things happen:
- Training becomes unstable. Large values cause large gradients, which cause large parameter updates, which cause even larger values. This is a vicious cycle called “exploding gradients.”
- Learning slows down. When the input distribution to each layer keeps shifting, the layer wastes time adapting to the shifting input rather than learning useful patterns. This is called “internal covariate shift.”
Layer normalization (LayerNorm) fixes this by normalizing the values within each sample to have zero mean and unit variance. It’s like constantly recalibrating a thermometer so 0 always means average and 1 always means one standard deviation above.
Implementation from Scratch
class LayerNorm(nn.Module):
"""
Layer Normalization from scratch.
For each input vector, normalizes to zero mean and unit variance,
then applies learned scale (gamma) and shift (beta).
Args:
d_model: Dimension of the input
eps: Small constant for numerical stability (avoid division by zero)
"""
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.eps = eps
# Learnable parameters: the network can undo normalization if needed
self.gamma = nn.Parameter(torch.ones(d_model)) # scale
self.beta = nn.Parameter(torch.zeros(d_model)) # shift
def forward(self, x):
"""
Args:
x: Input tensor of shape (..., d_model)
Returns:
Normalized tensor of same shape
"""
# Compute mean and variance along the last dimension
mean = x.mean(dim=-1, keepdim=True) # (..., 1)
var = x.var(dim=-1, keepdim=True, unbiased=False) # (..., 1)
# Normalize: (x - mean) / sqrt(var + eps)
x_norm = (x - mean) / torch.sqrt(var + self.eps) # (..., d_model)
# Apply learned scale and shift
output = self.gamma * x_norm + self.beta # (..., d_model)
return output
# ---- Test it ----
d_model = 8
ln = LayerNorm(d_model)
# Random input: batch of 2 sequences, 4 tokens each, 8 dimensions
x = torch.randn(2, 4, d_model) * 10 + 5 # deliberately shifted and scaled
print(f"Input mean (per token): {x[0, 0].mean().item():.2f}")
print(f"Input std (per token): {x[0, 0].std().item():.2f}")
output = ln(x)
print(f"\nAfter LayerNorm mean: {output[0, 0].mean().item():.4f}") # ≈ 0
print(f"After LayerNorm std: {output[0, 0].std(unbiased=False).item():.4f}") # ≈ 1
print(f"Output shape: {output.shape}") # Same as input: (2, 4, 8)
Comparison with PyTorch’s nn.LayerNorm
# PyTorch's built-in LayerNorm
ln_pytorch = nn.LayerNorm(d_model)
# Both should produce normalized output with mean ≈ 0, std ≈ 1
output_ours = ln(x)
output_pytorch = ln_pytorch(x)
print(f"Our LayerNorm output mean: {output_ours[0, 0].mean().item():.6f}")
print(f"PyTorch LayerNorm output mean: {output_pytorch[0, 0].mean().item():.6f}")
print(f"Our LayerNorm output std: {output_ours[0, 0].std(unbiased=False).item():.4f}")
print(f"PyTorch LayerNorm output std: {output_pytorch[0, 0].std(unbiased=False).item():.4f}")
The values won’t be identical (because gamma and beta are initialized randomly in our version vs. ones/zeros in PyTorch), but the normalization properties are the same.
8. Feed-Forward Network
The “Thinking” Layer
After attention aggregates contextual information, the model needs to process that information — apply non-linear transformations that extract higher-level patterns. This is the job of the position-wise feed-forward network (FFN).
It’s called “position-wise” because it applies the same transformation independently to each token position. While attention mixes information across tokens, the FFN processes each token’s representation independently.
The structure is simple: two linear layers with a non-linear activation in between.
Input (d_model) → Linear (d_model → d_ff) → Activation → Linear (d_ff → d_model) → Output
The intermediate dimension d_ff is typically 4 times larger than d_model. If your embeddings are 768-dimensional, the FFN expands them to 3072 dimensions, applies a non-linear function, then compresses back to 768. This “expand-then-compress” pattern gives the network room to represent more complex features in the middle.
Why GELU over ReLU?
Older networks used ReLU (Rectified Linear Unit): max(0, x). It’s simple — negative values become 0, positive values stay as they are. But ReLU has a sharp corner at 0, and it completely kills negative values.
GELU (Gaussian Error Linear Unit) is a smooth approximation that has become standard in modern transformers (GPT-2, BERT, etc.). Instead of a hard cutoff at 0, GELU gradually tapers negative values. It’s defined as:
$$\text{GELU}(x) = x \cdot \Phi(x)$$
where $\Phi(x)$ is the cumulative distribution function of the standard normal distribution. In practice, it’s smoother and trains better.
# Visualize ReLU vs GELU
import matplotlib
matplotlib.use('Agg') # non-interactive backend
import matplotlib.pyplot as plt
x_vals = torch.linspace(-4, 4, 200)
relu_vals = F.relu(x_vals)
gelu_vals = F.gelu(x_vals)
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax.plot(x_vals.numpy(), relu_vals.numpy(), label='ReLU', linewidth=2)
ax.plot(x_vals.numpy(), gelu_vals.numpy(), label='GELU', linewidth=2)
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Input')
ax.set_ylabel('Output')
ax.set_title('ReLU vs GELU Activation Functions')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('relu_vs_gelu.png', dpi=100)
print("Saved relu_vs_gelu.png")
Implementation
class FeedForward(nn.Module):
"""
Position-wise Feed-Forward Network.
Two linear transformations with GELU activation in between.
Expands to d_ff dimensions, then compresses back to d_model.
Args:
d_model: Input and output dimension
d_ff: Hidden dimension (typically 4 * d_model)
"""
def __init__(self, d_model, d_ff=None):
super().__init__()
if d_ff is None:
d_ff = 4 * d_model # standard multiplier
self.linear1 = nn.Linear(d_model, d_ff) # expand
self.linear2 = nn.Linear(d_ff, d_model) # compress
self.activation = nn.GELU()
def forward(self, x):
"""
Args:
x: Input of shape (batch_size, seq_len, d_model)
Returns:
Output of shape (batch_size, seq_len, d_model)
"""
# Expand: (batch, seq, d_model) → (batch, seq, d_ff)
x = self.linear1(x)
print(f" After expand: {x.shape}")
# Non-linear activation
x = self.activation(x)
# Compress: (batch, seq, d_ff) → (batch, seq, d_model)
x = self.linear2(x)
print(f" After compress: {x.shape}")
return x
# ---- Test it ----
d_model = 8
ffn = FeedForward(d_model)
x = torch.randn(1, 5, d_model) # batch=1, seq_len=5, d_model=8
print(f"Input shape: {x.shape}")
output = ffn(x)
print(f"Output shape: {output.shape}") # Same as input: (1, 5, 8)
9. Residual Connections
The Information Highway
Imagine a highway with multiple exits leading to small towns. Even if every town is a dead end, you can always get back on the highway and keep driving. The highway guarantees forward progress regardless of what happens at the exits.
Residual connections (also called skip connections) do the same thing for neural networks. Instead of the output of a layer being just the processed result, it’s the processed result plus the original input:
$$\text{output} = x + \text{sublayer}(x)$$
This simple addition has profound effects:
-
Gradient flow. During training, gradients flow backward through the network. Without residual connections, gradients must travel through every transformation, and they often shrink to near-zero in deep networks (vanishing gradients). The residual connection provides a direct path — a “highway” — for gradients to flow through, enabling the training of very deep networks.
-
Safe to add layers. If a sublayer learns nothing useful (outputs near zero), the residual connection ensures the output is still approximately equal to the input. The layer does no harm. This means we can stack many layers without risk — a layer can always “choose” to be an identity function.
-
Easier optimization. Instead of learning the full transformation from input to output, each layer only needs to learn the difference — what should be added or subtracted from the input. This is often a simpler function to learn.
Implementation
Residual connections are almost trivially simple to implement:
class ResidualConnection(nn.Module):
"""
Wraps a sublayer with a residual connection.
output = x + sublayer(x)
"""
def __init__(self):
super().__init__()
def forward(self, x, sublayer):
"""
Args:
x: Input tensor
sublayer: A callable (nn.Module or function) to apply
Returns:
x + sublayer(x)
"""
return x + sublayer(x)
# ---- Demonstrate the effect ----
d_model = 8
residual = ResidualConnection()
x = torch.randn(1, 5, d_model)
# A sublayer that does almost nothing (outputs near zero)
zero_sublayer = lambda x: x * 0.01
# Without residual: output ≈ 0 (information destroyed)
output_no_residual = zero_sublayer(x)
print(f"Without residual — mean magnitude: {output_no_residual.abs().mean():.4f}")
# With residual: output ≈ x (information preserved)
output_with_residual = residual(x, zero_sublayer)
print(f"With residual — mean magnitude: {output_with_residual.abs().mean():.4f}")
print(f"Original input — mean magnitude: {x.abs().mean():.4f}")
# Output matches original input — the signal survives!
10. Putting It Together — The Complete Transformer Block
Now we have all the pieces. A single transformer block combines them in a specific pattern:
Input
│
├──────────────────┐
│ │ (Residual Connection 1)
▼ │
LayerNorm │
│ │
▼ │
Multi-Head Attention │
│ │
▼ │
+ ◄────────────────┘
│
├──────────────────┐
│ │ (Residual Connection 2)
▼ │
LayerNorm │
│ │
▼ │
Feed-Forward │
│ │
▼ │
+ ◄────────────────┘
│
▼
Output
This is the Pre-Norm variant (LayerNorm before each sublayer), which is used in GPT-2 and most modern transformers. The original paper placed LayerNorm after each sublayer (Post-Norm), but Pre-Norm has been found to train more stably.
Complete Implementation
class TransformerBlock(nn.Module):
"""
A single transformer block (Pre-Norm variant).
Combines:
1. LayerNorm → Multi-Head Attention + Residual
2. LayerNorm → Feed-Forward Network + Residual
Args:
d_model: Model dimension (embedding size)
num_heads: Number of attention heads
d_ff: Feed-forward hidden dimension (default: 4 * d_model)
"""
def __init__(self, d_model, num_heads, d_ff=None):
super().__init__()
if d_ff is None:
d_ff = 4 * d_model
# Sub-layers
self.attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
# Layer norms (one for each sub-layer)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
"""
Args:
x: Input of shape (batch_size, seq_len, d_model)
mask: Optional causal mask of shape (seq_len, seq_len)
Returns:
Output of shape (batch_size, seq_len, d_model)
"""
# Sub-layer 1: LayerNorm → Multi-Head Attention + Residual
residual = x
x = self.norm1(x) # Normalize
x = self.attention(x, mask=mask) # Attend
x = residual + x # Add residual
# Sub-layer 2: LayerNorm → Feed-Forward + Residual
residual = x
x = self.norm2(x) # Normalize
x = self.feed_forward(x) # Transform
x = residual + x # Add residual
return x
# ---- Test the complete transformer block ----
d_model = 64 # realistic-ish embedding size
num_heads = 4 # 4 heads, each with d_k = 16
seq_len = 10 # 10 tokens
batch_size = 2 # 2 sequences
block = TransformerBlock(d_model, num_heads)
# Count parameters
total_params = sum(p.numel() for p in block.parameters())
print(f"Transformer block parameters: {total_params:,}")
# Create input: batch of random embeddings
torch.manual_seed(42)
x = torch.randn(batch_size, seq_len, d_model)
print(f"\nInput shape: {x.shape}") # (2, 10, 64)
# Create causal mask
mask = create_causal_mask(seq_len)
# Forward pass
output = block(x, mask=mask)
print(f"Output shape: {output.shape}") # (2, 10, 64) — same as input!
# Verify shapes are identical
assert x.shape == output.shape, "Input and output shapes must match!"
print("\n✓ Input and output shapes match — the block can be stacked!")
Stacking Transformer Blocks
In practice, transformer models stack many blocks on top of each other. GPT-2 Small uses 12 blocks. GPT-3 uses 96. Each block refines the representations, adding more context and extracting higher-level patterns.
class TransformerStack(nn.Module):
"""
Stack multiple transformer blocks.
This is the core of any transformer model.
"""
def __init__(self, num_layers, d_model, num_heads, d_ff=None):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
# Final layer norm (standard in GPT-style models)
self.final_norm = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
for i, layer in enumerate(self.layers):
x = layer(x, mask=mask)
x = self.final_norm(x)
return x
# Stack 4 transformer blocks
stack = TransformerStack(
num_layers=4,
d_model=64,
num_heads=4
)
total_params = sum(p.numel() for p in stack.parameters())
print(f"\n4-layer transformer stack parameters: {total_params:,}")
output = stack(x, mask=mask)
print(f"Stacked output shape: {output.shape}") # (2, 10, 64)
11. Summary
Let’s recap what each component does and why it exists:
| Component | What It Does | Why It’s Needed |
|---|---|---|
| Self-Attention | Lets each token gather information from all other tokens | Words get meaning from context |
| Scaling (÷√d_k) | Prevents attention scores from getting too large | Keeps softmax distributions useful |
| Multi-Head Attention | Runs multiple attention patterns in parallel | Captures different types of relationships |
| Causal Mask | Prevents tokens from seeing future tokens | Required for next-word prediction |
| Layer Normalization | Keeps values in a stable range | Prevents exploding/vanishing values |
| Feed-Forward Network | Non-linear transformation of each token | Adds processing capacity and complexity |
| Residual Connections | Adds input directly to sublayer output | Enables training deep networks |
| Transformer Block | Combines all of the above into a reusable unit | The building block of modern LLMs |
In the next chapter, we’ll use these components to build a complete language model that can actually generate text.
12. Exercises
Exercise 1: Attention Weight Visualization
Write a function that takes a sentence, runs it through a simple self-attention computation, and prints the attention weight matrix with word labels. Use the sentence “I love cats and dogs”.
Hint: You’ll need to create synthetic embeddings for the words and compute Q, K, V.
Solution
def visualize_attention(sentence):
"""
Compute and display attention weights for a sentence.
Uses random embeddings for demonstration.
"""
words = sentence.split()
seq_len = len(words)
d_model = 16
# Create random embeddings (in a real model, these come from nn.Embedding)
torch.manual_seed(42)
embeddings = torch.randn(seq_len, d_model)
# Create random weight matrices
W_Q = torch.randn(d_model, d_model)
W_K = torch.randn(d_model, d_model)
# Compute Q and K
Q = embeddings @ W_Q
K = embeddings @ W_K
# Scaled attention scores
scores = Q @ K.T / math.sqrt(d_model)
weights = F.softmax(scores, dim=-1)
# Print the attention matrix with labels
print(f"{'':>10}", end="")
for w in words:
print(f"{w:>10}", end="")
print()
for i, word in enumerate(words):
print(f"{word:>10}", end="")
for j in range(seq_len):
print(f"{weights[i, j].item():>10.3f}", end="")
print()
print(f"\nRow sums: {weights.sum(dim=-1).tolist()}")
visualize_attention("I love cats and dogs")
Exercise 2: Effect of Scaling on Attention
Create an experiment that demonstrates the effect of scaling on attention distributions. Generate high-dimensional Q and K vectors (d_k = 512), compute attention scores with and without scaling, and compare the resulting softmax distributions. Use entropy to measure how “spread out” each distribution is.
Hint: Entropy is computed as $H = -\sum p_i \log(p_i)$. Higher entropy means a more uniform distribution.
Solution
def compute_entropy(probs):
"""Compute entropy of a probability distribution."""
# Add small epsilon to avoid log(0)
log_probs = torch.log(probs + 1e-10)
entropy = -(probs * log_probs).sum(dim=-1)
return entropy
# High-dimensional scenario
d_k = 512
seq_len = 20
torch.manual_seed(42)
Q = torch.randn(seq_len, d_k)
K = torch.randn(seq_len, d_k)
# Unscaled scores
scores_unscaled = Q @ K.T
weights_unscaled = F.softmax(scores_unscaled, dim=-1)
# Scaled scores
scores_scaled = Q @ K.T / math.sqrt(d_k)
weights_scaled = F.softmax(scores_scaled, dim=-1)
# Compare entropy (higher = more spread out)
entropy_unscaled = compute_entropy(weights_unscaled).mean()
entropy_scaled = compute_entropy(weights_scaled).mean()
# Maximum possible entropy for uniform distribution
max_entropy = math.log(seq_len)
print(f"Dimension d_k: {d_k}")
print(f"Sequence length: {seq_len}")
print(f"Maximum entropy (uniform): {max_entropy:.4f}")
print(f"Unscaled entropy: {entropy_unscaled:.4f}")
print(f"Scaled entropy: {entropy_scaled:.4f}")
print(f"\nMax weight (unscaled): {weights_unscaled.max():.6f}")
print(f"Max weight (scaled): {weights_scaled.max():.6f}")
print("\nScaled attention is more spread out → healthier gradients!")
Exercise 3: Implement Multi-Head Attention with Different Head Counts
Implement an experiment that creates multi-head attention layers with 1, 2, 4, and 8 heads (keeping d_model = 64 constant), runs the same input through each, and prints the number of parameters in each configuration. Verify that they all have the same number of parameters.
Solution
d_model = 64
seq_len = 10
batch_size = 1
torch.manual_seed(42)
x = torch.randn(batch_size, seq_len, d_model)
for num_heads in [1, 2, 4, 8]:
mha = MultiHeadAttention(d_model, num_heads)
output = mha(x)
num_params = sum(p.numel() for p in mha.parameters())
d_k = d_model // num_heads
print(f"\nHeads: {num_heads}, d_k per head: {d_k}")
print(f" Parameters: {num_params:,}")
print(f" Output shape: {output.shape}")
print("\n✓ All configurations have the same number of parameters!")
print(" The total computation is the same — it's just organized differently.")
print(" More heads = more parallel 'perspectives' on the data,")
print(" but each perspective sees fewer dimensions.")
Exercise 4: Build and Test a Complete Transformer Block with Masking
Create a transformer block, pass a batch of sequences through it with causal masking, and verify:
- The output shape matches the input shape.
- Changing a future token does NOT affect earlier positions’ outputs.
This second check is crucial — it confirms that the causal mask is actually preventing information leakage.
Solution
d_model = 32
num_heads = 4
seq_len = 6
batch_size = 1
# Build block and set to eval mode (deterministic)
block = TransformerBlock(d_model, num_heads)
block.eval()
# Create causal mask
mask = create_causal_mask(seq_len)
# Create input
torch.manual_seed(42)
x1 = torch.randn(batch_size, seq_len, d_model)
# Forward pass with original input
with torch.no_grad():
out1 = block(x1, mask=mask)
# Modify the LAST token (position 5) in the input
x2 = x1.clone()
x2[0, 5, :] = torch.randn(d_model) * 100 # dramatically different
# Forward pass with modified input
with torch.no_grad():
out2 = block(x2, mask=mask)
# Check 1: Output shape matches input
assert out1.shape == x1.shape
print(f"✓ Output shape matches input: {out1.shape}")
# Check 2: Earlier positions should be UNCHANGED
for pos in range(seq_len):
diff = (out1[0, pos] - out2[0, pos]).abs().max().item()
changed = "CHANGED" if diff > 1e-5 else "unchanged"
print(f" Position {pos}: max difference = {diff:.8f} → {changed}")
print("\n✓ Positions 0-4 are unchanged even though position 5 was modified!")
print(" The causal mask correctly prevents information from flowing backward.")