The transformer architecture, introduced in the landmark 2017 paper "Attention Is All You Need," fundamentally changed how we approach sequence modeling in deep learning. While the paper's mathematical notation can feel imposing, the underlying ideas are elegant and surprisingly intuitive once you work through them step by step. In this deep dive, we'll implement a transformer from scratch using PyTorch, building each component carefully so you understand not just what transformers do, but why they work so well.
The Core Insight: Why Attention Matters
Before transformers, researchers used recurrent neural networks (RNNs) and their variants to process sequences. RNNs process tokens sequentially, passing hidden state from one timestep to the next. This design has a critical flaw: information from distant tokens is diluted as it propagates through many layers, and the sequential nature makes parallelization impossible.
The transformer approach asks a simpler question: why do we need sequential processing at all? What if we could let each token directly look at every other token in the sequence? This is the essence of the attention mechanism. Instead of relying on a hidden state bottleneck, attention allows tokens to gather information from relevant parts of the sequence directly, and it does so in parallel.
Consider translating a sentence from English to French. When translating a particular word, you don't need to read the entire sentence sequentially. You can look at key words that are semantically related to that word, regardless of their position. Attention formalizesthis intuition mathematically.
Scaled Dot-Product Attention: The Building Block
The foundation of the transformer is the scaled dot-product attention mechanism. The concept is straightforward: given a query, we score how relevant every key is to that query, then use those scores as weights to aggregate values.
The formula is:
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
Let's break this down. We have three matrix inputs:
- Queries (Q): What we're looking for. Each token produces a query vector.
- Keys (K): What's available to look at. Each token produces a key vector.
- Values (V): The actual information to aggregate. Each token produces a value vector.
The attention mechanism works in three steps:
- Compute scores: QK^T produces a matrix where each entry (i,j) represents how much query i should attend to key j.
- Scale and normalize: We divide by sqrt(d_k) where d_k is the dimension of the key vectors. This prevents the dot products from growing too large (which would make the softmax nearly one-hot). Then softmax converts scores to probabilities.
- Aggregate values: The softmax weights are multiplied with the value matrix to produce the output.
Let's implement this:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
d_k = Q.shape[-1]
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask if provided (for causal masking in decoder)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Apply softmax and dropout
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Aggregate values
output = torch.matmul(attention_weights, V)
return output, attention_weights
This implementation captures the essence of attention. The key insight is that this entire operation is differentiable and can be computed in parallel for all sequence positions, which is why transformers scale so well.
Multi-Head Attention: Looking from Multiple Perspectives
Using attention from a single perspective is limiting. Different types of relationships in the sequence might benefit from different representational spaces. The solution is multi-head attention: run several attention operations in parallel, each with different learned linear projections, then concatenate the results.
The math is elegantly simple:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O
where head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
Each head learns different projection matrices (W^Q, W^K, W^V) for transforming the input. One head might learn to attend to syntactic relationships while another focuses on semantic meaning. The output projection W^O recombines all heads.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(dropout)
def forward(self, Q, K, V, mask=None):
batch_size = Q.shape[0]
# Linear projections
Q = self.W_q(Q)
K = self.W_k(K)
V = self.W_v(V)
# Reshape for multi-head: (batch, seq_len, d_model)
# -> (batch, seq_len, num_heads, d_k)
# -> (batch, num_heads, seq_len, d_k)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Apply attention
output, weights = self.attention(Q, K, V, mask)
# Concatenate heads
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, -1, self.num_heads * self.d_k)
# Final linear projection
output = self.W_o(output)
return output, weights
Multi-head attention is where the magic happens. By having multiple heads operating independently, the model can capture diverse types of relationships in a single forward pass. This is one reason transformers are so effective.
Positional Encoding and the Feed-Forward Network
Unlike RNNs, transformers have no inherent notion of order. A sequence [A, B, C] looks identical to [C, B, A] from the attention mechanism's perspective. We need to inject position information explicitly through positional encoding.
The original paper uses sinusoidal positional encodings:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
This choice is clever: the positional encoding for position pos is a combination of sine and cosine functions at different frequencies. This allows the model to learn relative positions and extrapolate to sequence lengths longer than those seen during training.
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create positional encoding matrix
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
x = x + self.pe[:, :x.shape[1], :]
return self.dropout(x)
Beyond attention, each transformer layer includes a feed-forward network (FFN) applied to each position independently. This consists of two linear layers with a non-linearity between them:
FFN(x) = max(0, xW_1 + b_1) W_2 + b_2
The intermediate dimension is typically 4 times the model dimension, providing capacity for non-linear transformations. While attention handles communication between positions, the FFN provides position-wise feature transformation.
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
Assembling the Transformer: Layer Normalization and Residual Connections
The final pieces are layer normalization and residual connections. These are crucial for stable training of deep networks.
Layer normalization normalizes each sample independently across the feature dimension, stabilizing the distribution of hidden states. Residual connections (skip connections) allow gradients to flow directly through the network, enabling much deeper models.
A transformer encoder layer combines everything:
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Multi-head attention with residual and normalization
attn_output, _ = self.attention(x, x, x, mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# Feed-forward with residual and normalization
ff_output = self.feed_forward(x)
x = x + self.dropout2(ff_output)
x = self.norm2(x)
return x
The pattern is consistent: apply a sub-layer, add a residual connection, then apply layer normalization. This "pre-norm" architecture (normalization after the residual connection) has been found to train more stably than the original "post-norm" design.
Stacking multiple encoder layers creates a full transformer encoder:
class TransformerEncoder(nn.Module):
def __init__(self, d_model, num_layers, num_heads, d_ff, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, dropout=dropout)
self.layers = nn.ModuleList([
TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Token embeddings + positional encoding
x = self.embedding(x)
x = self.pos_encoding(x)
# Apply transformer layers
for layer in self.layers:
x = layer(x, mask)
# Final normalization
x = self.norm(x)
return x
Why This Works: The Power of Attention at Scale
The transformer architecture succeeds because attention provides a mechanism for direct information flow between any two positions in the sequence, unmediated by intermediate hidden states. This stands in stark contrast to RNNs, where information must flow sequentially.
This design choice has profound consequences. First, transformers can be trained in parallel across the entire sequence, unlike RNNs which must process sequentially. Second, long-range dependencies can be captured directly without vanishing gradients that plague deep RNNs. Third, the interpretability of attention weights gives us insight into what the model is learning.
When you scale up the number of layers, heads, and model dimensions, and train on massive amounts of data, this architecture produces remarkably capable systems. The same basic building blocks that power GPT, BERT, and other state-of-the-art models are what we've just built.
Understanding transformers at this level—by implementing them from scratch—provides invaluable intuition that shallow knowledge cannot match. You now understand not just what attention is, but how every component serves a purpose in enabling the model to learn effectively on sequence data.