
From Words to Intelligence: Building an MLP Classifier on Pretrained Sentence Embeddings
A deep dive into pretrained sentence embeddings, MLP architecture, BatchNorm, Dropout, Adam, and early stopping — with full PyTorch implementation.
Imagine you want to teach a computer to understand what someone means when they type a sentence — not just match keywords, but actually understand. A phrase like 'book me a flight' and 'reserve a plane ticket' mean exactly the same thing, yet share almost no words. Classic approaches like TF-IDF fail here completely. In this post, we'll build a system that genuinely handles this, by combining pretrained sentence embeddings with a multi-layer perceptron (MLP) built in PyTorch. Along the way, we'll unpack every building block from scratch: what embeddings are, why hidden layers matter, how BatchNorm and Dropout prevent failure modes, why Adam beats plain SGD, and how early stopping keeps your model honest.
What We're Building
An intent classifier for natural language. Given a sentence like 'What's the weather in Tokyo?', our model will output a label like 'weather_query'. We'll use a frozen pretrained transformer to turn sentences into vectors, then train a small MLP on top of those vectors. Expected accuracy: >90% on 151 intent classes.
Before we talk about what we're building, let's understand what we're replacing and why. TF-IDF (Term Frequency–Inverse Document Frequency) represents a sentence as a sparse vector where each dimension corresponds to a word in the vocabulary. A sentence with 10,000 possible vocabulary words becomes a vector of 10,000 numbers, most of which are zero.
The problems are fundamental, not incidental. First, TF-IDF is completely blind to meaning. The words 'book', 'reserve', and 'schedule' all have completely different TF-IDF dimensions, so the model has no idea they're related. Second, word order is lost entirely. 'Dog bites man' and 'Man bites dog' produce identical TF-IDF vectors. Third, unseen words are invisible. If a user types a word not in your training vocabulary, it vanishes. A logistic regression on top of TF-IDF can only draw straight lines through this broken space — its accuracy ceiling is around 78% on hard intent datasets.
| Problem | TF-IDF Behavior | Impact |
|---|---|---|
| Synonyms | 'book' and 'reserve' are unrelated dimensions | Fails to generalize between paraphrases |
| Word order | 'dog bites man' = 'man bites dog' | Can't distinguish meaning from word arrangement |
| Unseen words | Out-of-vocabulary words are dropped | New phrasing causes silent failures |
| Sparsity | 99%+ of dimensions are zero | High memory use, poor statistical efficiency |
| No context | Each word scored independently | No understanding of phrase-level meaning |
A sentence embedding is a dense vector — typically 384 or 768 numbers — that captures the meaning of a sentence, not just its words. Think of it as a GPS coordinate in meaning-space: sentences that mean similar things end up close together, regardless of the exact words used.
The Intuition: A Map of Meaning
Imagine plotting every possible sentence on a giant map. 'Book a flight' and 'Reserve a plane ticket' would be plotted millimetres apart. 'What's the weather?' would be in a completely different neighbourhood. A sentence embedding is just a coordinate on that map. The miracle is that someone else (Hugging Face, Google, etc.) already built the map for us — by training on billions of sentences.
The model we'll use is all-MiniLM-L6-v2, a small but highly capable sentence transformer from the sentence-transformers library. 'MiniLM' means it's a distilled (compressed) version of a larger model. 'L6' means it has 6 transformer layers. 'v2' is the second version. It produces 384-dimensional embeddings, weighs only ~80MB, and runs fast even on CPU. It was trained using a technique called contrastive learning — pushed to make semantically similar sentences close together in the 384-dimensional space.
from sentence_transformers import SentenceTransformer
import torch
# Load the pretrained model
model = SentenceTransformer('all-MiniLM-L6-v2')
# Encode some sentences
sentences = [
'Book a flight to New York',
'Reserve a plane ticket to NYC',
'What is the weather like today?',
'Will it rain tomorrow?'
]
# Each sentence becomes a 384-dimensional vector
embeddings = model.encode(sentences, convert_to_tensor=True)
print(f'Shape: {embeddings.shape}') # (4, 384)
# Measure cosine similarity between sentences
from torch.nn.functional import cosine_similarity
sim_flight_reserve = cosine_similarity(
embeddings[0].unsqueeze(0),
embeddings[1].unsqueeze(0)
)
sim_flight_weather = cosine_similarity(
embeddings[0].unsqueeze(0),
embeddings[2].unsqueeze(0)
)
print(f'Similarity (flight vs reserve): {sim_flight_reserve.item():.3f}') # ~0.85 — very similar
print(f'Similarity (flight vs weather): {sim_flight_weather.item():.3f}') # ~0.15 — very differentTransfer Learning in Action
This is transfer learning: someone spent months and millions of dollars training a model on billions of sentences. We download the result for free and use those learned representations as our input features. We never update the weights of this encoder — we 'freeze' it and only train the small classifier on top.
Freezing means we don't compute gradients through the sentence transformer — its weights stay exactly as they were when we downloaded it. There are two strong reasons for this. First, the encoder is already excellent: it was trained on billions of sentences; we have ~15,000. Fine-tuning it on such a small dataset would make it worse, not better (a phenomenon called catastrophic forgetting). Second, it's dramatically cheaper: computing gradients through 6 transformer layers is expensive. By freezing it, our training loop only needs to update the small MLP weights — which is fast even on CPU.
Encoding 15,000 sentences through a transformer takes a few minutes. If you recompute embeddings on every training run, you're wasting that time every single time. Since the encoder is frozen, the embeddings never change — so we compute them once and save them to disk.
"""Sentence embedding utilities with disk caching."""
from __future__ import annotations
from pathlib import Path
import torch
from torch import Tensor
def encode_texts(
texts: list[str],
model_name: str = "all-MiniLM-L6-v2",
batch_size: int = 256,
show_progress: bool = True,
) -> Tensor:
"""Encode texts to dense embeddings using a sentence-transformer.
Returns:
Tensor of shape (len(texts), 384), dtype float32
"""
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
# convert_to_numpy=True is faster; we convert to tensor after
embeddings_np = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=show_progress,
convert_to_numpy=True,
)
return torch.tensor(embeddings_np, dtype=torch.float32)
def load_or_encode(
texts: list[str],
split: str, # 'train', 'val', 'test'
cache_dir: str = "runs/embeddings",
model_name: str = "all-MiniLM-L6-v2",
) -> Tensor:
"""Return cached embeddings if they exist, otherwise encode and cache."""
cache_dir_path = Path(cache_dir)
# Build a unique filename per (model, split) combination
safe_model_name = model_name.replace("/", "_")
cache_path = cache_dir_path / f"{safe_model_name}_{split}.pt"
if cache_path.exists():
print(f"Loading cached embeddings from {cache_path}")
return torch.load(cache_path, weights_only=True)
# Cache miss — compute and save
print(f"Computing embeddings for {split} split ({len(texts)} sentences)...")
embeddings = encode_texts(texts, model_name=model_name)
cache_dir_path.mkdir(parents=True, exist_ok=True)
torch.save(embeddings, cache_path)
print(f"Saved embeddings to {cache_path}")
return embeddingsHow torch.save / torch.load Work
torch.save() serializes a tensor (or any Python object) to a binary file using pickle under the hood. torch.load() reads it back. For tensors, this preserves the exact dtype, shape, and values. The weights_only=True flag (PyTorch 2.0+) is a security measure — it refuses to unpickle arbitrary Python objects, only tensors.
Every neural network in PyTorch is a subclass of nn.Module. Think of nn.Module as a smart container that does several important things automatically: it tracks all the learnable parameters in your model so the optimizer can find them, it handles switching between training and evaluation modes, and it lets you save and load the entire model state with a single call.
The basic pattern has two parts: __init__ (where you define your layers) and forward (where you describe how data flows through them). Here's the minimal skeleton:
import torch
import torch.nn as nn
class TinyNet(nn.Module):
def __init__(self):
super().__init__() # Always call this first — sets up nn.Module internals
# Define layers here. nn.Module automatically 'sees' any nn.* assigned to self
self.layer1 = nn.Linear(10, 5)
self.layer2 = nn.Linear(5, 2)
self.relu = nn.ReLU()
def forward(self, x): # Called when you do: output = model(input)
x = self.relu(self.layer1(x))
x = self.layer2(x)
return x
model = TinyNet()
# nn.Module automatically found all parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params}') # (10*5 + 5) + (5*2 + 2) = 67
# Switch modes
model.train() # Enables dropout, batch norm in training mode
model.eval() # Disables dropout, uses running stats in batch norm
# Save and load weights
torch.save(model.state_dict(), 'model.pt')
model.load_state_dict(torch.load('model.pt'))super().__init__() is Mandatory
Forgetting super().__init__() in your nn.Module subclass causes a cryptic error. It initializes the internal bookkeeping that lets PyTorch track your parameters. Always put it as the first line of __init__.
A logistic regression — or a neural network with no hidden layers — can only draw straight lines (or flat hyperplanes in high dimensions) to separate classes. This works fine when classes are linearly separable, but real data almost never is.
Think about the classic XOR problem. Four points: (0,0)→0, (0,1)→1, (1,0)→1, (1,1)→0. No single straight line can separate the 0s from the 1s. But a hidden layer can draw two lines and combine them — suddenly XOR is solvable. The same principle applies to intent classification: the boundaries between 151 intent classes in 384-dimensional embedding space are curved and complex. Hidden layers give the model the power to learn those curves.
Here's the catch: stacking multiple nn.Linear layers without anything in between is mathematically equivalent to a single linear layer. No matter how many layers you stack, a linear-of-linear is still linear. You need nonlinear activations between layers to break this equivalence.
The most popular activation today is ReLU (Rectified Linear Unit): f(x) = max(0, x). It's dead simple — if the input is positive, pass it through unchanged; if negative, output zero. Despite its simplicity, ReLU has several advantages over older activations like sigmoid or tanh: it doesn't saturate for positive inputs (avoiding the vanishing gradient problem), it's computationally trivial, and empirically it trains faster.
import torch
import torch.nn as nn
# What ReLU does, element-wise
x = torch.tensor([-3.0, -1.0, 0.0, 1.0, 3.0])
relu = nn.ReLU()
print(relu(x)) # tensor([0., 0., 0., 1., 3.])
# Why stacking Linear without activation is pointless
W1 = torch.randn(4, 3) # First linear layer weights
W2 = torch.randn(2, 4) # Second linear layer weights
# Two layers multiplied together collapse into one
W_combined = W2 @ W1 # Shape: (2, 3) — same as a single layer!
# With ReLU in between, they CAN'T collapse — the model can learn curves
net_with_nonlinearity = nn.Sequential(
nn.Linear(3, 4),
nn.ReLU(), # <--- This is the key ingredient
nn.Linear(4, 2)
)As data flows through many layers of a neural network, the distribution of activations (the numbers at each layer) can drift wildly — some layers might produce values in the thousands, others near zero. This makes training unstable: the gradients become tiny (vanishing) or enormous (exploding), and the model struggles to learn.
Batch Normalization solves this by normalizing the activations within each mini-batch — it subtracts the batch mean and divides by the batch standard deviation, so the activations have approximately mean=0 and variance=1. Then it applies learned scale (γ) and shift (β) parameters to let the model restore any distribution it needs.
The Coffee Shop Analogy
Imagine you're a barista and cups of coffee come to you at wildly inconsistent temperatures — some boiling, some cold. Batch norm is like a temperature regulator: it standardizes each batch of cups to a consistent temperature before they reach you. You (the next layer) can then focus on your job instead of constantly adjusting for the extreme variations.
import torch
import torch.nn as nn
# Imagine a batch of 4 samples, each with 3 features
batch = torch.tensor([
[100.0, 0.001, 50.0], # Sample 1: huge scale differences
[200.0, 0.002, 75.0], # Sample 2
[ 50.0, 0.003, 25.0], # Sample 3
[150.0, 0.004, 60.0], # Sample 4
])
batch_norm = nn.BatchNorm1d(num_features=3) # 3 features
batch_norm.eval() # Use batch statistics directly for this demo
normed = batch_norm(batch)
print("Before BatchNorm - std per feature:", batch.std(dim=0))
# tensor([55.90, 0.0013, 20.41]) <- huge variance differences!
print("After BatchNorm - std per feature: ", normed.std(dim=0).detach())
# tensor([~1.15, ~1.15, ~1.15]) <- roughly uniform now
# Critical: BatchNorm behaves differently in train vs eval mode!
# During training: uses the CURRENT batch's mean and std
# During eval/test: uses RUNNING AVERAGES accumulated during training
# >>> model.train() before training, model.eval() before inference — ALWAYS <<<The Most Common Bug with BatchNorm
Forgetting to call model.eval() before validation or testing is the #1 source of mysteriously bad test accuracy. During training mode, BatchNorm uses the current mini-batch's statistics. During eval mode, it uses stable running averages. If you evaluate in train mode, your 'validation accuracy' is computed with noisy, batch-dependent statistics — it's meaningless.
A neural network trained long enough on a fixed dataset will start to overfit — it memorizes the training examples rather than learning generalizable patterns. Its training accuracy climbs toward 100%, while validation accuracy stagnates or falls.
Dropout is a remarkably simple fix: during each training step, randomly zero out a fraction of the neurons in a layer. With dropout=0.3, each neuron has a 30% chance of being silenced on any given forward pass. The remaining active neurons are scaled up to compensate, so the expected sum stays constant.
The Team Analogy for Dropout
Imagine training a football team where, at every practice session, 30% of players are randomly sent home. The remaining players can't rely on any specific teammate being there, so each player *must* learn to be useful on their own. The team becomes more robust because no single player becomes a crutch. Similarly, dropout forces neurons to learn independently useful features rather than co-adapting to each other.
import torch
import torch.nn as nn
dropout = nn.Dropout(p=0.3) # 30% of neurons zeroed out
x = torch.ones(1, 10) # 10 neurons all active
# During training: random neurons are zeroed
dropout.train()
out_train = dropout(x)
print("Training mode:", out_train)
# e.g. tensor([[1.4286, 0., 1.4286, 1.4286, 0., 1.4286, 0., 1.4286, 1.4286, 1.4286]])
# Notice: active neurons are scaled by 1/(1-0.3) ≈ 1.43 to keep expected sum equal
# During inference: ALL neurons are active, no scaling
dropout.eval()
out_eval = dropout(x)
print("Eval mode: ", out_eval)
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]) <- all ones, no dropout applied
# KEY RULE: Dropout only fires during training.
# Call model.train() before training steps, model.eval() before validation/test.Now we combine all the pieces. Our MLP takes a 384-dimensional embedding as input and outputs logits for 151 classes. Between input and output, we stack blocks of [Linear → BatchNorm → ReLU → Dropout]. The key design decision is making the number of hidden layers dynamic — driven by a config list like [256, 128] — so we can experiment without rewriting code.
"""MLP on frozen sentence embeddings."""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch import Tensor
@dataclass
class MLPConfig:
hidden_dims: list[int] # e.g. [256, 128] — one entry per hidden layer
dropout: float = 0.3
lr: float = 1e-3
epochs: int = 50
batch_size: int = 128
weight_decay: float = 1e-4 # L2 regularization built into Adam
patience: int = 5 # early stopping patience
seed: int = 42
class MLPClassifier(nn.Module):
"""Feed-forward MLP for intent classification on pre-computed embeddings.
Architecture:
Input (embed_dim)
→ [Linear → BatchNorm1d → ReLU → Dropout] × len(hidden_dims)
→ Linear → output (n_classes)
"""
def __init__(self, embed_dim: int, n_classes: int, config: MLPConfig) -> None:
super().__init__()
self.config = config
# Build the layer stack dynamically from config.hidden_dims
# This means [256, 128] gives 2 hidden layers,
# [512, 256, 128] gives 3 hidden layers, etc.
layers = []
prev_dim = embed_dim
for h in config.hidden_dims:
layers.extend([
nn.Linear(prev_dim, h),
nn.BatchNorm1d(h),
nn.ReLU(),
nn.Dropout(config.dropout),
])
prev_dim = h
# Final classification layer — no activation (logits go into CrossEntropyLoss)
layers.append(nn.Linear(prev_dim, n_classes))
# nn.Sequential chains the layers: output of one = input to next
self.net = nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
"""Return raw logits, shape (N, n_classes)."""
return self.net(x) # nn.Sequential handles the chaining
def predict(self, x: Tensor) -> Tensor:
"""Return predicted class indices, shape (N,)."""
self.eval() # Disable dropout, use BatchNorm running stats
with torch.no_grad(): # Don't store gradients — we're not training
logits = self.forward(x)
return logits.argmax(dim=1) # Pick the class with the highest score
@property
def n_params(self) -> int:
"""Count total trainable parameters."""
# p.numel() = number of elements in tensor p
# p.requires_grad = True means this parameter is trainable
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# Quick sanity check
if __name__ == "__main__":
config = MLPConfig(hidden_dims=[256, 128])
model = MLPClassifier(embed_dim=384, n_classes=151, config=config)
print(f"Parameters: {model.n_params:,}")
# 384*256+256 + 256*128+128 + 128*151+151 = 98,304 + 32,896 + 19,479 = 150,679 + BN params
dummy_batch = torch.randn(8, 384) # 8 samples, 384 dims each
logits = model(dummy_batch)
print(f"Output shape: {logits.shape}") # (8, 151)
preds = model.predict(dummy_batch)
print(f"Predictions shape: {preds.shape}") # (8,)Why No Activation After the Last Layer?
The last nn.Linear outputs raw scores called logits. We do NOT apply softmax here because PyTorch's nn.CrossEntropyLoss already applies log-softmax internally — applying softmax first and then CrossEntropyLoss would compute log(softmax(softmax(x))), which is wrong. During inference, if you want probabilities, call torch.softmax(logits, dim=1) yourself. For argmax predictions, raw logits work fine since argmax is order-preserving.
Training a neural network means finding the weights that minimize the loss. We do this by gradient descent: compute the gradient of the loss with respect to every weight, then nudge each weight in the opposite direction of its gradient. The plain version of this is SGD (Stochastic Gradient Descent) — every parameter gets the same learning rate, every update.
Adam (Adaptive Moment Estimation) is a smarter optimizer that gives each parameter its own effective learning rate, adapted based on the history of its gradients. It tracks two things for each parameter: the first moment (running average of the gradient — like a velocity in that direction) and the second moment (running average of the squared gradient — how volatile has this parameter's gradient been?). Parameters with large, consistent gradients get smaller effective learning rates; parameters with small or noisy gradients get larger effective learning rates.
Adam vs SGD: The Highway Analogy
SGD is like driving at a fixed speed everywhere — 60 mph on the highway and 60 mph in a school zone. Adam is like adaptive cruise control that watches the road and adjusts: it speeds up on clear highways (parameters with consistent gradients) and slows down where the terrain is bumpy and uncertain (parameters with noisy gradients). Adam typically converges in 5–10x fewer epochs than SGD.
import torch
import torch.nn as nn
model = nn.Linear(10, 5)
# SGD: one global learning rate for everything
sgd = torch.optim.SGD(
model.parameters(),
lr=0.1,
momentum=0.9 # Basic momentum
)
# Adam: adaptive per-parameter learning rates
# lr=1e-3 is the BASE learning rate (Adam scales it per-parameter)
# weight_decay=1e-4 adds L2 regularization: penalizes large weights to prevent overfitting
adam = torch.optim.Adam(
model.parameters(),
lr=1e-3,
weight_decay=1e-4 # equiv. to L2 regularization
)
# The update step is identical for all optimizers:
loss = nn.MSELoss()(model(torch.randn(4, 10)), torch.randn(4, 5))
adam.zero_grad() # Clear gradients from previous step (ALWAYS do this first!)
loss.backward() # Compute gradients via backpropagation
adam.step() # Update weights using Adam's adaptive ruleAlways zero_grad() Before backward()
PyTorch accumulates gradients by default — each backward() call adds to existing gradients rather than replacing them. If you forget zero_grad(), gradients pile up across steps and your weights get corrupted updates. The pattern is always: zero_grad() → loss.backward() → optimizer.step().
The weight_decay parameter in Adam adds a small penalty proportional to the magnitude of each weight. Mathematically, it adds λ * ||w||² to the loss, where λ is the weight_decay value. This penalizes very large weights, discouraging the model from over-relying on any single feature. Think of it as encouraging the model to spread its 'bets' across many features rather than putting all its weight on a few. A value of 1e-4 (0.0001) is a gentle nudge — large enough to matter, small enough not to overwhelm the signal.
For a classification problem with multiple classes, we use CrossEntropyLoss. It measures how well the model's predicted probability distribution matches the true label. Internally, PyTorch's nn.CrossEntropyLoss does three things in one: applies log-softmax to convert raw logits into log-probabilities, selects the log-probability for the correct class, and negates it (so minimizing the loss = maximizing confidence in the correct class).
import torch
import torch.nn as nn
criterion = nn.CrossEntropyLoss()
# Batch of 3 examples, 4 classes
logits = torch.tensor([
[2.0, 1.0, 0.1, 0.0], # Model strongly prefers class 0
[0.0, 0.5, 2.0, 0.5], # Model strongly prefers class 2
[1.0, 1.0, 1.0, 1.0], # Model is completely uncertain (uniform)
])
true_labels = torch.tensor([0, 2, 1]) # Ground truth
loss = criterion(logits, true_labels)
print(f'Loss: {loss.item():.4f}') # ~0.55 — low loss because predictions are mostly correct
# What happens when the model is confident but WRONG:
wrong_logits = torch.tensor([
[0.0, 0.0, 5.0, 0.0], # Very confident about class 2, but true label is 0
])
wrong_labels = torch.tensor([0])
bad_loss = criterion(wrong_logits, wrong_labels)
print(f'Bad loss: {bad_loss.item():.4f}') # ~5.0 — high loss, big penaltyTraining a model longer doesn't always make it better. After a certain point, the training loss keeps decreasing (the model is memorizing training data) but the validation accuracy plateaus or falls. This is overfitting. If you stop training at the wrong epoch, you get a model that performs great on training data and poorly on new data.
Early stopping monitors validation accuracy after each epoch. If it improves, we save the model weights and reset a counter. If it doesn't improve for patience consecutive epochs, we stop training and restore the best weights we ever saw. This way we automatically find the sweet spot without having to guess the right number of epochs.
import copy
import torch
# --- Early stopping state ---
best_val_acc = 0.0
best_state = None # Will hold a copy of model weights at the best epoch
patience_counter = 0
for epoch in range(config.epochs):
# ... (train one epoch here) ...
val_acc = evaluate(model, X_val, y_val) # float, e.g. 0.923
if val_acc > best_val_acc:
# New best! Snapshot the model weights
best_val_acc = val_acc
best_state = copy.deepcopy(model.state_dict())
# deepcopy is critical — state_dict() returns references, not copies
patience_counter = 0
print(f'Epoch {epoch+1}: New best val_acc = {val_acc:.4f} ✓')
else:
patience_counter += 1
print(f'Epoch {epoch+1}: No improvement ({patience_counter}/{config.patience})')
if patience_counter >= config.patience:
print(f'Early stopping triggered at epoch {epoch+1}')
break
# After training (either early stop or max epochs), restore the best weights
model.load_state_dict(best_state)
print(f'Restored best weights (val_acc = {best_val_acc:.4f})')copy.deepcopy() is Essential
model.state_dict() returns a dictionary of references to the current tensors. If you just store best_state = model.state_dict(), and the model continues training and changing its weights, best_state will silently update too — because it points to the same tensors. copy.deepcopy() makes a completely independent copy of all the weight values at that moment. Never skip this.
Now let's put everything together into a complete MLPTrainer class. A training loop has a repeating structure: for each epoch, shuffle the training data, slice it into mini-batches, do forward→loss→backward→step for each batch, then evaluate on the validation set.
"""Training loop with early stopping and history tracking."""
import copy
import torch
import torch.nn as nn
from torch import Tensor
class MLPTrainer:
"""Manages the training loop, early stopping, and history."""
def __init__(self, model: MLPClassifier, config: MLPConfig) -> None:
self.model = model
self.config = config
# Adam: adaptive learning rates + weight decay (L2 regularization)
self.optimizer = torch.optim.Adam(
model.parameters(),
lr=config.lr,
weight_decay=config.weight_decay,
)
# CrossEntropyLoss = log_softmax + NLL loss, numerically stable
self.criterion = nn.CrossEntropyLoss()
def _run_epoch(self, X: Tensor, y: Tensor) -> float:
"""Run one training epoch, return average loss."""
self.model.train() # Enable dropout and batch norm training mode
# Shuffle training data each epoch to prevent order-dependent learning
perm = torch.randperm(len(X))
X, y = X[perm], y[perm]
total_loss = 0.0
n_batches = 0
for i in range(0, len(X), self.config.batch_size):
X_batch = X[i : i + self.config.batch_size]
y_batch = y[i : i + self.config.batch_size]
self.optimizer.zero_grad() # 1. Clear old gradients
logits = self.model(X_batch) # 2. Forward pass
loss = self.criterion(logits, y_batch) # 3. Compute loss
loss.backward() # 4. Backprop: compute gradients
self.optimizer.step() # 5. Update weights
total_loss += loss.item()
n_batches += 1
return total_loss / n_batches
@torch.no_grad() # Decorator: all operations here skip gradient tracking
def _evaluate(self, X: Tensor, y: Tensor) -> tuple[float, float]:
"""Return (loss, accuracy) on a dataset."""
self.model.eval() # Disable dropout, use BatchNorm running stats
logits = self.model(X)
loss = self.criterion(logits, y).item()
acc = (logits.argmax(dim=1) == y).float().mean().item()
return loss, acc
def fit(
self,
X_train: Tensor, y_train: Tensor,
X_val: Tensor, y_val: Tensor,
) -> dict[str, list[float]]:
"""Train with early stopping. Returns history dict."""
torch.manual_seed(self.config.seed)
history = {"train_loss": [], "val_loss": [], "val_acc": []}
best_val_acc = 0.0
best_state = None
patience_counter = 0
for epoch in range(self.config.epochs):
train_loss = self._run_epoch(X_train, y_train)
val_loss, val_acc = self._evaluate(X_val, y_val)
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["val_acc"].append(val_acc)
if epoch % 5 == 0 or epoch == 0:
print(
f"Epoch {epoch+1:3d} | "
f"train_loss={train_loss:.4f} | "
f"val_loss={val_loss:.4f} | "
f"val_acc={val_acc:.4f}"
)
# Early stopping logic
if val_acc > best_val_acc:
best_val_acc = val_acc
best_state = copy.deepcopy(self.model.state_dict())
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= self.config.patience:
print(f"Early stopping at epoch {epoch+1} (best val_acc={best_val_acc:.4f})")
break
# Always restore best weights before returning
if best_state is not None:
self.model.load_state_dict(best_state)
return historyNow let's wire everything together in a runnable script. The pipeline is: load data → encode to embeddings (with caching) → train MLP → evaluate → plot curves → report results.
"""Full pipeline: data → embeddings → MLP training → evaluation."""
import argparse
import time
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from agent_router.data import load_clinc150
from agent_router.models.embedding_utils import load_or_encode
from agent_router.models.mlp_classifier import MLPClassifier, MLPConfig, MLPTrainer
from agent_router.eval import evaluate_classifier
def parse_args():
p = argparse.ArgumentParser(description="Train MLP on sentence embeddings")
p.add_argument("--hidden-dims", default="256,128")
p.add_argument("--dropout", type=float, default=0.3)
p.add_argument("--epochs", type=int, default=50)
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--batch-size", type=int, default=128)
p.add_argument("--patience", type=int, default=5)
p.add_argument("--embed-model", default="all-MiniLM-L6-v2")
p.add_argument("--embed-cache-dir",default="runs/ticket_004/embeddings")
p.add_argument("--plot-path", default="runs/ticket_004/training_curves.png")
return p.parse_args()
def plot_history(history: dict, save_path: str) -> None:
"""Plot train/val loss and val accuracy on dual y-axes."""
epochs = range(1, len(history["train_loss"]) + 1)
fig, ax1 = plt.subplots(figsize=(10, 6))
ax2 = ax1.twinx() # Second y-axis sharing the same x-axis
ax1.plot(epochs, history["train_loss"], label="Train loss", color="steelblue")
ax1.plot(epochs, history["val_loss"], label="Val loss", color="tomato")
ax2.plot(epochs, history["val_acc"], label="Val accuracy",color="seagreen", linestyle="--")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax2.set_ylabel("Accuracy")
ax1.set_title("MLP Training Curves")
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc="center right")
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"Saved training curves to {save_path}")
def main():
args = parse_args()
hidden_dims = [int(x) for x in args.hidden_dims.split(",")]
config = MLPConfig(
hidden_dims=hidden_dims,
dropout=args.dropout,
lr=args.lr,
epochs=args.epochs,
batch_size=args.batch_size,
patience=args.patience,
)
# 1. Load dataset
dataset = load_clinc150() # Returns train/val/test splits
# 2. Encode texts to embeddings (loads from cache if available)
X_train = load_or_encode(dataset.train_texts, "train", args.embed_cache_dir, args.embed_model)
X_val = load_or_encode(dataset.val_texts, "val", args.embed_cache_dir, args.embed_model)
X_test = load_or_encode(dataset.test_texts, "test", args.embed_cache_dir, args.embed_model)
y_train = torch.tensor(dataset.train_labels)
y_val = torch.tensor(dataset.val_labels)
y_test = torch.tensor(dataset.test_labels)
# 3. Build and train model
model = MLPClassifier(embed_dim=384, n_classes=151, config=config)
trainer = MLPTrainer(model, config)
print(f"Model has {model.n_params:,} trainable parameters")
history = trainer.fit(X_train, y_train, X_val, y_val)
# 4. Plot training curves
plot_history(history, args.plot_path)
# 5. Evaluate on test set
results = evaluate_classifier(model, X_test, y_test, dataset.label_names)
print(f"\nTest Accuracy: {results.accuracy:.4f}")
# 6. Measure end-to-end inference latency
sample = X_test[:1]
start = time.perf_counter()
for _ in range(100):
model.predict(sample)
latency_ms = (time.perf_counter() - start) / 100 * 1000
print(f"Inference latency (MLP forward only): {latency_ms:.2f} ms")
if __name__ == "__main__":
main()Good tests for neural networks don't just check that the code runs — they check that the math is right. Here are the key things worth testing and why each one matters:
"""Tests for the MLP classifier."""
import copy
import torch
import pytest
from agent_router.models.mlp_classifier import MLPClassifier, MLPConfig, MLPTrainer
# ── 1. Shape tests ─────────────────────────────────────────────────────────────
def test_forward_output_shape():
"""forward() must return (N, n_classes) — anything else means broken architecture."""
config = MLPConfig(hidden_dims=[64, 32])
model = MLPClassifier(embed_dim=384, n_classes=151, config=config)
x = torch.randn(16, 384) # 16 samples, 384-dim embeddings
logits = model(x)
assert logits.shape == (16, 151), f"Expected (16, 151), got {logits.shape}"
def test_predict_output_shape():
"""predict() must return (N,) integer class indices."""
config = MLPConfig(hidden_dims=[64])
model = MLPClassifier(embed_dim=384, n_classes=10, config=config)
x = torch.randn(8, 384)
preds = model.predict(x)
assert preds.shape == (8,), f"Expected (8,), got {preds.shape}"
assert preds.dtype == torch.int64, f"Expected int64, got {preds.dtype}"
# ── 2. Dynamic architecture ────────────────────────────────────────────────────
def test_dynamic_architecture_changes_params():
"""A wider/deeper network must have more parameters."""
config_small = MLPConfig(hidden_dims=[64])
config_large = MLPConfig(hidden_dims=[128, 64])
model_small = MLPClassifier(embed_dim=384, n_classes=10, config=config_small)
model_large = MLPClassifier(embed_dim=384, n_classes=10, config=config_large)
assert model_large.n_params > model_small.n_params, "Larger config must have more params"
# ── 3. n_params correctness ────────────────────────────────────────────────────
def test_n_params_correctness():
"""Manually count expected parameters for a simple known architecture."""
# Architecture: Linear(4→3) + BN(3) + Linear(3→2)
# Linear(4→3): weights=4*3=12, bias=3 → 15
# BN(3): weight=3, bias=3 → 6
# Linear(3→2): weights=3*2=6, bias=2 → 8
# Total: 15 + 6 + 8 = 29
config = MLPConfig(hidden_dims=[3], dropout=0.0)
model = MLPClassifier(embed_dim=4, n_classes=2, config=config)
assert model.n_params == 29, f"Expected 29 params, got {model.n_params}"
# ── 4. Training reduces loss ───────────────────────────────────────────────────
def test_training_reduces_loss():
"""Training on a tiny synthetic dataset should consistently reduce train loss."""
torch.manual_seed(42)
config = MLPConfig(hidden_dims=[32], epochs=20, patience=20, lr=1e-2)
model = MLPClassifier(embed_dim=16, n_classes=3, config=config)
trainer = MLPTrainer(model, config)
X = torch.randn(60, 16)
y = torch.randint(0, 3, (60,))
history = trainer.fit(X, y, X, y) # Using train=val just to test loss goes down
assert history["train_loss"][-1] < history["train_loss"][0], \
"Final train loss should be lower than initial loss"
# ── 5. Early stopping fires ────────────────────────────────────────────────────
def test_early_stopping_fires():
"""Training must stop before max_epochs when val acc plateaus."""
config = MLPConfig(hidden_dims=[16], epochs=100, patience=3, lr=1e-5) # tiny lr = no progress
model = MLPClassifier(embed_dim=8, n_classes=2, config=config)
trainer = MLPTrainer(model, config)
X = torch.randn(40, 8)
y = torch.randint(0, 2, (40,))
history = trainer.fit(X, y, X, y)
# With patience=3 and effectively no learning, must stop well before 100 epochs
assert len(history["val_acc"]) < 100, \
f"Early stopping should have fired before 100 epochs, ran {len(history['val_acc'])}"
# ── 6. predict() equals argmax(forward()) ─────────────────────────────────────
def test_predict_matches_argmax():
"""predict() must be equivalent to forward().argmax(dim=1)."""
config = MLPConfig(hidden_dims=[32])
model = MLPClassifier(embed_dim=64, n_classes=5, config=config)
x = torch.randn(20, 64)
# Get predictions both ways
preds_direct = model.predict(x)
model.eval()
with torch.no_grad():
preds_argmax = model(x).argmax(dim=1)
assert torch.all(preds_direct == preds_argmax), "predict() must match argmax(forward())"When you run this pipeline on CLINC150 (a 150-class intent dataset with ~15,000 sentences), you should expect test accuracy above 90%. This is a dramatic jump from the ~78% of TF-IDF + logistic regression. The gain comes almost entirely from the embeddings, not the MLP architecture. The sentence transformer has already done the hard work of mapping synonymous phrases to nearby points in 384-dimensional space — the MLP just needs to draw decision boundaries between well-separated clusters.
| Model | Features | Accuracy | Inference Latency | Why |
|---|---|---|---|---|
| TF-IDF + Logistic Regression | Sparse word counts | ~78% | ~1ms | Linear boundary, no synonym handling |
| TF-IDF + PyTorch Logistic Regression | Sparse word counts | ~78% | ~1ms | Same features, same ceiling |
| MLP on Frozen Embeddings (this post) | Dense semantic vectors | >90% | ~5ms | Pretrained meaning + nonlinear boundaries |
| Fine-tuned Transformer | Contextual token embeddings | ~95%+ | ~50ms | End-to-end training, most expressive |
Why Not Just Fine-Tune the Transformer?
Fine-tuning gives the best accuracy, but it's 10–100x more expensive to train, requires a GPU for reasonable speed, and is overkill for many production scenarios. The frozen-embeddings + MLP approach is a sweet spot: 90%+ accuracy, trains in minutes on CPU, and inference is fast. This is a common production pattern.
If your accuracy is unexpectedly low, work through these in order — most issues reduce to one of these five causes:
- Forgot model.eval() during validation — This is the #1 culprit. Dropout stays active in train mode and randomly zeros neurons during your validation forward pass. BatchNorm uses noisy batch statistics instead of stable running stats. Your 'validation accuracy' becomes meaningless. Fix: always call model.eval() before any evaluation code.
- Not restoring best weights — If early stopping fires but you forget load_state_dict(best_state), you evaluate the last epoch's weights, not the best epoch's. The last epoch is often overfit. Fix: always restore best_state after the training loop.
- Learning rate too high or too low — Too high (>1e-2 for Adam): loss oscillates wildly and never converges. Too low (<1e-5): training effectively doesn't happen. The default 1e-3 is well-tested for Adam on this type of problem.
- Hidden dimensions too small — With 151 output classes, you need enough representational capacity in the hidden layers. [256, 128] is appropriate. [32] is too small to separate 151 classes reliably.
- Embeddings not cached correctly — If the cache isn't working and you're accidentally re-encoding with a different random seed or batch size (which shouldn't matter for this model, but can cause subtle bugs), verify the cache file is being loaded with print statements.
| Concept | What it is | Why it matters |
|---|---|---|
| Sentence embeddings | Dense 384-dim vectors capturing semantic meaning | Synonyms map to nearby points; no TF-IDF ceiling |
| Transfer learning | Reusing weights trained on large data for a new task | Start with a powerful representation, no need to train from scratch |
| Frozen encoder | Pretrained model weights kept fixed, not updated by backprop | Prevents overfitting on small data; much faster training |
| nn.Module | PyTorch base class for all neural networks | Automatic parameter tracking, train/eval modes, save/load |
| Hidden layers + ReLU | Nonlinear transformations between input and output | Enables learning curved decision boundaries (unlike linear models) |
| BatchNorm1d | Normalizes activations within each mini-batch | Stabilizes training, allows higher learning rates |
| Dropout | Randomly zeroes neurons during training | Prevents overfitting; forces neurons to be independently useful |
| Adam optimizer | Adaptive per-parameter learning rates | Faster convergence than SGD; handles sparse gradients well |
| weight_decay | L2 penalty on weight magnitudes, built into Adam | Discourages over-reliance on any single feature |
| CrossEntropyLoss | log-softmax + NLL combined; takes raw logits | Numerically stable; correct loss for multi-class classification |
| Early stopping | Stop training when val metric stops improving | Finds the optimal epoch automatically; prevents overfitting |
| Embedding caching | Save computed embeddings to disk for reuse | Pay encoding cost once; subsequent runs start instantly |
We've covered a lot of ground. Starting from the limitations of TF-IDF, we built a complete pipeline that uses a pretrained sentence transformer to encode meaning into dense vectors, then trains a multi-layer perceptron to classify intents with over 90% accuracy. Every piece — BatchNorm for training stability, Dropout for generalization, Adam for fast convergence, early stopping for finding the optimal epoch — plays a specific role in making the system robust.
The most important insight is about the division of labour: the sentence transformer does the heavy lifting of understanding language (trained on billions of sentences, frozen, never updated), and the MLP does the lighter work of learning decision boundaries in that well-structured representation space. This separation is the foundation of the modern practice of transfer learning in NLP — and the same pattern (frozen pretrained encoder + small trained head) is used in production systems at Google, Meta, and virtually every company doing serious NLP work.
Where to Go Next
The natural next step is fine-tuning — instead of freezing the transformer, you update its weights too with a very small learning rate. This is how BERT, RoBERTa, and similar models are adapted to specific tasks. The infrastructure you've built here (data pipeline, training loop, evaluation, early stopping) carries forward to fine-tuning almost unchanged — you just swap the frozen encoder for a trainable one.
Related Articles
Understanding Transformers: The Architecture Behind Modern AI
A comprehensive guide to understanding the Transformer architecture that powers GPT, BERT, and other modern language models.
Logistic Regression from Scratch in PyTorch: Every Line Explained
Build a multi-class classifier in PyTorch without nn.Linear, without optim.SGD, without CrossEntropyLoss. Just tensors, autograd, and arithmetic — so you finally see what those helpers actually do.
TF-IDF + Logistic Regression: The Classical ML Baseline You Should Try First
Before reaching for LLMs or neural networks for text classification, try the boring thing. Here's how TF-IDF + Logistic Regression works, why it's often embarrassingly competitive, and where it breaks.