Back to articles
Understanding RNNs and LSTMs: The Foundation of Sequence Modeling
intermediateDeep Learning

Understanding RNNs and LSTMs: The Foundation of Sequence Modeling

A deep dive into Recurrent Neural Networks and Long Short-Term Memory networks. Learn how they process sequences, why vanilla RNNs fail, and how LSTMs solve the vanishing gradient problem.

15 min read

Imagine you're reading a book. You don't process each word in isolation—you remember what came before, building context as you go. This is exactly what Recurrent Neural Networks (RNNs) do for machines. They're designed to process sequences by maintaining a "memory" of previous inputs.

In this guide, we'll explore how RNNs work, why they struggle with long sequences, and how Long Short-Term Memory (LSTM) networks elegantly solve these problems.

Traditional feedforward neural networks have a fundamental limitation: they treat each input independently. Consider these two sentences:

  • "The cat sat on the mat"
  • "The mat sat on the cat"

A feedforward network would see the same words and might produce similar outputs, completely missing that these sentences have opposite meanings. The problem? No memory of word order.

python
# Feedforward network treats each word independently
class FeedforwardClassifier(nn.Module):
    def forward(self, words):
        # Process each word separately
        word_features = [self.process(word) for word in words]
        # Average them (loses order!)
        sentence_feature = torch.mean(word_features, dim=0)
        return self.classify(sentence_feature)

# Result: "cat sat mat" = "mat sat cat" = "sat cat mat"
# All produce the same output! 😱

RNNs solve this by introducing recurrence—the output at each step depends not just on the current input, but also on the previous hidden state. Think of it as a neural network with memory.

An RNN maintains a hidden state that gets updated at each time step. This hidden state acts as the network's memory, encoding information about everything it has seen so far.

python
# Simplified RNN cell
def rnn_cell(input_t, hidden_t_minus_1):
    """
    Process one time step of an RNN
    
    Args:
        input_t: Current input (e.g., word embedding)
        hidden_t_minus_1: Previous hidden state (memory)
    
    Returns:
        hidden_t: New hidden state
    """
    # Combine current input with previous memory
    combined = torch.cat([input_t, hidden_t_minus_1], dim=1)
    
    # Apply transformation and activation
    hidden_t = torch.tanh(W @ combined + b)
    
    return hidden_t

Let's walk through processing the sentence "I love pizza" word by word:

python
# Initialize hidden state (all zeros)
hidden = torch.zeros(hidden_size)

# Step 1: Process "I"
word_1 = embed("I")  # Convert word to vector
hidden = rnn_cell(word_1, hidden)
print(f"After 'I': hidden = {hidden}")
# hidden now encodes: "I"

# Step 2: Process "love"
word_2 = embed("love")
hidden = rnn_cell(word_2, hidden)  # Uses previous hidden state!
print(f"After 'love': hidden = {hidden}")
# hidden now encodes: "I love"

# Step 3: Process "pizza"
word_3 = embed("pizza")
hidden = rnn_cell(word_3, hidden)
print(f"After 'pizza': hidden = {hidden}")
# hidden now encodes: "I love pizza"

# Final hidden state represents the entire sentence!

Key Insight: The hidden state is like a running summary. At each step, it combines the new word with the summary of everything before it.

The RNN update equation is surprisingly simple:

python
# At time step t:
# h_t = tanh(W_hh @ h_{t-1} + W_xh @ x_t + b_h)

# Where:
# h_t = new hidden state
# h_{t-1} = previous hidden state
# x_t = current input
# W_hh = hidden-to-hidden weight matrix
# W_xh = input-to-hidden weight matrix
# b_h = bias
# tanh = activation function (squashes values to [-1, 1])

Let's implement a simple RNN from scratch:

python
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Weight matrices
        self.W_xh = nn.Linear(input_size, hidden_size)   # input to hidden
        self.W_hh = nn.Linear(hidden_size, hidden_size)  # hidden to hidden
        self.W_hy = nn.Linear(hidden_size, output_size)  # hidden to output
    
    def forward(self, x):
        """
        Args:
            x: input sequence, shape (batch, seq_len, input_size)
        Returns:
            outputs: predictions at each time step
            hidden: final hidden state
        """
        batch_size, seq_len, _ = x.shape
        
        # Initialize hidden state
        hidden = torch.zeros(batch_size, self.hidden_size)
        
        outputs = []
        
        # Process sequence one step at a time
        for t in range(seq_len):
            # Get input at time t
            x_t = x[:, t, :]  # (batch, input_size)
            
            # Update hidden state
            hidden = torch.tanh(
                self.W_xh(x_t) + self.W_hh(hidden)
            )
            
            # Compute output
            output = self.W_hy(hidden)
            outputs.append(output)
        
        # Stack outputs
        outputs = torch.stack(outputs, dim=1)  # (batch, seq_len, output_size)
        
        return outputs, hidden

RNNs sound perfect, right? Unfortunately, they have a critical flaw: they can't learn long-range dependencies. This is called the vanishing gradient problem.

During backpropagation through time, gradients must flow backward through many time steps. At each step, they get multiplied by the weight matrix and the derivative of tanh.

python
# Gradient flow in RNN
# To update weights based on word 1, gradient must flow:
# word_10 -> word_9 -> word_8 -> ... -> word_2 -> word_1

# At each step, gradient gets multiplied by:
# - Weight matrix W_hh
# - Derivative of tanh (which is < 1)

# After 10 steps:
# gradient = gradient * W_hh * tanh' * W_hh * tanh' * ... (10 times)

# If W_hh has values < 1 and tanh' < 1:
# gradient ≈ 0.9^10 ≈ 0.35  (already quite small)
# gradient ≈ 0.9^50 ≈ 0.005 (almost zero!)

# Result: Early words don't get updated effectively

This means RNNs struggle with sentences like:

"The chef, who trained in Paris for five years and later opened a restaurant in Tokyo, **makes** amazing sushi."

The RNN needs to connect "chef" (at the start) with "makes" (at the end), but the gradient signal is too weak by the time it reaches back to "chef".

The Problem: Vanilla RNNs can only effectively learn dependencies spanning 5-10 time steps. Beyond that, gradients vanish and learning fails.

LSTMs were specifically designed to solve the vanishing gradient problem. They do this through a clever architecture with gates that control information flow.

LSTMs introduce a cell state—a separate memory channel that runs through the entire sequence with minimal modifications. Think of it as a "memory highway" where information can flow unchanged.

LSTMs use three gates to control the cell state:

Decides what information to remove from the cell state.

python
# Forget gate
f_t = sigmoid(W_f @ [h_{t-1}, x_t] + b_f)

# Output: values between 0 and 1
# 0 = "completely forget this"
# 1 = "completely keep this"

# Example: Processing "The cat, which was black, sat on the mat"
# When we reach "sat", we might forget details about "black"
# because color is less relevant to the action

Decides what new information to store in the cell state.

python
# Input gate
i_t = sigmoid(W_i @ [h_{t-1}, x_t] + b_i)

# Candidate values to add
C_tilde = tanh(W_C @ [h_{t-1}, x_t] + b_C)

# Combine: what to add and how much
new_info = i_t * C_tilde

# Example: When we see "sat", we want to remember this action
# Input gate opens to let this information in

Decides what parts of the cell state to output as the hidden state.

python
# Output gate
o_t = sigmoid(W_o @ [h_{t-1}, x_t] + b_o)

# Hidden state (what we output)
h_t = o_t * tanh(C_t)

# Example: We might remember many details in cell state,
# but only output the most relevant ones for the current task

Putting it all together:

python
def lstm_cell(x_t, h_prev, C_prev):
    """
    One step of LSTM computation
    
    Args:
        x_t: current input
        h_prev: previous hidden state
        C_prev: previous cell state
    
    Returns:
        h_t: new hidden state
        C_t: new cell state
    """
    # Concatenate input and previous hidden state
    combined = torch.cat([h_prev, x_t], dim=1)
    
    # 1. Forget gate: what to forget from cell state
    f_t = torch.sigmoid(W_f @ combined + b_f)
    
    # 2. Input gate: what new info to add
    i_t = torch.sigmoid(W_i @ combined + b_i)
    C_tilde = torch.tanh(W_C @ combined + b_C)
    
    # 3. Update cell state
    C_t = f_t * C_prev + i_t * C_tilde
    
    # 4. Output gate: what to output
    o_t = torch.sigmoid(W_o @ combined + b_o)
    h_t = o_t * torch.tanh(C_t)
    
    return h_t, C_t

The cell state provides a direct path for gradients to flow backward through time:

python
# Cell state update:
# C_t = f_t * C_{t-1} + i_t * C_tilde

# Gradient flow:
# dC_{t-1}/dC_t = f_t

# Key insight: f_t is learned!
# If the network needs to remember something,
# it can set f_t ≈ 1, allowing gradients to flow unchanged

# Compare to vanilla RNN:
# h_t = tanh(W @ h_{t-1} + ...)
# dh_{t-1}/dh_t = W * tanh'  (always < 1, causes vanishing)

# LSTM can maintain gradient flow over 100+ time steps!

The Magic: LSTMs can learn to keep the forget gate open (f_t ≈ 1) for important information, creating a gradient highway that prevents vanishing.

PyTorch provides a built-in LSTM implementation that's highly optimized:

python
import torch.nn as nn

# Create LSTM layer
lstm = nn.LSTM(
    input_size=100,      # dimension of input features
    hidden_size=256,     # dimension of hidden state
    num_layers=2,        # stack multiple LSTM layers
    batch_first=True,    # input shape: (batch, seq, features)
    dropout=0.3,         # dropout between layers
    bidirectional=False  # unidirectional (left-to-right only)
)

# Example usage
batch_size = 32
seq_len = 50
input_size = 100

# Input: (batch, sequence_length, input_size)
x = torch.randn(batch_size, seq_len, input_size)

# Forward pass
output, (h_n, c_n) = lstm(x)

print(f"Output shape: {output.shape}")  # (32, 50, 256)
print(f"Final hidden state: {h_n.shape}")  # (2, 32, 256) - 2 layers
print(f"Final cell state: {c_n.shape}")    # (2, 32, 256)
python
# LSTM returns two things:
# 1. output: hidden states at ALL time steps
#    shape: (batch, seq_len, hidden_size)
#    output[i, t, :] = hidden state for batch i at time t

# 2. (h_n, c_n): final hidden and cell states
#    shape: (num_layers, batch, hidden_size)
#    h_n[-1] = final hidden state of last layer

# For classification, typically use:
# - output[:, -1, :] (last time step's hidden state)
# OR
# - h_n[-1] (final layer's final hidden state)
# Both are equivalent for single-layer LSTM

A standard LSTM only reads left-to-right. But for many tasks, we want to see the full context. Bidirectional LSTMs process the sequence in both directions:

python
# Bidirectional LSTM
lstm = nn.LSTM(
    input_size=100,
    hidden_size=256,
    bidirectional=True  # This is the key!
)

# Now output has double the hidden size
output, (h_n, c_n) = lstm(x)
print(f"Output shape: {output.shape}")  # (32, 50, 512)
#                                         256 forward + 256 backward

# h_n contains both directions
print(f"h_n shape: {h_n.shape}")  # (2, 32, 256)
#                                   [forward_final, backward_final]

Let's build a complete sentiment classifier using LSTM:

python
class SentimentLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, n_classes=2):
        super().__init__()
        
        # Embedding layer (convert word indices to vectors)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )
        
        # Classification layer
        self.fc = nn.Linear(hidden_dim * 2, n_classes)  # *2 for bidirectional
    
    def forward(self, x):
        # x shape: (batch, seq_len) - word indices
        
        # Embed words
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        
        # Pass through LSTM
        lstm_out, (h_n, c_n) = self.lstm(embedded)
        # lstm_out: (batch, seq_len, hidden_dim * 2)
        
        # Take final hidden state
        final_hidden = lstm_out[:, -1, :]  # (batch, hidden_dim * 2)
        
        # Classify
        logits = self.fc(final_hidden)  # (batch, n_classes)
        
        return logits

# Training
model = SentimentLSTM(vocab_size=10000, n_classes=2)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
    for batch_x, batch_y in train_loader:
        # Forward pass
        logits = model(batch_x)
        loss = criterion(logits, batch_y)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

While LSTMs solve vanishing gradients, they can still suffer from exploding gradients. Solution: gradient clipping.

python
# Clip gradients to prevent explosion
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

# In training loop:
for batch_x, batch_y in train_loader:
    loss = criterion(model(batch_x), batch_y)
    optimizer.zero_grad()
    loss.backward()
    
    # Clip gradients before optimizer step
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
    
    optimizer.step()

LSTMs process sequences sequentially, which is slow. Solutions:

  • Use packed sequences to skip padding
  • Use larger batch sizes
  • Consider using GRU (simpler, faster variant of LSTM)
  • For very long sequences, consider Transformers instead

LSTMs have many parameters and can overfit. Solutions:

python
# 1. Dropout between LSTM layers
lstm = nn.LSTM(..., dropout=0.3, num_layers=2)

# 2. Dropout after LSTM
self.dropout = nn.Dropout(0.5)
final_hidden = self.dropout(lstm_out[:, -1, :])

# 3. Weight decay (L2 regularization)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)

# 4. Early stopping
if val_loss hasn't improved for 5 epochs:
    stop training
FeatureLSTMGRUTransformer
Gates3 (forget, input, output)2 (reset, update)None (attention)
ParametersMostFewer (~25% less)Most
Training speedSlowFasterFastest (parallel)
Long sequencesGood (100+ tokens)Good (100+ tokens)Excellent (1000+ tokens)
MemoryCell state + hiddenHidden onlyAttention weights
Use caseGeneral sequence modelingWhen speed mattersState-of-the-art NLP
  • You need to model sequential dependencies
  • Word order matters (it almost always does in NLP)
  • You have moderate amounts of data (10K+ examples)
  • Sequences are moderate length (< 500 tokens)
  • You want a good balance of performance and interpretability
  • You have very little data (< 1K examples) → use simpler models
  • You need state-of-the-art results → use Transformers
  • Sequences are very long (> 1000 tokens) → use Transformers with efficient attention
  • Training time is critical → consider GRU or simpler models

RNNs and LSTMs represent a fundamental breakthrough in sequence modeling. While Transformers have largely replaced them in state-of-the-art NLP, understanding LSTMs is crucial because:

  1. They introduce core concepts (hidden state, sequential processing) that appear in all sequence models
  2. They're still practical for many real-world applications with limited compute
  3. They're more interpretable than Transformers
  4. Understanding why they fail (vanishing gradients) helps you understand why Transformers succeed

Key Takeaway: LSTMs solve the vanishing gradient problem through gating mechanisms that create a "memory highway" for gradients. This allows them to learn long-range dependencies that vanilla RNNs cannot.

Master LSTMs, and you'll have a solid foundation for understanding modern sequence models like Transformers, which build upon these same core ideas.

Related Articles