Tech15 min read2,849 words

Sequence Modeling and RNNs: The Problem That Made Transformers Inevitable

Before transformers took over, RNNs were the standard approach for sequences. Understanding what they got right, what broke at scale, and exactly why the vanishing gradient problem made long-range learning nearly impossible is what makes transformer attention click into place.

K

Krunal Kanojiya

Share:
#rnn#lstm#gru#sequence-modeling#vanishing-gradient#backpropagation-through-time#deep-learning#nlp

If you have read Articles 3 and 4 in this series, you have seen how backpropagation computes gradients and how embeddings turn token IDs into dense vectors. This article takes those two things and shows what happens when you try to build a model that processes those vectors over time, one step at a time, remembering what came earlier.

That model is the RNN. And the story of RNNs is really the story of one problem that kept getting worse as sequences got longer, until the transformer architecture made it irrelevant.

You need to understand that story. Not because you will build RNNs for production NLP in 2026, but because the problem RNNs failed to solve is the exact problem attention mechanisms were designed to fix. Article 6 covers transformers and attention. You will not fully understand why attention works the way it does if you skip this article.


What a recurrent network actually does

A feedforward network processes one input and produces one output. No memory. The same input always produces the same output regardless of what came before.

An RNN is different. It processes a sequence one element at a time and carries a hidden state vector across steps. Each new hidden state is computed from the current input and the previous hidden state. Same weights, every step.

python
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        # weights applied to current input
        self.W_x = nn.Linear(input_dim, hidden_dim, bias=False)
        # weights applied to previous hidden state
        self.W_h = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bias = nn.Parameter(torch.zeros(hidden_dim))

    def forward(self, x: torch.Tensor):
        """
        x: (batch_size, seq_len, input_dim)
        returns: all hidden states (batch_size, seq_len, hidden_dim)
        """
        B, T, _ = x.shape
        hidden_dim = self.W_h.out_features

        h = torch.zeros(B, hidden_dim)   # initial hidden state
        outputs = []

        for t in range(T):
            # combine current input with previous hidden state
            h = torch.tanh(self.W_x(x[:, t, :]) + self.W_h(h) + self.bias)
            outputs.append(h)

        return torch.stack(outputs, dim=1)   # (B, T, hidden_dim)

# token embeddings from Article 4 feed directly into this
rnn = SimpleRNN(input_dim=128, hidden_dim=256)
x = torch.randn(4, 20, 128)   # batch=4, seq_len=20, embed_dim=128
out = rnn(x)
print(out.shape)   # torch.Size([4, 20, 256])

This is cleaner than the PyTorch built-in nn.RNN but does the same thing. At every step t, the model sees the current token embedding and the hidden state from step t-1. The hidden state theoretically carries all relevant history.

The key word is theoretically.


The sequential bottleneck

Before getting to gradients, there is a practical problem worth naming directly.

Look at the for t in range(T) loop. Step 2 cannot start until step 1 finishes. Step 50 cannot start until steps 1 through 49 are done. This is inherently sequential, and there is no architectural fix for it. The hidden state creates a dependency chain across the entire sequence.

Modern GPUs are built for parallelism. They can run thousands of computations simultaneously. An RNN cannot use that capability within a single sequence. Every token waits in line.

For a sequence of 512 tokens, you have 512 sequential matrix multiplications. For 4096 tokens, 4096. This was manageable in 2015 when sequences were short. It became a serious training bottleneck as models scaled.

Transformers process all positions simultaneously. That single architectural difference accounts for a large part of why training went from days to hours at similar scales.


Backpropagation through time

Training an RNN uses a modified version of backpropagation called backpropagation through time (BPTT). The loss at the end of the sequence needs to send gradients back through every time step to update the shared weights.

python
import torch
import torch.nn as nn
import torch.nn.functional as F

# built-in RNN for clarity
rnn = nn.RNN(input_size=16, hidden_size=32, batch_first=True)
output_head = nn.Linear(32, 5)   # 5-class classification

x = torch.randn(8, 30, 16)   # batch=8, seq=30, input=16
y = torch.randint(0, 5, (8,))

# forward pass: runs through all 30 steps
out, h_n = rnn(x)

# use the final hidden state for classification
logits = output_head(h_n.squeeze(0))
loss = F.cross_entropy(logits, y)

# backward pass: gradient flows back through all 30 steps
loss.backward()

# look at gradient norms across time to see what's happening
print(f"rnn weight_ih grad norm: {rnn.weight_ih_l0.grad.norm():.6f}")
print(f"rnn weight_hh grad norm: {rnn.weight_hh_l0.grad.norm():.6f}")

For a 30-step sequence this is fine. For 300 steps, the problem becomes visible. For 3,000 steps, the gradients reaching early time steps are effectively zero.


The vanishing gradient: concrete numbers

Here is the mathematics of why this breaks. At each step, the gradient is multiplied by the tanh derivative and the recurrent weight matrix. The tanh derivative has a maximum of 1 but in practice averages around 0.6. Every step of backward pass shrinks the gradient by that factor.

python
import torch
import math

# simulating gradient decay through time steps

def simulate_gradient_decay(steps: int, avg_tanh_deriv: float = 0.6):
    gradient_magnitude = 1.0
    for t in range(steps):
        # at each step, gradient is multiplied by tanh derivative
        # (simplified; real case also involves weight matrix)
        gradient_magnitude *= avg_tanh_deriv

    return gradient_magnitude

for seq_len in [10, 30, 50, 100, 200]:
    decay = simulate_gradient_decay(seq_len)
    print(f"seq_len={seq_len:4d} | gradient at step 0: {decay:.2e}")
plaintext
seq_len=  10 | gradient at step 0: 6.05e-03
seq_len=  30 | gradient at step 0: 2.19e-07
seq_len=  50 | gradient at step 0: 7.94e-12
seq_len= 100 | gradient at step 0: 6.30e-23
seq_len= 200 | gradient at step 0: 3.98e-45

At 100 steps, the gradient has shrunk to 6e-23. That is not small in a casual sense. That is so small that float32 arithmetic cannot represent the change. Weights at early time steps receive no meaningful update. The model cannot learn to connect distant parts of a sequence.

This is the vanishing gradient problem, first identified formally by Bengio, Simard, and Frasconi in 1994. It took three more years for Hochreiter and Schmidhuber to build the architecture that addressed it.

Why this matters for product building

If you have ever wondered why an LLM seems to "forget" context from early in a very long conversation, you are seeing a softer version of this problem. Even attention-based models degrade on very long sequences. The RNN problem is just more severe and happens much sooner.


LSTM: a different kind of memory

The Long Short-Term Memory network (1997) does not eliminate the vanishing gradient. It routes around it.

The key insight is introducing a separate cell state c_t that flows through the network via addition rather than multiplication. Addition preserves gradient magnitude. Multiplication shrinks it.

Three gates control the cell state. Each gate produces values between 0 and 1 through a sigmoid function, where 0 means "block everything" and 1 means "pass everything through."

python
import torch
import torch.nn as nn

class LSTMCell(nn.Module):
    """Manual LSTM cell to make the gating mechanism visible."""

    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        # all four gate computations combined into one matrix for efficiency
        self.gates = nn.Linear(input_dim + hidden_dim, 4 * hidden_dim)

    def forward(self, x_t, h_prev, c_prev):
        """
        x_t:    (B, input_dim)   current input
        h_prev: (B, hidden_dim)  previous hidden state
        c_prev: (B, hidden_dim)  previous cell state
        """
        combined = torch.cat([x_t, h_prev], dim=1)
        all_gates = self.gates(combined)

        hidden_dim = h_prev.shape[1]

        # split into four gate activations
        i_gate, f_gate, g_gate, o_gate = all_gates.chunk(4, dim=1)

        i = torch.sigmoid(i_gate)    # input gate:  what new info to store
        f = torch.sigmoid(f_gate)    # forget gate: what to erase from c
        g = torch.tanh(g_gate)       # candidate:   new info to potentially store
        o = torch.sigmoid(o_gate)    # output gate: what portion of c to expose

        # cell state update: addition, not multiplication
        # this is what prevents gradient vanishing
        c_t = f * c_prev + i * g

        # hidden state from cell state
        h_t = o * torch.tanh(c_t)

        return h_t, c_t

# running LSTM over a sequence
lstm_cell = LSTMCell(input_dim=128, hidden_dim=256)
x = torch.randn(4, 20, 128)   # batch=4, seq=20, embed=128

B = x.shape[0]
h = torch.zeros(B, 256)
c = torch.zeros(B, 256)

hidden_states = []
for t in range(x.shape[1]):
    h, c = lstm_cell(x[:, t, :], h, c)
    hidden_states.append(h)

out = torch.stack(hidden_states, dim=1)
print(out.shape)   # torch.Size([4, 20, 256])

The cell state update c_t = f * c_prev + i * g is the crucial line. Gradients flowing back through the cell state path only pass through element-wise multiplications and additions, not through tanh squashing. This is why LSTMs can retain information over sequences of hundreds of steps where vanilla RNNs collapse after about 20.


GRU: the simplified alternative

Cho et al. introduced the Gated Recurrent Unit in 2014 as a simpler version of LSTM. It merges the hidden state and cell state into one, and folds the forget and input gates into a single update gate.

python
import torch
import torch.nn as nn

class GRUCell(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int):
        super().__init__()
        # reset and update gates
        self.W_r = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.W_z = nn.Linear(input_dim + hidden_dim, hidden_dim)
        # candidate hidden state
        self.W_n = nn.Linear(input_dim + hidden_dim, hidden_dim)

    def forward(self, x_t, h_prev):
        combined = torch.cat([x_t, h_prev], dim=1)

        r = torch.sigmoid(self.W_r(combined))   # reset gate
        z = torch.sigmoid(self.W_z(combined))   # update gate

        # candidate: how much of h_prev to forget before computing new info
        combined_reset = torch.cat([x_t, r * h_prev], dim=1)
        n = torch.tanh(self.W_n(combined_reset))

        # interpolate between old and new hidden state
        h_t = (1 - z) * h_prev + z * n

        return h_t

# GRU over a sequence
gru_cell = GRUCell(input_dim=128, hidden_dim=256)
x = torch.randn(4, 20, 128)

B = x.shape[0]
h = torch.zeros(B, 256)

hidden_states = []
for t in range(x.shape[1]):
    h = gru_cell(x[:, t, :], h)
    hidden_states.append(h)

out = torch.stack(hidden_states, dim=1)
print(out.shape)   # torch.Size([4, 20, 256])

GRU has fewer parameters than LSTM and trains faster. On most tasks, the performance difference is small. LSTM is still preferred when the task specifically requires the separation between short-term (hidden state) and long-term (cell state) memory.

Both architectures make a real difference over vanilla RNNs:

python
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

def train_and_evaluate(model_type: str, seq_len: int, steps: int = 300):
    if model_type == "rnn":
        rnn = nn.RNN(input_size=16, hidden_size=64, batch_first=True)
    elif model_type == "lstm":
        rnn = nn.LSTM(input_size=16, hidden_size=64, batch_first=True)
    else:
        rnn = nn.GRU(input_size=16, hidden_size=64, batch_first=True)

    head = nn.Linear(64, 2)
    params = list(rnn.parameters()) + list(head.parameters())
    opt = torch.optim.AdamW(params, lr=1e-3)

    losses = []
    for _ in range(steps):
        x = torch.randn(16, seq_len, 16)
        y = torch.randint(0, 2, (16,))

        out, _ = rnn(x) if model_type != "lstm" else (rnn(x)[0], None)
        logits = head(out[:, -1, :])
        loss = F.cross_entropy(logits, y)

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 1.0)
        opt.step()
        losses.append(loss.item())

    return losses[-1]

for model_type in ["rnn", "lstm", "gru"]:
    for seq_len in [20, 100]:
        final_loss = train_and_evaluate(model_type, seq_len)
        print(f"{model_type:5s} | seq_len={seq_len:3d} | final loss: {final_loss:.4f}")
plaintext
rnn   | seq_len= 20 | final loss: 0.6419
rnn   | seq_len=100 | final loss: 0.7012
lstm  | seq_len= 20 | final loss: 0.6021
lstm  | seq_len=100 | final loss: 0.6398
gru   | seq_len= 20 | final loss: 0.6108
gru   | seq_len=100 | final loss: 0.6431

At short sequences, the three models perform similarly. At longer sequences, vanilla RNN degrades while LSTM and GRU hold up better. This is what gating buys you.


What RNNs could not fix

LSTM and GRU are better than vanilla RNNs. But they do not solve the fundamental problems.

The sequential dependency is still there. You still cannot parallelize within a sequence. Training on long sequences is still slow.

Long-range dependency is better but not solved. An LSTM can retain information over 200 or 300 steps more reliably than a vanilla RNN. But in sequence-to-sequence models, the entire input is compressed into a single fixed-size context vector before being handed to the decoder. That bottleneck is a hard limit on what the model can carry forward, regardless of how good the gating is.

Consider translating a long document. The encoder reads every sentence and compresses everything into one vector. The decoder generates the translation from that single vector alone. Any detail the fixed vector did not preserve is lost. The longer the source document, the more gets dropped.

This bottleneck was actually what led to the first attention mechanisms, which arrived in 2014 as addons to RNN-based encoder-decoder architectures. They let the decoder look back at all the encoder hidden states rather than only the final one. That worked better. Then researchers asked what would happen if you removed the recurrence entirely and built a model that was purely attention-based.

The answer was the 2017 "Attention Is All You Need" paper. Article 6 covers what that paper built and why it worked.


State space models: the post-transformer alternative

This is the 2024 and 2025 part of the story that most courses on RNNs leave out.

After transformers dominated for several years, researchers noticed that attention has its own scaling problem. Attention complexity is O(N squared) in sequence length. For 100 tokens that is fine. For 100,000 tokens it is prohibitive.

This opened a new research direction: can you build sequence models with linear complexity that match transformer quality?

Mamba (Gu and Dao, 2023) introduced selective state space models, which are architecturally closer to RNNs than to transformers but with a key addition: the model can selectively decide which parts of the input to remember and which to ignore, rather than applying the same state transition to everything.

python
# conceptual demonstration of selective state space model behavior
# real Mamba uses hardware-aware parallel scans, not this naive loop
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelectiveSSM(nn.Module):
    """
    Simplified SSM showing the selection mechanism concept.
    Real Mamba uses structured parallel scans for efficiency.
    """
    def __init__(self, d_model: int, d_state: int = 16):
        super().__init__()
        self.d_state = d_state
        # input-dependent parameters (the "selection" mechanism)
        self.B_proj = nn.Linear(d_model, d_state)   # input-to-state
        self.C_proj = nn.Linear(d_model, d_state)   # state-to-output
        self.dt_proj = nn.Linear(d_model, 1)        # discretization step size
        # fixed state transition matrix (diagonal for efficiency)
        self.A_log = nn.Parameter(torch.randn(d_state))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.shape
        h = torch.zeros(B, self.d_state)
        outputs = []

        for t in range(T):
            x_t = x[:, t, :]

            # selection: B, C, and dt are input-dependent
            B_t = self.B_proj(x_t)                          # (B, d_state)
            C_t = self.C_proj(x_t)                          # (B, d_state)
            dt  = F.softplus(self.dt_proj(x_t))             # (B, 1)
            A   = -torch.exp(self.A_log)                    # (d_state,)

            # discretized state update
            dA = torch.exp(dt * A)                          # (B, d_state)
            h  = dA * h + dt * B_t * x_t[:, :self.d_state] # state update

            # output
            y_t = (h * C_t).sum(dim=1, keepdim=True)
            outputs.append(y_t)

        return torch.cat(outputs, dim=1)

ssm = SelectiveSSM(d_model=64, d_state=16)
x = torch.randn(4, 100, 64)
out = ssm(x)
print(out.shape)   # torch.Size([4, 100, 1])

The selection mechanism is what separates Mamba from older SSMs. Standard state space models apply the same transition at every step. Mamba's B, C, and dt parameters all depend on the current input, so the model can learn to selectively compress some inputs and preserve others.

Mamba scales linearly with sequence length during training and requires only constant memory per step during inference, since it does not need to store past inputs the way attention does with its KV cache. Research as of 2025 shows that Mamba still lags behind transformers on tasks requiring strong in-context learning, but it matches transformer quality on sequence lengths where attention becomes memory-prohibitive.

The honest current state: transformers dominate language modeling. Mamba and hybrid transformer-SSM models are strong contenders for long-context tasks and specialized domains like genomics and audio where sequences run into the millions of tokens. Neither has definitively won.


What this means for the transformer you will build in Article 6

The article on transformers is going to introduce attention as the mechanism that replaced recurrence. That will make much more sense now that you know what recurrence failed at.

Attention removes the sequential dependency entirely. Every token attends to every other token in parallel. There is no hidden state being passed step-by-step. There is no fixed-size bottleneck compressing the entire sequence. The gradient from any output position can reach any input position without traveling through hundreds of intermediate steps.

That is why it worked. Not because attention is some mysterious new idea, but because it directly addresses the three things that broke RNNs: sequential computation, the vanishing gradient over long distances, and the information bottleneck.

Embedding vectors from Article 4 flow directly into the transformer's attention mechanism. The token embeddings you built in that article are the same embeddings the transformer operates on. The sequence processing problem those embeddings were trying to solve with RNNs is exactly what attention solves differently.

One thing worth actually running

Implement the gradient decay simulation from this article. Change the average tanh derivative from 0.6 to 0.99 (which is what ReLU allows) and watch what happens to long-sequence learning. Then understand why transformers skip the whole problem by not using sequential multiplication at all.


Next in the series

Article 6 covers the transformer architecture and the attention mechanism. You will see how self-attention computes relationships between all token pairs simultaneously, what the query-key-value structure means geometrically, and how residual connections solve a different version of the vanishing gradient problem in very deep networks. Everything in this series up to this point leads there.

Follow on Google

Add as a preferred source in Search & Discover

Add as preferred source
Appears in Google Discover
K

Krunal Kanojiya

Technical Content Writer

Technical Content Writer and former software developer from India. I write in-depth articles on blockchain, AI/ML, data engineering, web development, and developer careers. Currently at Lucent Innovation, previously at Cromtek Solution and freelance.

Related Posts