Building a BiLSTM Intent Classifier in PyTorch: Vocab, Packing, and Pooling
The practical layer between sequence-model theory and a working PyTorch classifier — building a vocabulary, batching variable-length sentences, packing sequences for the LSTM, and pooling the output back into a single vector.
Most tutorials on sequence models stop at the architecture diagram. You get a clean picture of a BiLSTM reading a sentence left-to-right and right-to-left, and the explanation ends there. Then you sit down to actually build one and immediately hit a wall: your sentences are different lengths, PyTorch wants tensors, the LSTM is faster if you tell it where the padding is, and somewhere along the way you have to turn a sequence of hidden states back into a single vector for classification.
This post is about that middle layer. The plumbing. The part that turns a clean theoretical model into something that trains in 30 seconds per epoch instead of 5 minutes, and that doesn't silently scramble your labels through a subtle indexing bug. We'll work through the full pipeline for a word-level BiLSTM intent classifier — vocabulary, dataset, collate function, packing, and pooling — and explain why each piece exists.
What you should already know
An MLP on frozen sentence embeddings is simple to feed: one sentence in, one 384-dimensional vector out, classify. The sentence transformer does the heavy lifting before your model sees anything. You can ignore words and word order entirely because they've already been baked into the vector.
A BiLSTM sees the words. That changes everything about the input pipeline:
- Words are strings, but neural networks need integers. You need a vocabulary — a deterministic mapping from word to integer ID.
- Sentences have different lengths, but a tensor in a batch must be rectangular. You need padding.
- Padded positions are fake — they shouldn't influence the model. You need packing or masking.
- An LSTM emits a hidden state at every timestep, but a classifier wants one vector per sentence. You need a pooling strategy.
Each of these is a small problem in isolation, but they interact. Get the order wrong — say, pad and then forget to mask — and your model trains on noise. The rest of the post is one solution per problem, in the order they show up.
A vocabulary is a dictionary word -> int. Every word you want your model to recognize gets a unique integer ID. The embedding layer then uses that integer to look up a learned vector. If you haven't seen the tokenization primer, the short version is: split the text on whitespace, lowercase it, count word frequencies, and assign IDs to the most common ones.
Neural networks are matrices. nn.Embedding(vocab_size, embed_dim) is literally a (vocab_size, embed_dim) weight matrix. To get the vector for the word flight, you index into row 234 (or whichever row you assigned). Strings have no order and no row index — integers do.
Before any real word goes into the vocabulary, two slots are reserved:
| Token | Index | Purpose |
|---|---|---|
<pad> | 0 | Fills empty positions when sentences in a batch have different lengths. Its embedding is forced to zero and never receives gradient updates. |
<unk> | 1 | Stands in for any word that wasn't in the training data. Without it, your model crashes the first time it sees a new word at inference time. |
<pad> at index 0 is a convention worth following — nn.Embedding takes a padding_idx argument that pins that row to zero, and most utility code in PyTorch defaults to 0 as the pad value.
The vocabulary is built from training texts. If you include validation or test words, you've leaked information. A real deployment sees brand-new words constantly, so simulating that by mapping unseen words to <unk> is the point.
from collections import Counter
class Vocabulary:
PAD, UNK = "<pad>", "<unk>"
PAD_IDX, UNK_IDX = 0, 1
def __init__(self, word2idx: dict[str, int]):
self.word2idx = word2idx
self.idx2word = {i: w for w, i in word2idx.items()}
@classmethod
def build(cls, texts, max_size=20_000, min_freq=2):
counter = Counter()
for text in texts:
counter.update(text.lower().split())
# Reserve indices 0 and 1 for the special tokens.
word2idx = {cls.PAD: cls.PAD_IDX, cls.UNK: cls.UNK_IDX}
for word, freq in counter.most_common(max_size - 2):
if freq < min_freq:
break
word2idx[word] = len(word2idx)
return cls(word2idx)
def encode(self, text: str) -> list[int]:
return [self.word2idx.get(w, self.UNK_IDX)
for w in text.lower().split()]Two knobs decide the size of the vocabulary. max_size caps the total number of words — useful because the embedding matrix grows linearly with this number. min_freq filters out words that appeared only once or twice in training; these are almost always typos, names, or rare items that the model can't learn anything useful about. Mapping them to <unk> is the honest move.
Save the vocabulary
flight moves from index 234 to 198, the old embedding matrix becomes garbage. Persist word2idx to disk (JSON works) alongside your model checkpoint.PyTorch has two abstractions for handling data: Dataset and DataLoader. They're independent of any model. Once you wire them up, the same code pattern works for images, audio, tabular data, or text.
A Dataset only needs two methods: __len__ (how many examples?) and __getitem__(idx) (give me example number idx). That's the entire interface.
from torch.utils.data import Dataset
class IntentDataset(Dataset):
def __init__(self, texts, labels, vocab, max_seq_len=64):
self.examples = []
for text, label in zip(texts, labels):
ids = vocab.encode(text)[:max_seq_len] # truncate long ones
self.examples.append((ids, label))
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ids, label = self.examples[idx]
return ids, labelNotice what's NOT here: padding, batching, conversion to tensors. The Dataset returns a Python list of integers and a Python int. Single example, raw. The DataLoader will handle the rest.
Wrap the Dataset in a DataLoader and you get batching, shuffling, and optional multi-process loading. The default behavior is to stack each item with torch.stack, which assumes every item is the same shape. For variable-length text, it isn't — so you provide a collate_fn that controls how the batch gets assembled.
from torch.utils.data import DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=64,
shuffle=True, # only True for training
collate_fn=collate_fn,
)
val_loader = DataLoader(val_dataset, batch_size=64,
shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64,
shuffle=False, collate_fn=collate_fn)Shuffle once, in the right place
shuffle=True only on the training loader. Validation and test loaders should be deterministic — otherwise your metrics will jitter between epochs and you'll waste hours chasing phantom regressions.Sentences in a batch will have different lengths — 5 tokens, 9 tokens, 22 tokens. To put them into a single tensor, the short ones get padded with the <pad> index until they match the longest. The question is: padded to what length?
Static padding pads every sentence to a fixed global maximum — say, max_seq_len = 64. Simple, but wasteful: if your batch happens to contain only short sentences, you're doing 64 timesteps of LSTM work on 90% padding.
Dynamic padding pads to the longest sentence in the current batch. A batch of mostly short sentences pads to maybe 12 tokens. A batch with one long outlier pads to 40. Across an epoch, this can cut training time in half.
import torch
def collate_fn(batch):
# batch is a list of (token_ids, label) tuples
token_ids_list, labels = zip(*batch)
lengths = torch.tensor(
[len(ids) for ids in token_ids_list], dtype=torch.long
)
max_len = lengths.max().item() # longest in THIS batch
# Zero is the pad index — see the vocab.
padded = torch.zeros(len(batch), max_len, dtype=torch.long)
for i, ids in enumerate(token_ids_list):
padded[i, :len(ids)] = torch.tensor(ids, dtype=torch.long)
labels = torch.tensor(labels, dtype=torch.long)
return padded, lengths, labelsThe function returns three tensors: padded of shape (batch, max_len), lengths of shape (batch,), and labels of shape (batch,). The lengths tensor is the key — without it, the model has no way to tell where real tokens end and padding begins.
Padding solves the shape problem but creates a compute problem. If your batch is padded to 40 timesteps and the average real length is 10, you're paying 4× the LSTM cost for nothing. Worse, the hidden state at timestep 40 of a 10-token sentence is the state after the LSTM has processed 30 padding tokens. If you use that as a sentence representation, you've corrupted the signal.
PyTorch's pack_padded_sequence solves both. It rearranges the padded tensor into a special PackedSequence object that the LSTM processes step by step, skipping padded positions automatically. The output comes back compressed in the same format; you unpack it with pad_packed_sequence to get a normal tensor again.
Here's the catch that trips up almost everyone the first time: pack_padded_sequence requires the batch to be sorted by length in descending order when enforce_sorted=True. That means you have to sort, pack, run, unpack — and then put everything back in the original order so it lines up with the labels.
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
def forward(self, x, lengths):
# 1. Sort the batch by length, descending.
sorted_lengths, sort_idx = lengths.sort(descending=True)
x_sorted = x[sort_idx]
# 2. Embed and pack.
embedded = self.embed_dropout(self.embedding(x_sorted))
packed = pack_padded_sequence(
embedded, sorted_lengths.cpu(),
batch_first=True, enforce_sorted=True,
)
# 3. Run the LSTM on packed input.
packed_out, (h_n, c_n) = self.lstm(packed)
# 4. Unpack back to (batch, max_len, hidden*2).
output, _ = pad_packed_sequence(packed_out, batch_first=True)
# 5. UNSORT — restore the original batch order.
_, unsort_idx = sort_idx.sort()
output = output[unsort_idx]
h_n = h_n[:, unsort_idx, :]
lengths_original = sorted_lengths[unsort_idx]
# 6. Pool, dropout, classify (next section).
pooled = self.pool(output, h_n, lengths_original)
return self.classifier(self.dropout(pooled))The bug you'll hit if you skip the unsort
index_select; the diagnosis can take hours.Two small details: pack_padded_sequence wants sorted_lengths.cpu() even if the rest of the tensors are on a GPU — the function uses lengths for indexing on the CPU side. And h_n (the final hidden state) is indexed differently from output: its shape is (num_layers * num_directions, batch, hidden_dim), so you unsort along dim 1.
After unpacking, you have a tensor of shape (batch, seq_len, hidden_dim * 2). The * 2 is because the BiLSTM concatenates forward and backward hidden states at every timestep. A classifier head wants a single vector per sentence, so the sequence dimension has to collapse. Two strategies are common.
Take the final hidden state of the LSTM. For a unidirectional LSTM, this is the state after reading the entire sentence — a natural summary. For a bidirectional LSTM, you want the forward direction's last state (which has read the whole sentence left-to-right) AND the backward direction's last state (which has read it right-to-left).
The shape of h_n is (num_layers * 2, batch, hidden_dim). For a 2-layer BiLSTM, that's 4 rows. The layout is [layer_0_forward, layer_0_backward, layer_1_forward, layer_1_backward]. The last two — h_n[-2] and h_n[-1] — are what you want:
def pool_last(self, h_n):
# h_n shape: (num_layers * 2, batch, hidden_dim)
h_forward = h_n[-2] # last layer, forward direction
h_backward = h_n[-1] # last layer, backward direction
return torch.cat([h_forward, h_backward], dim=1)
# result shape: (batch, hidden_dim * 2)Off-by-one indexing pitfall
h_n[0] and h_n[1] for forward and backward. That gives you layer 0, not the last layer. For a single-layer LSTM they happen to be equivalent. For 2+ layers, you've thrown away the deeper representations.Mean pooling averages the hidden states across all real timesteps. The wrinkle is that padded positions are still in the output tensor after unpacking — they're just zeros, but if you average over them you're dividing by the wrong denominator. You need a mask.
def pool_mean(self, output, lengths):
# output: (batch, seq_len, hidden_dim * 2)
# lengths: (batch,) - real lengths in original order
# Build a mask: True where the position is a real token.
seq_len = output.size(1)
mask = torch.arange(seq_len, device=output.device) \
.unsqueeze(0) < lengths.unsqueeze(1)
# mask shape: (batch, seq_len)
mask = mask.unsqueeze(2).float() # (batch, seq_len, 1)
summed = (output * mask).sum(dim=1) # (batch, hidden_dim * 2)
pooled = summed / lengths.unsqueeze(1).float()
return pooledThe mask is built by comparing each position index to the sentence's real length. Position 0, 1, 2, ... up to lengths[i] - 1 is True; everything after is False. Multiply by the mask, sum, divide by the real length, and you have an honest mean over non-padded positions.
| Strategy | Strengths | Weaknesses |
|---|---|---|
| Last hidden | Cheap, uses the LSTM's own summary, works well when the end of the sentence carries the meaning. | Sensitive to where the important word lives. A short prefix that flips intent can be muted by what follows. |
| Mean pool | Every timestep contributes equally, robust to long-range information, often more stable. | Slightly more compute, dilutes very strong single-token signals. |
On short utterances like intent classification, the two are usually within a percentage point of each other. On longer documents, mean pooling tends to be more reliable. Treat it as a hyperparameter worth flipping during your sweep.
The full forward pass — tokens to logits — looks like this:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class BiLSTMClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim,
num_layers, n_classes, dropout=0.3,
pool_strategy="last", pad_idx=0):
super().__init__()
self.pool_strategy = pool_strategy
self.embedding = nn.Embedding(vocab_size, embed_dim,
padding_idx=pad_idx)
self.embed_dropout = nn.Dropout(dropout)
self.lstm = nn.LSTM(
embed_dim, hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(hidden_dim * 2, n_classes)
def forward(self, x, lengths):
sorted_lengths, sort_idx = lengths.sort(descending=True)
x_sorted = x[sort_idx]
embedded = self.embed_dropout(self.embedding(x_sorted))
packed = pack_padded_sequence(
embedded, sorted_lengths.cpu(),
batch_first=True, enforce_sorted=True,
)
packed_out, (h_n, _) = self.lstm(packed)
output, _ = pad_packed_sequence(packed_out, batch_first=True)
_, unsort_idx = sort_idx.sort()
output = output[unsort_idx]
h_n = h_n[:, unsort_idx, :]
lengths_orig = sorted_lengths[unsort_idx]
if self.pool_strategy == "last":
pooled = torch.cat([h_n[-2], h_n[-1]], dim=1)
else: # mean pool
mask = (torch.arange(output.size(1), device=output.device)
.unsqueeze(0) < lengths_orig.unsqueeze(1))
mask = mask.unsqueeze(2).float()
pooled = (output * mask).sum(dim=1) / \
lengths_orig.unsqueeze(1).float()
return self.classifier(self.dropout(pooled))A couple of details worth noting: nn.LSTM's dropout argument only applies between stacked layers, which is why it's gated behind num_layers > 1 (passing dropout to a single-layer LSTM is a no-op and triggers a warning). And padding_idx=pad_idx on the embedding layer pins row 0 to zero and freezes it — no gradient updates, no drift.
Compared to the manual mini-batch loop you might have used with a frozen-embedding MLP, the DataLoader version is shorter. No manual shuffling, no manual slicing — the loader yields batches, you iterate.
for epoch in range(epochs):
model.train()
for x_batch, lengths_batch, y_batch in train_loader:
logits = model(x_batch, lengths_batch)
loss = criterion(logits, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
val_loss, val_acc = evaluate(model, val_loader)
if early_stopper.should_stop(val_loss):
breakThe optimizer is typically Adam with lr=1e-3 and a small weight decay, and you wrap the loop in early stopping on validation loss with a patience of 5 or so. None of that is specific to BiLSTMs — these are the same training conventions you've used for every PyTorch model.
- Accuracy hovers near random: almost always a missing unsort step after
pad_packed_sequence. Labels and predictions are misaligned. - Loss is NaN immediately: usually a learning rate problem, but check that
padding_idxis set on the embedding — otherwise the pad embedding drifts during training and can blow up. - Training is far slower than expected: you forgot to pack the sequence, or you're padding to a global
max_seq_leninstead of per-batch dynamic padding. - Validation accuracy is wildly noisy across epochs:
shuffle=Trueslipped onto the validation loader. Set it toFalse. - Inference breaks on new sentences: a word at inference is missing from the vocabulary. Make sure
encodereturnsUNK_IDXfor unknown words, notKeyError. h_nshape mismatch in pooling: you indexedh_n[0]andh_n[1]thinking they were forward/backward of the last layer. For multi-layer LSTMs useh_n[-2]andh_n[-1].
A working sequence classifier is mostly plumbing on top of a small model. The BiLSTM itself is a few lines — the work is in the pipeline around it: a vocabulary that maps strings to integers, a Dataset that yields raw token lists, a DataLoader with a collate_fn that pads dynamically, packing to skip padding inside the LSTM, the sort/unsort dance to keep labels aligned, and a pooling strategy to collapse the sequence back into a single vector.
Get this pipeline right once and almost every future sequence model — attention models, transformers, even speech recognizers — reuses the same shapes. Tokens go in, padded tensors flow through a model that knows how to ignore padding, and a pooled vector comes out for the head to classify.
Next step
last to mean and rerun. The accuracy delta — and where in the loss curves it shows up — will teach you more about your data than any blog post.Related Articles
From Words to Intelligence: Building an MLP Classifier on Pretrained Sentence Embeddings
A deep dive into pretrained sentence embeddings, MLP architecture, BatchNorm, Dropout, Adam, and early stopping — with full PyTorch implementation.
Understanding Transformers: The Architecture Behind Modern AI
A comprehensive guide to understanding the Transformer architecture that powers GPT, BERT, and other modern language models.
Transfer Learning in NLP: Standing on the Shoulders of Giants
A complete beginner's guide to transfer learning in NLP - how pretrained models work, why freezing encoders makes sense, and how to use sentence transformers effectively.