All Posts

Attention Mechanism Explained: How Transformers Learn to Focus

From the query-key-value formulation to multi-head attention — the mechanism that made transformers dominate AI

Abstract AlgorithmsAbstract Algorithms
··28 min read
Cover Image for Attention Mechanism Explained: How Transformers Learn to Focus
Share
AI Share on X / Twitter
AI Share on LinkedIn
Copy link

TLDR: Attention lets every token in a sequence ask "what else is relevant to me?" — dynamically weighting relationships across all positions simultaneously. It replaced the fixed-size hidden-state bottleneck of RNNs and is the engine behind every GPT, BERT, ViT, and AlphaFold model you have ever used.


📖 The Bottleneck That Broke Machine Translation

In 2014, Google's sequence-to-sequence translation model had a serious problem. To translate a long French sentence into English, its LSTM encoder read every word one by one and compressed the entire sentence into a single fixed-size vector — a 512-dimensional hidden state. The decoder then reconstructed the English translation from that one vector alone.

For short sentences like "Le chat boit du lait""The cat drinks milk", this worked fine. For sentences like a paragraph from Proust, it catastrophically failed. The encoder was being asked to stuff everything a 40-word sentence means — grammar, named entities, clause structure, word order — into 512 numbers. By the time the decoder reached the end of the English output, the beginning of the French input had essentially been forgotten.

This is called the context bottleneck problem, and it was the hard ceiling on RNN-based NLP performance through the mid-2010s. The fix, described in Bahdanau et al. (2015) and later crystallised in the Transformer paper (Vaswani et al., 2017), is the attention mechanism: instead of compressing the source into a single vector, let the decoder reach back and directly inspect every encoder state, weighting each by relevance to the current decoding step.

That single change unlocked parallelism, long-range dependency modelling, and scaling to billion-parameter models. Every GPT, BERT, ViT, Whisper, and AlphaFold model descends from this idea.


🔍 The Database Query Analogy: Queries, Keys, and Values

Before touching a single matrix multiply, let us build the right mental model.

Imagine a key-value database — like Redis. You send a query, the database compares it against all stored keys, finds the best matches, and returns the associated values. The difference between attention and a real database is that attention returns a weighted blend of all values rather than a single exact match. High-similarity key-query pairs contribute more to the output; low-similarity pairs contribute almost nothing.

In a transformer:

  • Query (Q): The token asking "what information do I need right now?"
  • Key (K): What each token advertises — "here is the kind of information I hold."
  • Value (V): The actual content each token contributes when it is selected.

Every token generates its own Q, K, and V by multiplying its embedding through three separate learned weight matrices. The model learns these matrices end-to-end during training — it discovers what to query for and what to advertise purely from data.

The diagram below shows this flow for a single query token attending over a short sequence. Each source token's key is compared against the query to produce a raw score, softmax normalises those scores into a probability distribution, and the final output is a weighted sum of the values.

flowchart TD
    Q[Query token embedding] --> QW[Multiply by W_Q]
    E1[Source token 1 embedding] --> KW1[Multiply by W_K]
    E2[Source token 2 embedding] --> KW2[Multiply by W_K]
    E3[Source token 3 embedding] --> KW3[Multiply by W_K]
    E1 --> VW1[Multiply by W_V]
    E2 --> VW2[Multiply by W_V]
    E3 --> VW3[Multiply by W_V]
    QW --> DOT[Dot products Q dot K_i]
    KW1 --> DOT
    KW2 --> DOT
    KW3 --> DOT
    DOT --> SCALE[Divide by sqrt of d_k]
    SCALE --> SM[Softmax - attention weights]
    SM --> WSUM[Weighted sum of values]
    VW1 --> WSUM
    VW2 --> WSUM
    VW3 --> WSUM
    WSUM --> OUT[Output vector for query token]

The diagram shows one complete attention pass for a single query token. Notice that the value vectors from all source tokens feed into the final weighted sum — even tokens with near-zero attention weight still participate at a vanishingly small contribution. This soft, differentiable selection is what makes the mechanism trainable end-to-end via backpropagation.


⚙️ Scaled Dot-Product Attention: The Complete Mathematical Picture

Now that the intuition is clear, let us state the formula precisely. For a single attention head operating on queries Q, keys K, and values V — all matrices with rows representing individual tokens:

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Breaking this down symbol by symbol:

SymbolShapeMeaning
$Q$$(n \times d_k)$Query matrix — one row per query token
$K$$(m \times d_k)$Key matrix — one row per source token
$V$$(m \times d_v)$Value matrix — one row per source token
$d_k$scalarDimension of query/key vectors
$QK^T$$(n \times m)$Raw compatibility scores between every query-key pair
$\sqrt{d_k}$scalarScaling factor to prevent softmax saturation
$\text{softmax}(\cdot)$$(n \times m)$Row-wise normalisation — each row sums to 1.0
Output$(n \times d_v)$Contextualised representation for each query token

Why divide by $\sqrt{d_k}$? When $d_k$ is large (e.g., 64 or 512), the dot products between Q and K vectors grow in magnitude proportional to $\sqrt{d_k}$ due to the variance of random initialisation. Without scaling, these large values push softmax into regions where gradients are nearly zero — the model stops learning. Dividing by $\sqrt{d_k}$ keeps the pre-softmax scores in a comfortable range regardless of head dimension.

The attention weight matrix $A = \text{softmax}(QK^T / \sqrt{d_k})$ is the most interpretable artefact in the entire model. Each row $A_i$ is a probability distribution over source tokens, answering: "for query token $i$, what fraction of attention does it assign to each source token?" Visualising these matrices reveals what the model actually learned to focus on — and it is often strikingly human-interpretable.


🧪 Worked Example: Watching "The Cat" Track "It" Through Attention Scores

This example makes the abstract formula concrete. We take a simple six-token sentence and trace how attention scores distribute for one specific query token, so you can see the mechanism selecting information rather than just passing it through.

Consider the sentence: "The cat sat on the mat"

After tokenisation and embedding, each of the six tokens has been projected into Q, K, and V vectors. We focus on computing the attention output for the token "sat" — asking: which other tokens does "sat" attend to most?

Assume we have already computed the dot products $Q_{\text{sat}} \cdot K_j$ for all $j$:

Source tokenRaw dot productAfter ÷ √d_k (d_k=64)Softmax weight
The2.10.2630.07
cat8.41.0500.41
sat6.20.7750.22
on3.50.4380.09
the2.80.3500.08
mat5.30.6630.13

"sat" attends most heavily to "cat" (41%) — the subject performing the sitting — and secondarily to itself (22%) and "mat" (13%), the destination. Articles ("The", "the") and the preposition ("on") receive negligible weight. The output vector for "sat" is therefore a weighted average: 41% of "cat"'s value vector + 22% of its own value vector + 13% of "mat"'s value vector + ... This output vector carries a contextualised representation of the sitting event enriched by its subject and location — far more useful for downstream tasks than the raw embedding of "sat" alone.

This is precisely what lets BERT resolve pronouns: when the input is "The animal didn't cross the street because it was tired", the token "it" computes high attention scores toward "animal" and low scores toward "street", correctly encoding the coreference relationship into the output vector for "it".


🧠 Deep Dive: Multi-Head Attention Internals and What Each Head Learns

A single attention head can only learn one type of relationship at a time. A head tuned for syntactic subject-verb agreement operates in a subspace where syntactic distance is salient — it cannot simultaneously represent long-range pronoun coreference or semantic role structure, because those patterns live in orthogonal directions of the embedding space. Multi-head attention escapes this constraint by running $h$ independent attention operations in parallel, each projecting the input into its own learned subspace.

Attention Internals: How Independent Projection Matrices Split the Representation Space

The multi-head formula is:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O$$

where each head computes:

$$\text{head}_i = \text{Attention}(Q W_i^Q,\; K W_i^K,\; V W_i^V)$$

The projections $Wi^Q \in \mathbb{R}^{d{\text{model}} \times d_k}$, $W_i^K$, $Wi^V$ are unique per head. Each head therefore sees the same token embeddings but through a different linear lens — producing different Q, K, V vectors that score token relevance along completely different axes. The concatenated head outputs are then linearly projected back to $d{\text{model}}$ dimensions by $W^O$, merging all the captured signals into a single unified representation.

What do different heads learn? Empirical probing studies on BERT and GPT have found consistent specialisation across heads:

Head typeWhat it tracksExample
Syntactic dependencySubject → verb, adjective → noun"cats" → "sat"
Coreference resolutionPronoun → antecedent"it" → "animal"
Positional adjacencyAttending to the next or previous tokenBigram-like pattern
Semantic roleAgent, patient, location relationships"cat" → theme of "sat"
Long-range dependencyMatching open/close brackets or quotes"(" → matching ")"

In BERT-base (12 heads, 768 dimensions), each head operates in a 64-dimensional subspace. GPT-3 uses 96 heads across 12,288 dimensions, giving each head a 128-dimensional subspace with enormous parallel capacity. The diversity of what these heads learn is why simply increasing model size — and therefore head count — keeps improving performance: you are literally giving the model more independent "lenses" through which to read the input.

The diagram below shows three heads attending to the same input sequence simultaneously, then having their outputs concatenated before the final linear projection.

flowchart LR
    IN[Input token embeddings] --> H1[Head 1 - syntax]
    IN --> H2[Head 2 - coreference]
    IN --> H3[Head 3 - positional]
    H1 --> OUT1[head_1 output]
    H2 --> OUT2[head_2 output]
    H3 --> OUT3[head_3 output]
    OUT1 --> CONCAT[Concatenate all heads]
    OUT2 --> CONCAT
    OUT3 --> CONCAT
    CONCAT --> PROJ[Linear projection W_O]
    PROJ --> FINAL[Final contextualised output]

Each of the three attention heads independently transforms the input through its own learned Q, K, V projections, capturing different relationship types. Their outputs are concatenated along the feature dimension and then compressed back to the model's hidden size through $W^O$, combining all captured relationships into a single unified representation.

Performance Analysis: Computational Cost and Memory Footprint Per Attention Head

Each individual head operates in a $dk = d{\text{model}} / h$ dimensional space. For BERT-base with $d_{\text{model}} = 768$ and $h = 12$ heads, each head dimension is $d_k = 64$. This split is deliberate: the total parameter count for the Q, K, V, and output projections stays constant regardless of the number of heads — the model gains representational diversity without an increase in compute.

The per-head attention matrix has shape $(n \times n)$. Across all heads the combined attention tensor is $(h \times n \times n)$. For a 512-token sequence with 12 heads, that is $12 \times 512 \times 512 = 3{,}145{,}728$ float values per attention layer — before considering the value aggregation and output projection. Training backpropagation doubles this footprint because activations must be retained for the backward pass.

The total compute cost scales as $O(h \cdot n^2 \cdot dk) = O(n^2 \cdot d{\text{model}})$ — the head count cancels with the reduction in $dk$ per head. Adding more heads at constant $d{\text{model}}$ does not increase compute. What drives cost up is growing sequence length $n$ (quadratic) or model dimension $d{\text{model}}$ (linear). This is why reducing $d{\text{model}}$ is the most effective lever for efficient inference.

ConfigurationHeadsd_modeld_k per headAttn matrix size at n=512W_Q+K+V+O params
BERT-base12768643.1M floats~2.4M
BERT-large161024644.2M floats~4.2M
GPT-2 medium161024644.2M floats~4.2M
GPT-3961228812825.2M floats~603M

🔄 Self-Attention vs. Cross-Attention: Same Mechanism, Different Data Sources

The attention formula is identical in both cases — the only difference is where Q, K, and V come from.

Self-attention: All three matrices — Q, K, and V — are derived from the same sequence. Every token looks at every other token in the same sequence. This is what encoder blocks in BERT use: the word "bank" in "river bank" attends to "river" and "water" to resolve its meaning. It is also what decoder-only GPT models use for the attended-to prefix.

Cross-attention: The queries Q come from one sequence while the keys K and values V come from a different sequence. This is the bridge that connects encoder and decoder in translation models. The decoder (generating English) sends queries; the encoder (processing French) provides keys and values. The decoder can directly "look up" which French words are most relevant to the English word currently being generated.

PropertySelf-AttentionCross-Attention
Q sourceCurrent sequenceDecoder sequence
K, V sourceSame current sequenceEncoder output
Primary useEncoder stacks, decoder causal maskingEncoder-decoder bridge
Models that use itBERT, GPT, ViTT5, Whisper, original Transformer
Sequence lengthsQ len = K/V lenQ len ≠ K/V len (often)

There is also causal self-attention (also called masked self-attention), used in GPT-style decoders. Here, each token can only attend to tokens that come before it in the sequence — future positions are masked to $-\infty$ before softmax, forcing the model to predict the next token without cheating by looking ahead. This is the mechanism that enables autoregressive text generation.


📊 Visualizing the Complete Attention Pipeline: From Raw Tokens to Contextualised Representations

The individual steps — embedding lookup, Q/K/V projection, score computation, softmax, value aggregation — are each straightforward, but their composition into a full forward pass is worth seeing end to end. The diagram below traces a single token's journey through one complete multi-head self-attention layer, from the initial embedding to the residual-normalised output that feeds into the next layer.

flowchart TD
    TK[Raw input tokens] --> EMB[Token embeddings plus positional encodings]
    EMB --> QP[W_Q projection per head]
    EMB --> KP[W_K projection per head]
    EMB --> VP[W_V projection per head]
    QP --> QV[Query vectors Q]
    KP --> KV[Key vectors K]
    VP --> VV[Value vectors V]
    QV --> SCORES[QK-T divided by sqrt d_k - raw scores]
    KV --> SCORES
    SCORES --> MASK{Causal mask needed?}
    MASK -->|decoder yes| MASKED[Apply minus-inf to future positions]
    MASK -->|encoder no| SM[Softmax over key dimension]
    MASKED --> SM
    SM --> ATTN[Attention weight matrix A]
    ATTN --> WSUM[Weighted sum A times V]
    VV --> WSUM
    WSUM --> WO[Output projection W_O - merge all heads]
    WO --> RESID[Add residual connection plus LayerNorm]
    RESID --> NEXT[Feed-forward sublayer or next encoder block]

The diagram highlights two paths at the masking decision point: encoder self-attention skips the causal mask and every token can attend to every other token freely, while decoder self-attention applies the upper-triangular mask so each position can only aggregate information from its past. The residual connection and LayerNorm at the end are not part of the attention mechanism itself but are inseparable from how it is deployed in practice — they stabilise gradient flow through the dozens of stacked layers that make up deep transformer models.


🌍 Where Attention Powers Modern AI: From Text to Proteins

Attention is not a niche text trick. In the eight years since the original Transformer paper, it has become the dominant mechanism across virtually every AI domain.

BERT (Google, 2018): Uses 12 stacked self-attention encoder layers. The bidirectional attention — every token sees every other token simultaneously — enabled state-of-the-art results on 11 NLP benchmarks on release. BERT is still the backbone of Google Search's query understanding.

GPT series (OpenAI, 2018–2023): Causal self-attention in a decoder-only architecture. GPT-3's 96-layer, 175 billion-parameter model is entirely composed of stacked multi-head causal self-attention blocks interleaved with feedforward layers. No convolution, no recurrence — attention all the way down.

Vision Transformer (ViT, Google 2020): Images are split into 16×16 pixel patches, each flattened into a vector and treated as a "token." Self-attention then runs over these patch tokens identically to text. ViT achieved competitive performance with convolutional neural networks on ImageNet while being more scalable — larger ViTs trained on more data consistently outperform CNNs.

AlphaFold 2 (DeepMind, 2021): Uses a specialised attention mechanism called "Evoformer" that attends simultaneously over amino acid sequences and the 2D pairwise distance map between residues. Cross-attention between these two representations is what lets AlphaFold reason about how distant parts of a protein chain interact in 3D space. Its 0.96 Å median accuracy essentially solved protein structure prediction.

Whisper (OpenAI, 2022): An encoder-decoder transformer for speech-to-text. The encoder runs self-attention over spectrogram patches; the decoder uses cross-attention to the encoder states while generating text tokens. Trained on 680,000 hours of labelled audio, it achieves near-human transcription accuracy across 99 languages.


⚖️ The O(n²) Problem: Cost, Memory, and Sparse Alternatives

Attention is powerful, but its complexity profile creates real engineering constraints.

Time complexity: Computing $QK^T$ for a sequence of $n$ tokens requires $O(n^2 \cdot d_k)$ multiplications — quadratic in sequence length. For $n = 1000$ tokens that is one million score computations per head per layer. For $n = 100{,}000$ (a long document), it becomes ten billion — impractical on standard hardware.

Memory complexity: Storing the $n \times n$ attention weight matrix $A$ requires $O(n^2)$ space. With 12 layers and 12 heads, a BERT-base model operating on a 512-token sequence materialises 12 × 12 × 512 × 512 = ~38 million floats per batch element just for attention weights.

Gradient throughput: During training, the $n^2$ attention matrix must be stored for the backward pass, doubling memory pressure. FlashAttention (Dao et al., 2022) addresses this by reordering the computation to process attention in tiles, never materialising the full matrix, reducing memory to $O(n)$ while maintaining exact output — a pure engineering win with no quality trade-off.

ApproachComplexityQualityNotes
Standard (dense) attention$O(n^2)$ExactBaseline; impractical for n > 4096
FlashAttention v2$O(n^2)$ compute, $O(n)$ memoryExactTile-based reordering; no approximation
Sparse attention (Longformer)$O(n \cdot w)$ — w = windowApproxLocal window + global tokens
Linear attention (Performer)$O(n)$ApproxRandom feature kernel approximation
Flash-Decoding (inference)$O(n)$ memoryExactParallelises the KV-cache decode loop

For most practical applications — sequences under 8,192 tokens — FlashAttention solves the performance problem without compromising quality. Sparse attention variants are worth considering for document-level tasks with sequences in the tens of thousands.


🧭 Choosing the Right Attention Variant for Your Architecture

Not every task needs full dense self-attention, and picking the wrong variant adds complexity without benefit.

SituationRecommendationReason
Sequence-to-sequence (translation, summarisation)Self-attention encoder + cross-attention decoderCross-attention bridges two sequences efficiently; encoder gives decoder global source context
Language model pretraining or generationCausal self-attention decoder onlyAutoregressive generation requires masking; encoder overhead adds nothing for left-to-right tasks
Bidirectional text understanding (classification, NER)Full (non-causal) self-attention encoderEvery token needs full context; masking future tokens would discard useful signal
Long documents (> 4096 tokens)FlashAttention + sliding window (Longformer/BigBird)Dense $O(n^2)$ becomes impractical; local + global token patterns preserve most relevant signal
Image understandingPatch-based self-attention (ViT)Treats patches as tokens; captures non-local correlations CNNs miss without inductive locality bias
Protein / structured data with pairwise relationsSpecialised cross-attention (Evoformer pattern)Attending over both sequence and pairwise feature matrix enables 3D relational reasoning

🛠️ PyTorch and HuggingFace: How They Implement Attention

Building Scaled Dot-Product Attention from Scratch

The best way to internalise the mechanism is to implement it yourself before relying on library abstractions. This PyTorch implementation is fully runnable and adds only standard library imports.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Scaled Dot-Product Attention.

    Args:
        Q: (batch, heads, seq_q, d_k)
        K: (batch, heads, seq_k, d_k)
        V: (batch, heads, seq_k, d_v)
        mask: (batch, 1, 1, seq_k) or (batch, 1, seq_q, seq_k) — 0 = keep, 1 = mask out

    Returns:
        output: (batch, heads, seq_q, d_v)
        attn_weights: (batch, heads, seq_q, seq_k)
    """
    d_k = Q.size(-1)

    # Step 1: raw attention scores — (batch, heads, seq_q, seq_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Step 2: apply mask (e.g., causal mask for decoder, padding mask)
    if mask is not None:
        scores = scores.masked_fill(mask == 1, float("-inf"))

    # Step 3: softmax over key dimension
    attn_weights = F.softmax(scores, dim=-1)

    # Step 4: weighted sum of values
    output = torch.matmul(attn_weights, V)

    return output, attn_weights

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention from 'Attention Is All You Need' (Vaswani et al., 2017).
    """

    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # head dimension

        # Separate projections for Q, K, V, and output
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """(batch, seq, d_model) -> (batch, heads, seq, d_k)"""
        batch, seq, _ = x.shape
        x = x.view(batch, seq, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            query: (batch, seq_q, d_model)
            key:   (batch, seq_k, d_model)
            value: (batch, seq_k, d_model)
            mask:  optional attention mask

        Returns:
            output: (batch, seq_q, d_model)
            attn_weights: (batch, heads, seq_q, seq_k)
        """
        batch = query.size(0)

        # Project inputs to Q, K, V and split into heads
        Q = self.split_heads(self.W_q(query))  # (batch, heads, seq_q, d_k)
        K = self.split_heads(self.W_k(key))    # (batch, heads, seq_k, d_k)
        V = self.split_heads(self.W_v(value))  # (batch, heads, seq_k, d_k)

        # Scaled dot-product attention across all heads in parallel
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        # attn_output: (batch, heads, seq_q, d_k)

        # Merge heads: (batch, heads, seq_q, d_k) -> (batch, seq_q, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch, -1, self.d_model)

        # Final linear projection
        output = self.W_o(attn_output)

        return output, attn_weights

# ── Quick smoke test ──────────────────────────────────────────────────────────
if __name__ == "__main__":
    torch.manual_seed(42)

    batch, seq_len, d_model, num_heads = 2, 6, 64, 8
    x = torch.randn(batch, seq_len, d_model)

    mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
    output, weights = mha(x, x, x)  # self-attention: Q=K=V=x

    print(f"Input shape:         {x.shape}")          # (2, 6, 64)
    print(f"Output shape:        {output.shape}")      # (2, 6, 64)
    print(f"Attention weights:   {weights.shape}")     # (2, 8, 6, 6)
    print(f"Weights sum to 1.0:  {weights[0, 0].sum(dim=-1).allclose(torch.ones(seq_len))}")

What to observe when you run this:

  • output.shape equals input.shape — the attention mechanism is shape-preserving; it contextualises each token but keeps dimensionality constant.
  • weights.shape is (batch, heads, seq_q, seq_k) — you can visualise row 0 of any head as the attention distribution for token 0 over all tokens.
  • The sum of each row in weights is exactly 1.0, confirming softmax normalisation.

Building a Causal Mask for Decoder Self-Attention

Decoder-only models like GPT require a causal mask so token $i$ cannot attend to token $j > i$ (future positions):

def causal_mask(seq_len: int) -> torch.Tensor:
    """
    Upper-triangular mask: positions where mask == 1 are masked out.
    Shape: (1, 1, seq_len, seq_len) — broadcasts over batch and heads.
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask.unsqueeze(0).unsqueeze(0).long()

# Usage — decoder self-attention on 6 tokens
mask = causal_mask(6)
output, weights = mha(x, x, x, mask=mask)

# Token 0 can only attend to itself; token 5 can attend to all 6.
print("Causal weights row 0 (only position 0 non-zero):", weights[0, 0, 0])
print("Causal weights row 5 (all 6 positions non-zero):", weights[0, 0, 5])

Using nn.MultiheadAttention in One Line

PyTorch ships a fused, optimised attention implementation. After understanding the from-scratch version above, this is what you should use in production:

# PyTorch built-in — CUDA-optimised, supports FlashAttention backend
mha_builtin = nn.MultiheadAttention(
    embed_dim=64,
    num_heads=8,
    dropout=0.1,
    batch_first=True,   # (batch, seq, features) — matches our convention
)

output_bt, weights_bt = mha_builtin(x, x, x)
print(f"nn.MultiheadAttention output: {output_bt.shape}")  # (2, 6, 64)

Pass attn_mask (additive) or key_padding_mask (boolean) for masking. PyTorch >= 2.0 automatically dispatches to FlashAttention when the inputs are on CUDA and no custom mask is provided.

HuggingFace Transformers: Accessing Attention Weights from BERT

In production, you often want to inspect what a pretrained model attends to rather than implement attention from scratch:

from transformers import BertTokenizer, BertModel
import torch

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased", output_attentions=True)
model.eval()

text = "The cat sat on the mat"
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# outputs.attentions: tuple of 12 tensors, one per layer
# Each tensor shape: (batch=1, heads=12, seq_len, seq_len)
layer_0_weights = outputs.attentions[0]
print(f"Layer 0 attention shape: {layer_0_weights.shape}")
# (1, 12, 8, 8) — 8 tokens including [CLS] and [SEP]

# Head 0, token 2 ("cat") attending to all tokens
cat_attention = layer_0_weights[0, 0, 2]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
for tok, weight in zip(tokens, cat_attention.tolist()):
    print(f"  {tok:12s}: {weight:.4f}")

For a full deep-dive on BERT's architecture and pretraining objective, see how-transformer-architecture-works-a-deep-dive.


📚 Common Bugs That Will Waste Your Weekend

Attention is simple in theory but has several sharp edges in practice. These are the mistakes that show up most often in from-scratch implementations and fine-tuning pipelines.

1. Forgetting to scale by $\sqrt{d_k}$ This is the most common first-implementation bug. Without the scaling factor, dot products between 64-dimensional vectors can reach magnitudes of 30–50, which pushes softmax to output near-zero weights for all tokens except the one with the highest score (essentially argmax). The model appears to train but converges to a mode where it attends to one token rigidly. Fix: always divide raw scores by math.sqrt(d_k) before softmax.

2. Wrong mask shape — silent broadcasting errors PyTorch masks broadcast silently. If your causal mask is (seq, seq) but attention scores are (batch, heads, seq, seq), PyTorch will broadcast the mask instead of raising an error — but the result may be wrong if the mask has a batch or head dimension that does not match. Always shape masks as (batch, 1, 1, seq_k) for key-padding masks or (1, 1, seq_q, seq_k) for causal masks so the broadcast is explicit and intentional.

3. Wrong Q, K, V projection shapes in multi-head attention A common mistake is projecting Q, K, and V to d_model before splitting into heads, but accidentally projecting to num_heads × d_model (off by a factor of num_heads). Always verify: each head should receive a d_k = d_model / num_heads dimensional vector, and split_heads should reshape (batch, seq, d_model) into (batch, heads, seq, d_k) — not (batch, seq, heads, d_model).

4. Using additive masking when the API expects multiplicative (or vice versa) nn.MultiheadAttention uses additive masks by default — you pass -inf or a large negative float for positions to ignore. HuggingFace's models often use boolean key-padding masks where True means "ignore this position." Mixing these conventions produces gradients that are numerically valid but semantically wrong: the model silently attends to padding tokens or future tokens.

5. Not detaching attention weights before accumulating them If you collect attn_weights across multiple forward passes for visualisation (e.g., building an attention heatmap), forgetting to call .detach() keeps a reference to the full computation graph. Memory usage grows linearly with the number of forward passes until the process OOMs. Always call weights.detach().cpu().numpy() when storing for analysis.

6. Positional encoding missing or added after projection Positional encodings must be added to token embeddings before the Q/K/V projections, not after. Adding them after means Q and K do not carry positional information, making all positions look the same to the attention mechanism — the model degrades to a bag-of-words behaviour for any position-sensitive task.


📌 TLDR & Key Takeaways

  • Attention is a differentiable database lookup. Q finds matching Ks; V provides the content. The softmax output is a weighted blend — not a hard selection — so gradients flow everywhere.
  • The $\sqrt{d_k}$ scaling factor is non-negotiable. Skip it and your softmax saturates, gradients vanish, and training stalls — even though the loss may still decrease slowly.
  • Multi-head attention gives the model parallel "lenses." Each head learns different relationship types (syntax, coreference, adjacency). More heads = more capacity to capture diverse relationships simultaneously.
  • Self-attention and cross-attention use the same formula. The difference is whether Q, K, V come from one sequence or two. Causal masking is self-attention with an upper-triangular mask applied.
  • The $O(n^2)$ complexity is the practical ceiling. Dense attention is impractical for sequences beyond ~8k tokens. FlashAttention eliminates the memory problem while keeping exact arithmetic; sparse variants trade some quality for linear scaling.
  • Every major modern AI system uses attention. GPT, BERT, ViT, AlphaFold, Whisper — the mechanism is universal because it makes no assumptions about input modality, locality, or sequence length structure.
  • The most actionable skill is being able to read attention weights. Visualising outputs.attentions from a HuggingFace model will immediately show you what your model has and has not learned to focus on.

📝 Practice Quiz

  1. Why is the raw dot product between Q and K divided by $\sqrt{d_k}$ before the softmax?

    • A) To normalise the output to the range [0, 1] before the weighted sum
    • B) To prevent large dot products from pushing softmax into near-zero gradient regions
    • C) To reduce the memory required by the attention weight matrix
    • D) To ensure Q and K have the same L2 norm after projection Correct Answer: B
  2. In an encoder-decoder transformer used for translation, the decoder uses cross-attention in the middle sub-layers. Which of the following describes the correct source of Q, K, and V in that cross-attention layer?

    • A) Q from encoder output, K and V from decoder previous layer
    • B) Q, K, and V all from decoder previous layer
    • C) Q from decoder previous layer, K and V from encoder output
    • D) Q from decoder previous layer, K from encoder output, V from decoder previous layer Correct Answer: C
  3. A model is trained on sequences of length 512 and then deployed to process documents of length 4096. After switching to FlashAttention, which problem does FlashAttention directly solve, and which problem does it NOT solve?

    • A) Solves: quadratic compute time. Does NOT solve: output quality degradation from longer sequences
    • B) Solves: quadratic memory by tiling the attention computation. Does NOT solve: quadratic compute time — that remains O(n²)
    • C) Solves: both compute and memory — FlashAttention is O(n) in both dimensions
    • D) Solves: positional encoding generalisation to longer sequences. Does NOT solve: memory usage Correct Answer: B
  4. Open-ended: You are debugging a multi-head attention module you wrote from scratch. The training loss decreases, but when you visualise the attention weights across all heads and all layers, almost every head appears to attend nearly uniformly to every token (all weights ≈ 1/n). What are two distinct root causes that could produce this behaviour, and what would you check in the code for each?


Abstract Algorithms

Written by

Abstract Algorithms

@abstractalgorithms