Back to articles
BiLSTM for Text Classification: Understanding Sequential Deep Learning
intermediateDeep Learning

BiLSTM for Text Classification: Understanding Sequential Deep Learning

Learn how Bidirectional LSTM networks process text sequentially to capture context, word order, and meaning. A complete guide to building your first sequence model for NLP.

18 min read

Imagine you're reading a sentence: "I don't want to cancel my flight." As a human, you understand that the word "don't" completely changes the meaning. But what if we told you that many machine learning models would treat "I want to cancel my flight" and "I don't want to cancel my flight" almost identically?

This is the fundamental problem with bag-of-words approaches and even frozen sentence embeddings—they lose the sequential structure of language. Enter Bidirectional Long Short-Term Memory (BiLSTM) networks, a powerful architecture that reads text word by word, understanding context, word order, and compositional meaning.

Let's understand why sequence matters with a concrete example:

python
# Two sentences with opposite meanings
sentence1 = "I want to cancel my flight"
sentence2 = "I don't want to cancel my flight"

# Using frozen embeddings (like sentence-transformers)
# Both get collapsed into a single 384-dim vector
embed1 = model.encode(sentence1)  # [0.23, -0.45, 0.67, ...]
embed2 = model.encode(sentence2)  # [0.21, -0.43, 0.65, ...]

# Cosine similarity is very high!
similarity = cosine_similarity(embed1, embed2)  # 0.95
# The model thinks they're almost the same! 😱

The problem? These models process the entire sentence at once, creating a single vector representation. The word "don't" gets averaged out with all other words, losing its critical negation role.

Key Insight: Language is inherently sequential. The order of words matters. "Dog bites man" is very different from "Man bites dog"—same words, completely different meaning.

Recurrent Neural Networks are designed to process sequences by maintaining a hidden state that gets updated at each time step. Think of it like reading a book—you don't forget what you read in previous sentences; you carry that context forward.

Let's process the sentence "I love pizza" word by word:

python
# Simplified RNN processing
sentence = ["I", "love", "pizza"]
hidden_state = [0, 0, 0]  # Initial state (all zeros)

# Step 1: Process "I"
word_embedding_1 = [0.1, 0.2, 0.3]  # Vector for "I"
hidden_state = rnn_cell(word_embedding_1, hidden_state)
# hidden_state = [0.15, 0.22, 0.18]  # Updated with info about "I"

# Step 2: Process "love"
word_embedding_2 = [0.4, 0.5, 0.6]  # Vector for "love"
hidden_state = rnn_cell(word_embedding_2, hidden_state)
# hidden_state = [0.35, 0.48, 0.52]  # Now knows about "I love"

# Step 3: Process "pizza"
word_embedding_3 = [0.7, 0.8, 0.9]  # Vector for "pizza"
hidden_state = rnn_cell(word_embedding_3, hidden_state)
# hidden_state = [0.62, 0.75, 0.81]  # Final state: "I love pizza"

At each step, the RNN combines the current word with the previous hidden state, creating a new hidden state that encodes everything seen so far. The final hidden state represents the entire sentence.

Simple RNNs have a fatal flaw: they can't remember long-range dependencies. When processing long sentences, the gradient signal gets weaker and weaker as it propagates backward through time. This is called the vanishing gradient problem.

python
# Example: Long sentence
sentence = "The chef who trained in Paris and later opened a restaurant in Tokyo makes amazing sushi"

# Simple RNN struggles to connect "chef" with "makes"
# By the time it reaches "makes", it has mostly forgotten about "chef"
# The gradient from "makes" barely reaches back to "chef"

This is where LSTMs come to the rescue.

LSTMs solve the vanishing gradient problem through a clever architecture with gates that control information flow. Think of gates as smart filters that decide what to remember, what to forget, and what to output.

  1. Forget Gate: Decides what information to throw away from the cell state. "Should I forget that we're talking about a chef?"
  2. Input Gate: Decides what new information to store in the cell state. "Should I remember that we're now talking about sushi?"
  3. Output Gate: Decides what to output based on the cell state. "What information is relevant for the next step?"

Here's a simplified view of how LSTM processes one word:

python
def lstm_cell(word_embedding, prev_hidden, prev_cell_state):
    """
    Simplified LSTM cell (actual implementation is more complex)
    """
    # Forget gate: What to forget from previous cell state?
    forget_gate = sigmoid(W_f @ [word_embedding, prev_hidden])
    # Values between 0 (forget everything) and 1 (keep everything)
    
    # Input gate: What new information to add?
    input_gate = sigmoid(W_i @ [word_embedding, prev_hidden])
    candidate = tanh(W_c @ [word_embedding, prev_hidden])
    
    # Update cell state
    cell_state = forget_gate * prev_cell_state + input_gate * candidate
    
    # Output gate: What to output?
    output_gate = sigmoid(W_o @ [word_embedding, prev_hidden])
    hidden_state = output_gate * tanh(cell_state)
    
    return hidden_state, cell_state

Why LSTMs Work: The cell state acts like a "memory highway" that information can flow through with minimal modification. Gates can choose to let information pass unchanged, solving the vanishing gradient problem.

A regular LSTM only reads text left-to-right. But humans understand language by considering context from both directions. Consider this sentence:

"The bank was steep and covered with grass."

Is "bank" a financial institution or a riverbank? You need to read ahead to "steep" and "grass" to know. This is why Bidirectional LSTMs are so powerful—they process the sequence in both directions simultaneously.

The BiLSTM creates two hidden states for each word:

  • Forward hidden state: Encodes everything from the start up to this word
  • Backward hidden state: Encodes everything from the end back to this word

These are concatenated to give each word full context from both directions.

Let's build a complete BiLSTM classifier step by step. We'll classify customer service queries into intents (like "cancel_flight", "book_hotel", etc.).

Before we can feed text into a neural network, we need to convert words to numbers. This involves building a vocabulary—a mapping from words to integer indices.

python
class Vocabulary:
    def __init__(self, min_freq=2, max_vocab=10000):
        # Special tokens
        self.word2idx = {"<PAD>": 0, "<UNK>": 1}
        self.idx2word = {0: "<PAD>", 1: "<UNK>"}
        self.min_freq = min_freq
        self.max_vocab = max_vocab
    
    def build_from_texts(self, texts):
        """Build vocabulary from training texts"""
        from collections import Counter
        
        # Count word frequencies
        word_counts = Counter()
        for text in texts:
            words = text.lower().split()
            word_counts.update(words)
        
        # Keep words that appear at least min_freq times
        valid_words = [
            word for word, count in word_counts.items()
            if count >= self.min_freq
        ]
        
        # Sort by frequency and take top max_vocab
        valid_words = sorted(
            valid_words,
            key=lambda w: word_counts[w],
            reverse=True
        )[:self.max_vocab - 2]  # -2 for PAD and UNK
        
        # Assign indices (starting from 2)
        for idx, word in enumerate(valid_words, start=2):
            self.word2idx[word] = idx
            self.idx2word[idx] = word

# Example usage
texts = [
    "I want to cancel my flight",
    "Can you help me book a hotel",
    "I need to cancel my reservation"
]

vocab = Vocabulary(min_freq=1, max_vocab=100)
vocab.build_from_texts(texts)

print(f"Vocabulary size: {len(vocab.word2idx)}")
print(f"Sample mappings: {list(vocab.word2idx.items())[:10]}")

Important: Always build vocabulary only from the training set, never from validation or test data. This prevents data leakage.

Neural networks require fixed-size inputs, but sentences have variable lengths. We solve this with padding—adding special tokens to make all sequences the same length.

python
def encode_text(text, vocab, max_len=50):
    """Convert text to padded sequence of indices"""
    words = text.lower().split()
    
    # Convert words to indices (use 1 for unknown words)
    indices = [
        vocab.word2idx.get(word, 1)  # 1 is <UNK>
        for word in words
    ]
    
    # Truncate if too long
    if len(indices) > max_len:
        indices = indices[:max_len]
    
    # Pad if too short (0 is <PAD>)
    if len(indices) < max_len:
        indices = indices + [0] * (max_len - len(indices))
    
    return indices

# Example
text = "I want to cancel my flight"
encoded = encode_text(text, vocab, max_len=10)
print(f"Original: {text}")
print(f"Encoded: {encoded}")
print(f"Length: {len(encoded)}")

# Output:
# Original: I want to cancel my flight
# Encoded: [5, 8, 3, 12, 7, 15, 0, 0, 0, 0]
# Length: 10

Now we need to convert word indices to dense vectors. Unlike frozen embeddings, we'll learn these embeddings from scratch during training. This allows the model to learn task-specific word representations.

python
import torch
import torch.nn as nn

# Create embedding layer
vocab_size = 10000
embed_dim = 128

embedding = nn.Embedding(
    num_embeddings=vocab_size,
    embedding_dim=embed_dim,
    padding_idx=0  # Don't update embeddings for padding tokens
)

# Example: Convert word indices to embeddings
word_indices = torch.tensor([[5, 8, 3, 12, 7, 15, 0, 0, 0, 0]])
embedded = embedding(word_indices)

print(f"Input shape: {word_indices.shape}")  # (1, 10)
print(f"Output shape: {embedded.shape}")     # (1, 10, 128)
print(f"Each word is now a {embed_dim}-dimensional vector!")

The embedding layer is essentially a lookup table. Each word index maps to a learnable vector. During training, backpropagation updates these vectors to be more useful for the task.

Now we can build the complete BiLSTM classifier:

python
class BiLSTMClassifier(nn.Module):
    def __init__(self, vocab_size, n_classes, embed_dim=128, 
                 hidden_dim=256, num_layers=2, dropout=0.3):
        super().__init__()
        
        # 1. Embedding layer (learnable word vectors)
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim,
            padding_idx=0
        )
        
        # 2. Bidirectional LSTM
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,  # Input shape: (batch, seq_len, features)
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True  # This is the key!
        )
        
        # 3. Classification head
        # BiLSTM outputs hidden_dim * 2 (forward + backward)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim * 2, n_classes)
    
    def forward(self, x):
        """
        Args:
            x: word indices, shape (batch, seq_len)
        Returns:
            logits, shape (batch, n_classes)
        """
        # 1. Embed words: (batch, seq_len) -> (batch, seq_len, embed_dim)
        embedded = self.embedding(x)
        
        # 2. Pass through BiLSTM
        # lstm_out shape: (batch, seq_len, hidden_dim * 2)
        lstm_out, (h_n, c_n) = self.lstm(embedded)
        
        # 3. Take final hidden state (last time step)
        # This represents the entire sequence
        final_hidden = lstm_out[:, -1, :]  # (batch, hidden_dim * 2)
        
        # 4. Classify
        out = self.dropout(final_hidden)
        logits = self.fc(out)  # (batch, n_classes)
        
        return logits

Let's trace through what happens to a single sentence:

python
# Example: "I want to cancel my flight"
# After encoding and padding: [5, 8, 3, 12, 7, 15, 0, 0, 0, 0]

# Step 1: Embedding
# Input: (1, 10) - batch of 1, sequence length 10
# Output: (1, 10, 128) - each word is now a 128-dim vector

# Step 2: BiLSTM
# The LSTM processes this sequence in both directions:
# - Forward LSTM: reads [5, 8, 3, 12, 7, 15, 0, 0, 0, 0]
# - Backward LSTM: reads [0, 0, 0, 0, 15, 7, 12, 3, 8, 5]
# Output: (1, 10, 512) - 256 from forward + 256 from backward

# Step 3: Take final hidden state
# We take the last time step: lstm_out[:, -1, :]
# Output: (1, 512) - this represents the entire sentence

# Step 4: Classification
# Linear layer: (1, 512) -> (1, 151) for 151 classes
# Each value is a score for that class

Training a BiLSTM is similar to training any neural network, but with some sequence-specific considerations:

python
import torch.optim as optim

# Initialize model
model = BiLSTMClassifier(
    vocab_size=10000,
    n_classes=151,
    embed_dim=128,
    hidden_dim=256,
    num_layers=2,
    dropout=0.3
)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

# Training loop
for epoch in range(30):
    model.train()
    total_loss = 0
    
    for batch_x, batch_y in train_loader:
        # batch_x shape: (batch_size, seq_len)
        # batch_y shape: (batch_size,)
        
        # Forward pass
        logits = model(batch_x)
        loss = criterion(logits, batch_y)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    # Validation
    model.eval()
    with torch.no_grad():
        val_logits = model(X_val)
        val_preds = val_logits.argmax(dim=1)
        val_acc = (val_preds == y_val).float().mean()
    
    print(f"Epoch {epoch+1}: Loss={total_loss:.4f}, Val Acc={val_acc:.4f}")

Let's compare the two approaches on our negation example:

AspectFrozen EmbeddingsBiLSTM
ProcessingEntire sentence → single vectorWord by word → sequence of vectors
Word orderLost (bag of words)Preserved (sequential)
Negation handlingPoor ("don't" gets averaged out)Good ("don't" affects subsequent words)
ContextGlobal onlyLocal + global (bidirectional)
EmbeddingsFixed (pretrained)Learned (task-specific)
Parameters~100K (classifier only)~2-5M (embeddings + LSTM + classifier)
Training timeFast (minutes)Slower (hours)
Typical accuracy85-88%88-92%
python
# Sentence: "I don't want to cancel"

# Forward LSTM processing:
# "I"     -> hidden_state_1 (knows about "I")
# "don't" -> hidden_state_2 (knows about "I don't")
# "want"  -> hidden_state_3 (knows about "I don't want")
# "to"    -> hidden_state_4 (knows about "I don't want to")
# "cancel"-> hidden_state_5 (knows about "I don't want to cancel")

# Backward LSTM processing:
# "cancel"-> hidden_state_1 (knows about "cancel")
# "to"    -> hidden_state_2 (knows about "to cancel")
# "want"  -> hidden_state_3 (knows about "want to cancel")
# "don't" -> hidden_state_4 (knows about "don't want to cancel")
# "I"     -> hidden_state_5 (knows about "I don't want to cancel")

# Final representation: concatenate forward_5 + backward_5
# This captures the full context with negation preserved!

BiLSTMs have several important hyperparameters that significantly affect performance:

  • Too small (32-64): Words can't capture enough semantic information
  • Sweet spot (128-256): Good balance of expressiveness and efficiency
  • Too large (512+): Overfitting, slower training, diminishing returns
  • Too small (64-128): Can't capture complex patterns
  • Sweet spot (256-512): Sufficient capacity for most tasks
  • Too large (1024+): Overfitting, memory issues
  • 1 layer: Simple patterns only
  • 2 layers: Good for most tasks (recommended starting point)
  • 3+ layers: Deeper hierarchies, but harder to train
python
# Analyze your data first!
import numpy as np

sentence_lengths = [len(text.split()) for text in train_texts]
print(f"Mean length: {np.mean(sentence_lengths):.1f}")
print(f"95th percentile: {np.percentile(sentence_lengths, 95):.0f}")
print(f"Max length: {np.max(sentence_lengths)}")

# Set max_len to cover 95-99% of sentences
# For CLINC150 dataset: max_len=50 covers 99% of sentences
python
# ❌ Wrong: Unidirectional LSTM
lstm = nn.LSTM(embed_dim, hidden_dim, bidirectional=False)
# Output: (batch, seq_len, hidden_dim)
# Only sees left context!

# ✅ Correct: Bidirectional LSTM
lstm = nn.LSTM(embed_dim, hidden_dim, bidirectional=True)
# Output: (batch, seq_len, hidden_dim * 2)
# Sees both left and right context!
python
# ❌ Wrong: Using h_n (final layer's hidden state)
lstm_out, (h_n, c_n) = self.lstm(embedded)
final_hidden = h_n[-1]  # Only gets last layer, one direction

# ✅ Correct: Using last time step from lstm_out
lstm_out, (h_n, c_n) = self.lstm(embedded)
final_hidden = lstm_out[:, -1, :]  # Gets both directions concatenated
python
# ❌ Wrong: Padding tokens get updated during training
embedding = nn.Embedding(vocab_size, embed_dim)
# Padding tokens (index 0) will have non-zero gradients!

# ✅ Correct: Padding tokens stay at zero
embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
# Padding tokens don't contribute to gradients

On a typical intent classification dataset (like CLINC150 with 151 classes):

ModelAccuracyParametersTraining TimeInference Speed
TF-IDF + Logistic Regression82-85%~150K< 1 minVery fast
Frozen Embeddings + MLP85-88%~100K2-5 minFast
BiLSTM (this approach)88-92%2-5M30-60 minMedium
BERT (fine-tuned)93-96%110M2-4 hoursSlow

Sweet Spot: BiLSTM offers a great balance—significantly better than simple baselines, much faster than transformers, and still very interpretable.
  • Word order and sequence structure are critical (negation, temporal relationships)
  • You have moderate amounts of training data (10K+ examples)
  • You need better accuracy than bag-of-words but can't afford transformer training time
  • Interpretability matters (you can visualize attention over time steps)
  • You're working with sequences of moderate length (< 100 tokens)
  • You have very little data (< 1K examples) → use frozen embeddings
  • You need state-of-the-art accuracy and have compute budget → use transformers
  • Sequences are very long (> 500 tokens) → LSTMs struggle with very long sequences
  • Real-time inference is critical → simpler models are faster

When sentences have very different lengths, you can use packed sequences to avoid wasting computation on padding:

python
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

def forward(self, x, lengths):
    """Forward pass with packed sequences"""
    embedded = self.embedding(x)
    
    # Pack sequences (removes padding)
    packed = pack_padded_sequence(
        embedded, lengths, batch_first=True, enforce_sorted=False
    )
    
    # LSTM only processes actual words, not padding
    packed_out, (h_n, c_n) = self.lstm(packed)
    
    # Unpack back to padded format
    lstm_out, _ = pad_packed_sequence(packed_out, batch_first=True)
    
    # Rest of forward pass...
    final_hidden = lstm_out[:, -1, :]
    return self.fc(self.dropout(final_hidden))

Instead of just using the final hidden state, you can use attention to weight all time steps:

python
class BiLSTMWithAttention(nn.Module):
    def __init__(self, vocab_size, n_classes, embed_dim=128, hidden_dim=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, bidirectional=True, batch_first=True)
        
        # Attention layer
        self.attention = nn.Linear(hidden_dim * 2, 1)
        self.fc = nn.Linear(hidden_dim * 2, n_classes)
    
    def forward(self, x):
        embedded = self.embedding(x)
        lstm_out, _ = self.lstm(embedded)  # (batch, seq_len, hidden*2)
        
        # Compute attention weights
        attention_scores = self.attention(lstm_out)  # (batch, seq_len, 1)
        attention_weights = torch.softmax(attention_scores, dim=1)
        
        # Weighted sum of all hidden states
        context = (lstm_out * attention_weights).sum(dim=1)  # (batch, hidden*2)
        
        return self.fc(context)

You can initialize embeddings with pretrained vectors (like GloVe or Word2Vec) instead of random initialization:

python
# Load pretrained embeddings (e.g., GloVe)
pretrained_embeddings = load_glove_embeddings()  # Shape: (vocab_size, embed_dim)

# Initialize embedding layer with pretrained weights
embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))

# Optionally freeze embeddings (don't update during training)
embedding.weight.requires_grad = False  # Frozen
# OR
embedding.weight.requires_grad = True   # Fine-tune

BiLSTM networks represent a significant step up from bag-of-words and frozen embedding approaches. By processing text sequentially and bidirectionally, they capture the compositional nature of language—understanding that "I don't want to cancel" is fundamentally different from "I want to cancel."

Key takeaways:

  • LSTMs solve the vanishing gradient problem through gating mechanisms
  • Bidirectional processing gives each word full context from both directions
  • Learned embeddings allow task-specific word representations
  • Sequential processing preserves word order and handles negations correctly
  • BiLSTMs offer a sweet spot between simple baselines and heavy transformers

While transformers have largely replaced LSTMs in state-of-the-art NLP, BiLSTMs remain valuable for understanding sequence modeling fundamentals and for practical applications where compute budget is limited.

Next Steps: To deepen your understanding, explore attention mechanisms, transformer architectures, and how they build upon the sequential processing ideas introduced by RNNs and LSTMs.

Related Articles