Tech17 min read3,260 words

Transformer Architecture and Attention: Why Every Modern LLM Is Built This Way

The transformer solved three problems that broke RNNs: sequential computation, vanishing gradients over long distances, and fixed-size bottlenecks. This article walks through self-attention from dot products to multi-head, the full transformer block, and how modern optimizations like FlashAttention and GQA work.

K

Krunal Kanojiya

Share:
#transformer#attention#self-attention#multi-head-attention#flash-attention#gqa#llm#deep-learning#nlp

Article 5 ended with a specific promise. RNNs broke for three reasons: sequential computation that GPUs could not parallelize, gradients that vanished over long sequences, and a fixed-size bottleneck that compressed all context into one vector. The transformer removes all three at once.

This is the architecture that every modern LLM runs on. GPT-4, Claude, Gemini, LLaMA, Mistral, DeepSeek. All of them are variations on the same basic blueprint introduced in the 2017 paper "Attention Is All You Need." Understanding this blueprint is not optional background knowledge. It is the mechanism you are working with every time you build on top of an LLM.

This is Article 6 in the series. Article 7 covers pre-training and language modeling, where the transformer you build here gets trained on text and becomes a language model. The embedding vectors from Article 4 flow directly into this architecture. The gradient problems from Articles 3 and 5 get solved here by residual connections.


The core idea: everything attends to everything

In an RNN, token at position 5 connects to token at position 1 by passing information through positions 2, 3, and 4 sequentially. Each step degrades the signal.

In a transformer, token 5 connects to token 1 directly. There is no chain. There is a matrix of dot products between every pair of tokens, and every token can read every other token in a single operation. That is the structural change that made everything else possible.

No sequential dependency. Gradient from any output position reaches any input position without intermediate multiplications. Parallelizable across the entire sequence length.


Self-attention: dot products as relevance scores

Start with a single attention head to understand what is actually happening.

Each token's embedding vector gets projected into three vectors:

  • Query (Q): what this token is looking for
  • Key (K): what this token has to offer
  • Value (V): what information this token sends when attended to

The attention score between two tokens is the dot product of one token's query and the other token's key. High dot product means high relevance. The scores get scaled and passed through softmax to become weights. Those weights are applied to value vectors to produce the output.

python
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    mask: torch.Tensor = None
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    q, k, v: (batch, seq_len, head_dim)
    mask:    (batch, 1, seq_len, seq_len) or None
    returns: output (batch, seq_len, head_dim), weights (batch, seq_len, seq_len)
    """
    head_dim = q.shape[-1]

    # dot product between every query and every key
    # (B, T, D) @ (B, D, T) -> (B, T, T)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)

    # causal mask: future tokens get -inf so softmax makes them zero
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # softmax turns scores into probabilities summing to 1
    weights = F.softmax(scores, dim=-1)

    # weighted sum of values
    output = torch.matmul(weights, v)

    return output, weights

# concrete example: 3 tokens, head_dim=4
torch.manual_seed(42)
q = torch.randn(1, 3, 4)
k = torch.randn(1, 3, 4)
v = torch.randn(1, 3, 4)

out, weights = scaled_dot_product_attention(q, k, v)
print("attention weights:")
print(weights.squeeze(0))
# Each row sums to 1 — these are the "how much does token i attend to token j" scores
print("\noutput shape:", out.shape)   # (1, 3, 4)

The scaling by sqrt(head_dim) matters. Without it, dot products between large-dimensional vectors can get very large, which pushes softmax into a saturation zone where gradients become tiny. Dividing by the square root of the dimension keeps the scores in a reasonable range regardless of how large head_dim is.


The causal mask: why this is a language model

The attention above lets every token attend to every other token in both directions. That works for tasks like sentiment classification or text encoding. It does not work for language generation, because at generation time, token 5 does not know what token 6 will be.

The causal mask enforces this constraint during training by zeroing out all attention from earlier positions to later positions.

python
def make_causal_mask(seq_len: int) -> torch.Tensor:
    """
    Returns a lower-triangular boolean mask.
    True = allowed to attend, False = blocked.
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
    return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, T, T) for broadcasting

seq_len = 5
mask = make_causal_mask(seq_len)
print(mask.squeeze())
# tensor([[True, False, False, False, False],
#         [True,  True, False, False, False],
#         [True,  True,  True, False, False],
#         [True,  True,  True,  True, False],
#         [True,  True,  True,  True,  True]])

Token 0 attends only to itself. Token 3 attends to tokens 0 through 3. Token 4 attends to everything. After masking, softmax never assigns weight to future positions because their scores were set to negative infinity before softmax ran. The model cannot cheat.

This is the structural guarantee that makes GPT-style models work for generation. During training the model sees the full sequence but is forced to predict each token using only past context. At inference time it generates one token, appends it to the sequence, and repeats.


Multi-head attention: learning different relationships in parallel

One attention head learns one pattern of relationships between tokens. Multi-head attention runs several in parallel, each in a lower-dimensional subspace. Outputs get concatenated and projected back.

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        assert embed_dim % num_heads == 0

        self.num_heads = num_heads
        self.head_dim  = embed_dim // num_heads
        self.embed_dim = embed_dim

        # project input to Q, K, V for all heads at once
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        # output projection: combines all head outputs
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.dropout  = nn.Dropout(dropout)

        # causal mask registered as buffer (not a parameter)
        self.register_buffer(
            'causal_mask',
            torch.tril(torch.ones(1024, 1024, dtype=torch.bool))
                 .view(1, 1, 1024, 1024)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape

        # compute Q, K, V
        q = self.q_proj(x)  # (B, T, embed_dim)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # split embed_dim into num_heads * head_dim and rearrange
        # (B, T, embed_dim) -> (B, num_heads, T, head_dim)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # scaled dot-product attention with causal mask
        scale   = math.sqrt(self.head_dim)
        scores  = (q @ k.transpose(-2, -1)) / scale          # (B, H, T, T)
        scores  = scores.masked_fill(
            ~self.causal_mask[:, :, :T, :T], float('-inf')
        )
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        # weighted combination of values
        out = weights @ v                                      # (B, H, T, head_dim)

        # concatenate heads and project back
        out = out.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, embed_dim)
        return self.out_proj(out)

# test
mha = MultiHeadAttention(embed_dim=128, num_heads=4)
x   = torch.randn(2, 16, 128)   # batch=2, seq_len=16, dim=128
out = mha(x)
print(out.shape)   # torch.Size([2, 16, 128])

params = sum(p.numel() for p in mha.parameters())
print(f"attention parameters: {params:,}")   # 4 linear layers * 128*128 = 65,536

Each head operates in 32 dimensions (128 / 4). Head 0 might learn to track subject-verb agreement. Head 1 might focus on nearby context. Head 2 might handle long-range pronoun references. None of this is programmed explicitly. It emerges from training.

The key insight about the shape math: embed_dim = num_heads * head_dim. The total computation is the same as a single large attention head, but the model gets to learn multiple different attention patterns simultaneously rather than forcing everything through one.


The full transformer block

One attention layer is not a transformer. The full block adds a feed-forward network, residual connections, and LayerNorm.

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
        super().__init__()

        self.ln1  = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)

        self.ln2 = nn.LayerNorm(embed_dim)
        self.ff  = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # pre-norm: normalize BEFORE the sublayer, then add residual
        x = x + self.attn(self.ln1(x))   # attention sublayer
        x = x + self.ff(self.ln2(x))     # feed-forward sublayer
        return x

# one block
block = TransformerBlock(embed_dim=128, num_heads=4)
x    = torch.randn(2, 16, 128)
out  = block(x)
print(out.shape)   # torch.Size([2, 16, 128])

Two things worth stopping on here.

Pre-norm vs Post-norm. The original 2017 paper applied LayerNorm after the sublayer. Modern practice applies it before. Pre-norm trains more stably, especially at large depth. If you look at LLaMA or Mistral source code, you will see self.norm1(x) called inside the residual computation, not after it.

Residual connections. The x + sublayer(x) pattern is not incidental. It gives gradients a direct path to early layers during backpropagation. Without residuals, a 12-layer transformer would suffer from the same vanishing gradient problem Article 3 introduced. The gradient from the final layer's loss flows back through the residual addition at each block, bypassing the sublayer entirely if needed. This is what makes deep networks trainable.

The feed-forward network is where most of the model's "knowledge" is stored. Attention decides which tokens to look at. The FFN decides what to do with that information per token, processing each position independently. Modern LLMs replace the simple GELU FFN here with SwiGLU, as covered in Article 3.


The full decoder-only transformer

Stack several blocks, add embedding and positional encoding at the front, add a language model head at the end. This is what GPT, LLaMA, Mistral, and Claude's underlying architecture all look like at their core.

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class DecoderOnlyTransformer(nn.Module):
    def __init__(
        self,
        vocab_size:  int,
        embed_dim:   int,
        num_heads:   int,
        num_layers:  int,
        max_seq_len: int,
        dropout:     float = 0.1,
    ):
        super().__init__()

        # token and positional embeddings (Article 4)
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb   = nn.Embedding(max_seq_len, embed_dim)
        self.drop      = nn.Dropout(dropout)

        # stack of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])

        # final layer norm before the head
        self.ln_f = nn.LayerNorm(embed_dim)

        # language model head: maps embed_dim back to vocabulary
        # weight tied to token_emb (halves parameters, improves performance)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        self.lm_head.weight = self.token_emb.weight

    def forward(
        self,
        token_ids: torch.Tensor,
        targets:   torch.Tensor = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        B, T = token_ids.shape

        # embed tokens and positions
        tok = self.token_emb(token_ids)
        pos = self.pos_emb(torch.arange(T, device=token_ids.device))
        x   = self.drop(tok + pos)

        # pass through all transformer blocks
        for block in self.blocks:
            x = block(x)

        x      = self.ln_f(x)
        logits = self.lm_head(x)   # (B, T, vocab_size)

        loss = None
        if targets is not None:
            # predict next token: flatten batch and seq dims
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )

        return logits, loss

# NanoGPT-scale config from the series
model = DecoderOnlyTransformer(
    vocab_size  = 50_257,
    embed_dim   = 128,
    num_heads   = 4,
    num_layers  = 4,
    max_seq_len = 256,
)

total_params = sum(p.numel() for p in model.parameters())
print(f"total parameters: {total_params:,}")   # ~10M

# forward pass with loss
token_ids = torch.randint(0, 50_257, (2, 64))
targets   = torch.randint(0, 50_257, (2, 64))
logits, loss = model(token_ids, targets)

print(f"logits shape: {logits.shape}")   # (2, 64, 50257)
print(f"loss: {loss.item():.4f}")        # ~10.82 = ln(50257), random weights

That loss starting at 10.82 is the same number from the NanoGPT article. Maximum uncertainty across a 50,257-token vocabulary. Article 7 covers what happens when you train this on actual text and watch that number fall.


What is actually expensive and how modern models fix it

Standard attention has O(N squared) complexity in sequence length. For N=512 that is 262,144 attention score pairs per head per layer. For N=32,768 that is over a billion. This is why attention dominated short sequences but struggled with long ones.

Three optimizations changed the economics. You will see all three in modern production LLMs.

FlashAttention

The standard attention implementation materializes the full N×N score matrix in GPU high-bandwidth memory (HBM). HBM reads and writes are the bottleneck, not the computation.

FlashAttention is an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads and writes between GPU HBM and on-chip SRAM. Instead of computing the full matrix at once, it processes it in tiles that fit in SRAM. Intermediate results never go back to HBM. The final output is the same as standard attention, no approximation, but memory usage is linear in sequence length rather than quadratic.

FlashAttention 3, released in 2024, specifically targets NVIDIA's Hopper architecture (H100 GPUs) and exploits the asynchronous nature of Hopper's Tensor Cores and TMA (Tensor Memory Accelerator) to overlap computation and data movement.

In practice: enable FlashAttention in PyTorch with one flag.

python
import torch
import torch.nn.functional as F

# PyTorch 2.0+ includes SDPA with FlashAttention backend
# Replace manual attention computation with:
with torch.backends.cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False
):
    # q, k, v: (batch, num_heads, seq_len, head_dim)
    # is_causal=True applies the causal mask automatically
    out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

# On modern hardware this is 2-4x faster than naive attention
# and uses linear memory instead of quadratic

You do not need to implement FlashAttention yourself. But knowing what it does and why it matters tells you something important: attention's quadratic complexity is an engineering problem, not a fundamental limit. The math is the same. The implementation changes what is tractable.

Grouped Query Attention (GQA)

Standard multi-head attention stores one key and value vector per head per token in the KV cache. For a 70B model with 64 attention heads, that is enormous. During inference with long context, the KV cache often consumes more GPU memory than the model weights.

Grouped Query Attention shares a single set of key and value heads across a group of query heads. If you have 32 query heads grouped into 8 groups, you only cache 8 KV pairs per token instead of 32. The quality loss is small. The memory savings are large.

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class GroupedQueryAttention(nn.Module):
    """
    GQA: multiple query heads share one key/value head per group.
    Used in LLaMA 2 70B, Mistral 7B, and most large open models.
    """
    def __init__(
        self,
        embed_dim:  int,
        num_q_heads: int,   # e.g. 32
        num_kv_heads: int,  # e.g. 8  (must divide num_q_heads)
    ):
        super().__init__()
        assert num_q_heads % num_kv_heads == 0
        self.num_q_heads  = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim     = embed_dim // num_q_heads
        self.groups       = num_q_heads // num_kv_heads

        self.q_proj  = nn.Linear(embed_dim, num_q_heads  * self.head_dim, bias=False)
        self.k_proj  = nn.Linear(embed_dim, num_kv_heads * self.head_dim, bias=False)
        self.v_proj  = nn.Linear(embed_dim, num_kv_heads * self.head_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, _ = x.shape

        q = self.q_proj(x).view(B, T, self.num_q_heads,  self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)

        # repeat KV heads to match Q head count
        k = k.repeat_interleave(self.groups, dim=1)   # (B, num_q_heads, T, head_dim)
        v = v.repeat_interleave(self.groups, dim=1)

        # standard scaled dot-product from here
        scale   = math.sqrt(self.head_dim)
        scores  = (q @ k.transpose(-2, -1)) / scale
        # (causal mask omitted for brevity)
        weights = F.softmax(scores, dim=-1)
        out     = weights @ v

        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.out_proj(out)

# 32 query heads, 8 KV heads — LLaMA 2 70B configuration
gqa = GroupedQueryAttention(embed_dim=4096, num_q_heads=32, num_kv_heads=8)
x   = torch.randn(1, 128, 4096)
out = gqa(x)
print(out.shape)   # torch.Size([1, 128, 4096])

# KV cache size comparison
full_mha_kv = 32 * 128 * 64   # num_heads * seq_len * head_dim
gqa_kv      = 8  * 128 * 64   # num_kv_heads * seq_len * head_dim
print(f"MHA KV cache: {full_mha_kv:,} values")   # 262,144
print(f"GQA KV cache: {gqa_kv:,} values")         # 65,536 — 4x reduction

Multi-Head Latent Attention (MLA): what DeepSeek did

GQA reduces KV heads. DeepSeek took a different approach and asked: why are we storing the full, high-dimensional key and value vectors at all? The core insight behind Multi-Head Latent Attention is that the high-dimensional key and value vectors might contain redundant information, and the essence of what the model needs to remember from a past token could be captured in a much smaller space.

MLA compresses the key and value representations into a shared low-rank latent vector. During inference, only that latent vector is cached. When computing attention, it is decompressed back to full key and value heads. DeepSeek-V2 with MLA reduces the KV cache size by 93.3% compared to their dense MHA-based model.

The approach is conceptually similar to how LoRA works for fine-tuning: instead of storing a large matrix, store a low-rank factorization and reconstruct when needed.


Encoder-only, decoder-only, encoder-decoder: what the difference is

Three transformer variants exist and they are suited to different tasks.

Decoder-only (GPT, LLaMA, Mistral, Claude): causal masking everywhere, generates text autoregressively. The architecture you built in this article. Best for: text generation, code, instruction following.

Encoder-only (BERT, RoBERTa): no causal mask, every token attends to every other token in both directions. No generation. Best for: classification, named entity recognition, sentence embeddings.

Encoder-decoder (T5, BART): encoder processes the input with full bidirectional attention, decoder generates the output with causal attention and cross-attention to the encoder. Best for: translation, summarization, tasks where input and output are distinct sequences.

If you are building a product that generates text, you are using a decoder-only model. If you are building a retrieval system with dense embeddings, you might be using an encoder-only model fine-tuned for embeddings. Most of what developers encounter is decoder-only.


What residual connections actually solve

This is something I did not appreciate until I tried removing them.

In a 12-layer transformer, the gradient from the final layer's loss needs to reach the first layer's weights. Without residuals it travels through 12 attention sublayers and 12 FFN sublayers, each multiplying the gradient by the sublayer's local Jacobian. Small Jacobians cause vanishing. Large ones cause exploding.

With residuals, the gradient can bypass every sublayer entirely through the addition path. x + sublayer(x) means the gradient of the output with respect to x includes a direct 1 term regardless of what the sublayer does. Early layers always get a meaningful gradient signal.

This is the same principle that made residual networks (ResNets) work in computer vision. Transformers inherited it and it is why you can train 96-layer models without gradient collapse.


The connection to pre-training and Article 7

The model you built here is a function. It takes token IDs and produces logit distributions over the vocabulary. Right now its weights are random and its loss is ln(50257) ≈ 10.82.

Article 7 covers what happens when you train this on hundreds of billions of tokens of text: how the next-token prediction objective shapes the weights, what scaling laws say about how much data and compute you need, and how a model trained this way develops capabilities that nobody explicitly programmed.

The architecture stays the same. The training is what turns it from a random function into something useful.

The one thing worth verifying yourself

Run the full model forward pass above and check the initial loss. Then look at the model's logits before training. Every token gets roughly equal probability because the weights are random. That uniform distribution is the baseline. Every improvement in training is movement away from that starting point. Watching the loss fall from 10.82 is watching the model move away from ignorance, one gradient step at a time.


Next in the series

Article 7 covers pre-training and language modeling. You will see how the decoder-only transformer from this article gets trained on raw text using next-token prediction, what loss curves actually mean at scale, and how capabilities like in-context learning and reasoning emerge from this training process. The question of why LLMs can do things nobody taught them explicitly starts there.

Follow on Google

Add as a preferred source in Search & Discover

Add as preferred source
Appears in Google Discover
K

Krunal Kanojiya

Technical Content Writer

Technical Content Writer and former software developer from India. I write in-depth articles on blockchain, AI/ML, data engineering, web development, and developer careers. Currently at Lucent Innovation, previously at Cromtek Solution and freelance.

Related Posts