BiLSTM for Text Classification: Understanding Sequential Deep Learning
Learn how Bidirectional LSTM networks process text sequentially to capture context, word order, and meaning. A complete guide to building your first sequence model for NLP.
Imagine you're reading a sentence: "I don't want to cancel my flight." As a human, you understand that the word "don't" completely changes the meaning. But what if we told you that many machine learning models would treat "I want to cancel my flight" and "I don't want to cancel my flight" almost identically?
This is the fundamental problem with bag-of-words approaches and even frozen sentence embeddings—they lose the sequential structure of language. Enter Bidirectional Long Short-Term Memory (BiLSTM) networks, a powerful architecture that reads text word by word, understanding context, word order, and compositional meaning.
Let's understand why sequence matters with a concrete example:
# Two sentences with opposite meanings
sentence1 = "I want to cancel my flight"
sentence2 = "I don't want to cancel my flight"
# Using frozen embeddings (like sentence-transformers)
# Both get collapsed into a single 384-dim vector
embed1 = model.encode(sentence1) # [0.23, -0.45, 0.67, ...]
embed2 = model.encode(sentence2) # [0.21, -0.43, 0.65, ...]
# Cosine similarity is very high!
similarity = cosine_similarity(embed1, embed2) # 0.95
# The model thinks they're almost the same! 😱The problem? These models process the entire sentence at once, creating a single vector representation. The word "don't" gets averaged out with all other words, losing its critical negation role.
Recurrent Neural Networks are designed to process sequences by maintaining a hidden state that gets updated at each time step. Think of it like reading a book—you don't forget what you read in previous sentences; you carry that context forward.
Let's process the sentence "I love pizza" word by word:
# Simplified RNN processing
sentence = ["I", "love", "pizza"]
hidden_state = [0, 0, 0] # Initial state (all zeros)
# Step 1: Process "I"
word_embedding_1 = [0.1, 0.2, 0.3] # Vector for "I"
hidden_state = rnn_cell(word_embedding_1, hidden_state)
# hidden_state = [0.15, 0.22, 0.18] # Updated with info about "I"
# Step 2: Process "love"
word_embedding_2 = [0.4, 0.5, 0.6] # Vector for "love"
hidden_state = rnn_cell(word_embedding_2, hidden_state)
# hidden_state = [0.35, 0.48, 0.52] # Now knows about "I love"
# Step 3: Process "pizza"
word_embedding_3 = [0.7, 0.8, 0.9] # Vector for "pizza"
hidden_state = rnn_cell(word_embedding_3, hidden_state)
# hidden_state = [0.62, 0.75, 0.81] # Final state: "I love pizza"At each step, the RNN combines the current word with the previous hidden state, creating a new hidden state that encodes everything seen so far. The final hidden state represents the entire sentence.
Simple RNNs have a fatal flaw: they can't remember long-range dependencies. When processing long sentences, the gradient signal gets weaker and weaker as it propagates backward through time. This is called the vanishing gradient problem.
# Example: Long sentence
sentence = "The chef who trained in Paris and later opened a restaurant in Tokyo makes amazing sushi"
# Simple RNN struggles to connect "chef" with "makes"
# By the time it reaches "makes", it has mostly forgotten about "chef"
# The gradient from "makes" barely reaches back to "chef"This is where LSTMs come to the rescue.
LSTMs solve the vanishing gradient problem through a clever architecture with gates that control information flow. Think of gates as smart filters that decide what to remember, what to forget, and what to output.
- Forget Gate: Decides what information to throw away from the cell state. "Should I forget that we're talking about a chef?"
- Input Gate: Decides what new information to store in the cell state. "Should I remember that we're now talking about sushi?"
- Output Gate: Decides what to output based on the cell state. "What information is relevant for the next step?"
Here's a simplified view of how LSTM processes one word:
def lstm_cell(word_embedding, prev_hidden, prev_cell_state):
"""
Simplified LSTM cell (actual implementation is more complex)
"""
# Forget gate: What to forget from previous cell state?
forget_gate = sigmoid(W_f @ [word_embedding, prev_hidden])
# Values between 0 (forget everything) and 1 (keep everything)
# Input gate: What new information to add?
input_gate = sigmoid(W_i @ [word_embedding, prev_hidden])
candidate = tanh(W_c @ [word_embedding, prev_hidden])
# Update cell state
cell_state = forget_gate * prev_cell_state + input_gate * candidate
# Output gate: What to output?
output_gate = sigmoid(W_o @ [word_embedding, prev_hidden])
hidden_state = output_gate * tanh(cell_state)
return hidden_state, cell_stateA regular LSTM only reads text left-to-right. But humans understand language by considering context from both directions. Consider this sentence:
"The bank was steep and covered with grass."
Is "bank" a financial institution or a riverbank? You need to read ahead to "steep" and "grass" to know. This is why Bidirectional LSTMs are so powerful—they process the sequence in both directions simultaneously.
The BiLSTM creates two hidden states for each word:
- Forward hidden state: Encodes everything from the start up to this word
- Backward hidden state: Encodes everything from the end back to this word
These are concatenated to give each word full context from both directions.
Let's build a complete BiLSTM classifier step by step. We'll classify customer service queries into intents (like "cancel_flight", "book_hotel", etc.).
Before we can feed text into a neural network, we need to convert words to numbers. This involves building a vocabulary—a mapping from words to integer indices.
class Vocabulary:
def __init__(self, min_freq=2, max_vocab=10000):
# Special tokens
self.word2idx = {"<PAD>": 0, "<UNK>": 1}
self.idx2word = {0: "<PAD>", 1: "<UNK>"}
self.min_freq = min_freq
self.max_vocab = max_vocab
def build_from_texts(self, texts):
"""Build vocabulary from training texts"""
from collections import Counter
# Count word frequencies
word_counts = Counter()
for text in texts:
words = text.lower().split()
word_counts.update(words)
# Keep words that appear at least min_freq times
valid_words = [
word for word, count in word_counts.items()
if count >= self.min_freq
]
# Sort by frequency and take top max_vocab
valid_words = sorted(
valid_words,
key=lambda w: word_counts[w],
reverse=True
)[:self.max_vocab - 2] # -2 for PAD and UNK
# Assign indices (starting from 2)
for idx, word in enumerate(valid_words, start=2):
self.word2idx[word] = idx
self.idx2word[idx] = word
# Example usage
texts = [
"I want to cancel my flight",
"Can you help me book a hotel",
"I need to cancel my reservation"
]
vocab = Vocabulary(min_freq=1, max_vocab=100)
vocab.build_from_texts(texts)
print(f"Vocabulary size: {len(vocab.word2idx)}")
print(f"Sample mappings: {list(vocab.word2idx.items())[:10]}")Neural networks require fixed-size inputs, but sentences have variable lengths. We solve this with padding—adding special tokens to make all sequences the same length.
def encode_text(text, vocab, max_len=50):
"""Convert text to padded sequence of indices"""
words = text.lower().split()
# Convert words to indices (use 1 for unknown words)
indices = [
vocab.word2idx.get(word, 1) # 1 is <UNK>
for word in words
]
# Truncate if too long
if len(indices) > max_len:
indices = indices[:max_len]
# Pad if too short (0 is <PAD>)
if len(indices) < max_len:
indices = indices + [0] * (max_len - len(indices))
return indices
# Example
text = "I want to cancel my flight"
encoded = encode_text(text, vocab, max_len=10)
print(f"Original: {text}")
print(f"Encoded: {encoded}")
print(f"Length: {len(encoded)}")
# Output:
# Original: I want to cancel my flight
# Encoded: [5, 8, 3, 12, 7, 15, 0, 0, 0, 0]
# Length: 10Now we need to convert word indices to dense vectors. Unlike frozen embeddings, we'll learn these embeddings from scratch during training. This allows the model to learn task-specific word representations.
import torch
import torch.nn as nn
# Create embedding layer
vocab_size = 10000
embed_dim = 128
embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embed_dim,
padding_idx=0 # Don't update embeddings for padding tokens
)
# Example: Convert word indices to embeddings
word_indices = torch.tensor([[5, 8, 3, 12, 7, 15, 0, 0, 0, 0]])
embedded = embedding(word_indices)
print(f"Input shape: {word_indices.shape}") # (1, 10)
print(f"Output shape: {embedded.shape}") # (1, 10, 128)
print(f"Each word is now a {embed_dim}-dimensional vector!")The embedding layer is essentially a lookup table. Each word index maps to a learnable vector. During training, backpropagation updates these vectors to be more useful for the task.
Now we can build the complete BiLSTM classifier:
class BiLSTMClassifier(nn.Module):
def __init__(self, vocab_size, n_classes, embed_dim=128,
hidden_dim=256, num_layers=2, dropout=0.3):
super().__init__()
# 1. Embedding layer (learnable word vectors)
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embed_dim,
padding_idx=0
)
# 2. Bidirectional LSTM
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True, # Input shape: (batch, seq_len, features)
dropout=dropout if num_layers > 1 else 0,
bidirectional=True # This is the key!
)
# 3. Classification head
# BiLSTM outputs hidden_dim * 2 (forward + backward)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_dim * 2, n_classes)
def forward(self, x):
"""
Args:
x: word indices, shape (batch, seq_len)
Returns:
logits, shape (batch, n_classes)
"""
# 1. Embed words: (batch, seq_len) -> (batch, seq_len, embed_dim)
embedded = self.embedding(x)
# 2. Pass through BiLSTM
# lstm_out shape: (batch, seq_len, hidden_dim * 2)
lstm_out, (h_n, c_n) = self.lstm(embedded)
# 3. Take final hidden state (last time step)
# This represents the entire sequence
final_hidden = lstm_out[:, -1, :] # (batch, hidden_dim * 2)
# 4. Classify
out = self.dropout(final_hidden)
logits = self.fc(out) # (batch, n_classes)
return logitsLet's trace through what happens to a single sentence:
# Example: "I want to cancel my flight"
# After encoding and padding: [5, 8, 3, 12, 7, 15, 0, 0, 0, 0]
# Step 1: Embedding
# Input: (1, 10) - batch of 1, sequence length 10
# Output: (1, 10, 128) - each word is now a 128-dim vector
# Step 2: BiLSTM
# The LSTM processes this sequence in both directions:
# - Forward LSTM: reads [5, 8, 3, 12, 7, 15, 0, 0, 0, 0]
# - Backward LSTM: reads [0, 0, 0, 0, 15, 7, 12, 3, 8, 5]
# Output: (1, 10, 512) - 256 from forward + 256 from backward
# Step 3: Take final hidden state
# We take the last time step: lstm_out[:, -1, :]
# Output: (1, 512) - this represents the entire sentence
# Step 4: Classification
# Linear layer: (1, 512) -> (1, 151) for 151 classes
# Each value is a score for that classTraining a BiLSTM is similar to training any neural network, but with some sequence-specific considerations:
import torch.optim as optim
# Initialize model
model = BiLSTMClassifier(
vocab_size=10000,
n_classes=151,
embed_dim=128,
hidden_dim=256,
num_layers=2,
dropout=0.3
)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Training loop
for epoch in range(30):
model.train()
total_loss = 0
for batch_x, batch_y in train_loader:
# batch_x shape: (batch_size, seq_len)
# batch_y shape: (batch_size,)
# Forward pass
logits = model(batch_x)
loss = criterion(logits, batch_y)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# Validation
model.eval()
with torch.no_grad():
val_logits = model(X_val)
val_preds = val_logits.argmax(dim=1)
val_acc = (val_preds == y_val).float().mean()
print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Val Acc={val_acc:.4f}")Let's compare the two approaches on our negation example:
| Aspect | Frozen Embeddings | BiLSTM |
|---|---|---|
| Processing | Entire sentence → single vector | Word by word → sequence of vectors |
| Word order | Lost (bag of words) | Preserved (sequential) |
| Negation handling | Poor ("don't" gets averaged out) | Good ("don't" affects subsequent words) |
| Context | Global only | Local + global (bidirectional) |
| Embeddings | Fixed (pretrained) | Learned (task-specific) |
| Parameters | ~100K (classifier only) | ~2-5M (embeddings + LSTM + classifier) |
| Training time | Fast (minutes) | Slower (hours) |
| Typical accuracy | 85-88% | 88-92% |
# Sentence: "I don't want to cancel"
# Forward LSTM processing:
# "I" -> hidden_state_1 (knows about "I")
# "don't" -> hidden_state_2 (knows about "I don't")
# "want" -> hidden_state_3 (knows about "I don't want")
# "to" -> hidden_state_4 (knows about "I don't want to")
# "cancel"-> hidden_state_5 (knows about "I don't want to cancel")
# Backward LSTM processing:
# "cancel"-> hidden_state_1 (knows about "cancel")
# "to" -> hidden_state_2 (knows about "to cancel")
# "want" -> hidden_state_3 (knows about "want to cancel")
# "don't" -> hidden_state_4 (knows about "don't want to cancel")
# "I" -> hidden_state_5 (knows about "I don't want to cancel")
# Final representation: concatenate forward_5 + backward_5
# This captures the full context with negation preserved!BiLSTMs have several important hyperparameters that significantly affect performance:
- Too small (32-64): Words can't capture enough semantic information
- Sweet spot (128-256): Good balance of expressiveness and efficiency
- Too large (512+): Overfitting, slower training, diminishing returns
- Too small (64-128): Can't capture complex patterns
- Sweet spot (256-512): Sufficient capacity for most tasks
- Too large (1024+): Overfitting, memory issues
- 1 layer: Simple patterns only
- 2 layers: Good for most tasks (recommended starting point)
- 3+ layers: Deeper hierarchies, but harder to train
# Analyze your data first!
import numpy as np
sentence_lengths = [len(text.split()) for text in train_texts]
print(f"Mean length: {np.mean(sentence_lengths):.1f}")
print(f"95th percentile: {np.percentile(sentence_lengths, 95):.0f}")
print(f"Max length: {np.max(sentence_lengths)}")
# Set max_len to cover 95-99% of sentences
# For CLINC150 dataset: max_len=50 covers 99% of sentences# ❌ Wrong: Unidirectional LSTM
lstm = nn.LSTM(embed_dim, hidden_dim, bidirectional=False)
# Output: (batch, seq_len, hidden_dim)
# Only sees left context!
# ✅ Correct: Bidirectional LSTM
lstm = nn.LSTM(embed_dim, hidden_dim, bidirectional=True)
# Output: (batch, seq_len, hidden_dim * 2)
# Sees both left and right context!# ❌ Wrong: Using h_n (final layer's hidden state)
lstm_out, (h_n, c_n) = self.lstm(embedded)
final_hidden = h_n[-1] # Only gets last layer, one direction
# ✅ Correct: Using last time step from lstm_out
lstm_out, (h_n, c_n) = self.lstm(embedded)
final_hidden = lstm_out[:, -1, :] # Gets both directions concatenated# ❌ Wrong: Padding tokens get updated during training
embedding = nn.Embedding(vocab_size, embed_dim)
# Padding tokens (index 0) will have non-zero gradients!
# ✅ Correct: Padding tokens stay at zero
embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# Padding tokens don't contribute to gradientsOn a typical intent classification dataset (like CLINC150 with 151 classes):
| Model | Accuracy | Parameters | Training Time | Inference Speed |
|---|---|---|---|---|
| TF-IDF + Logistic Regression | 82-85% | ~150K | < 1 min | Very fast |
| Frozen Embeddings + MLP | 85-88% | ~100K | 2-5 min | Fast |
| BiLSTM (this approach) | 88-92% | 2-5M | 30-60 min | Medium |
| BERT (fine-tuned) | 93-96% | 110M | 2-4 hours | Slow |
- Word order and sequence structure are critical (negation, temporal relationships)
- You have moderate amounts of training data (10K+ examples)
- You need better accuracy than bag-of-words but can't afford transformer training time
- Interpretability matters (you can visualize attention over time steps)
- You're working with sequences of moderate length (< 100 tokens)
- You have very little data (< 1K examples) → use frozen embeddings
- You need state-of-the-art accuracy and have compute budget → use transformers
- Sequences are very long (> 500 tokens) → LSTMs struggle with very long sequences
- Real-time inference is critical → simpler models are faster
When sentences have very different lengths, you can use packed sequences to avoid wasting computation on padding:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
def forward(self, x, lengths):
"""Forward pass with packed sequences"""
embedded = self.embedding(x)
# Pack sequences (removes padding)
packed = pack_padded_sequence(
embedded, lengths, batch_first=True, enforce_sorted=False
)
# LSTM only processes actual words, not padding
packed_out, (h_n, c_n) = self.lstm(packed)
# Unpack back to padded format
lstm_out, _ = pad_packed_sequence(packed_out, batch_first=True)
# Rest of forward pass...
final_hidden = lstm_out[:, -1, :]
return self.fc(self.dropout(final_hidden))Instead of just using the final hidden state, you can use attention to weight all time steps:
class BiLSTMWithAttention(nn.Module):
def __init__(self, vocab_size, n_classes, embed_dim=128, hidden_dim=256):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(embed_dim, hidden_dim, bidirectional=True, batch_first=True)
# Attention layer
self.attention = nn.Linear(hidden_dim * 2, 1)
self.fc = nn.Linear(hidden_dim * 2, n_classes)
def forward(self, x):
embedded = self.embedding(x)
lstm_out, _ = self.lstm(embedded) # (batch, seq_len, hidden*2)
# Compute attention weights
attention_scores = self.attention(lstm_out) # (batch, seq_len, 1)
attention_weights = torch.softmax(attention_scores, dim=1)
# Weighted sum of all hidden states
context = (lstm_out * attention_weights).sum(dim=1) # (batch, hidden*2)
return self.fc(context)You can initialize embeddings with pretrained vectors (like GloVe or Word2Vec) instead of random initialization:
# Load pretrained embeddings (e.g., GloVe)
pretrained_embeddings = load_glove_embeddings() # Shape: (vocab_size, embed_dim)
# Initialize embedding layer with pretrained weights
embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
# Optionally freeze embeddings (don't update during training)
embedding.weight.requires_grad = False # Frozen
# OR
embedding.weight.requires_grad = True # Fine-tuneBiLSTM networks represent a significant step up from bag-of-words and frozen embedding approaches. By processing text sequentially and bidirectionally, they capture the compositional nature of language—understanding that "I don't want to cancel" is fundamentally different from "I want to cancel."
Key takeaways:
- LSTMs solve the vanishing gradient problem through gating mechanisms
- Bidirectional processing gives each word full context from both directions
- Learned embeddings allow task-specific word representations
- Sequential processing preserves word order and handles negations correctly
- BiLSTMs offer a sweet spot between simple baselines and heavy transformers
Related Articles
Understanding RNNs and LSTMs: The Foundation of Sequence Modeling
A deep dive into Recurrent Neural Networks and Long Short-Term Memory networks. Learn how they process sequences, why vanilla RNNs fail, and how LSTMs solve the vanishing gradient problem.
From Words to Intelligence: Building an MLP Classifier on Pretrained Sentence Embeddings
A deep dive into pretrained sentence embeddings, MLP architecture, BatchNorm, Dropout, Adam, and early stopping — with full PyTorch implementation.
Logistic Regression from Scratch in PyTorch: Every Line Explained
Build a multi-class classifier in PyTorch without nn.Linear, without optim.SGD, without CrossEntropyLoss. Just [tensors](/blog/what-is-a-tensor), [autograd](/blog/pytorch-autograd-deep-dive), and arithmetic — so you finally see what those helpers actually do.