
Early Stopping Explained: Knowing When to Stop Training
A complete beginner's guide to early stopping - how to automatically find the optimal training duration, prevent overfitting, and save the best model weights.
Imagine studying for an exam. If you stop too early, you haven't learned enough. If you study too long, you start overthinking and second-guessing yourself. There's a sweet spot — and finding it is crucial. Early stopping solves the same problem for neural networks: it automatically finds the optimal training duration, preventing both underfitting (stopping too early) and overfitting (training too long). This post explains exactly how early stopping works, why it's essential, and how to implement it correctly.
What You'll Learn
Training a neural network is an iterative process. Each epoch, the model sees the entire training dataset and updates its weights. But how many epochs should you train for? This is harder than it sounds.
If you stop training after 5 epochs when the model needs 50, you get underfitting. The model hasn't learned the patterns in your data yet. Both training and test accuracy are low.
The Unprepared Student
If you train for 500 epochs when the model only needed 50, you get overfitting. The model starts memorizing the training data instead of learning generalizable patterns. Training accuracy keeps improving, but test accuracy plateaus or even decreases.
Here's what happens during overfitting:
- Early epochs: Model learns general patterns (good)
- Middle epochs: Model refines understanding (still good)
- Late epochs: Model starts memorizing training examples (bad)
- Very late epochs: Model has memorized training data perfectly but fails on new data (very bad)
The Over-Prepared Student
There's an optimal number of epochs where the model has learned the patterns but hasn't started memorizing. This is where test accuracy is highest. The problem: you don't know this number in advance. It depends on:
- Your dataset size
- Model complexity
- Learning rate
- Regularization strength
- Random initialization
You could guess ("let's try 100 epochs"), but that's wasteful. Early stopping finds the sweet spot automatically.
Early stopping monitors a validation metric (usually validation accuracy or validation loss) after each epoch. When the metric stops improving, training stops. Here's the algorithm:
# Early stopping algorithm (pseudocode)
best_val_metric = 0 # or infinity for loss
best_weights = None
patience_counter = 0
patience = 5 # How many epochs to wait
for epoch in range(max_epochs):
# Train for one epoch
train_one_epoch()
# Evaluate on validation set
val_metric = evaluate_on_validation()
# Check if this is the best we've seen
if val_metric > best_val_metric: # or < for loss
# New best! Save the weights
best_val_metric = val_metric
best_weights = copy_model_weights()
patience_counter = 0 # Reset counter
else:
# No improvement
patience_counter += 1
# Have we waited long enough?
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
# Restore the best weights (CRITICAL!)
restore_weights(best_weights)Let's break down each component:
You need three datasets:
- Training set: Used to update weights
- Validation set: Used to monitor progress and decide when to stop
- Test set: Used only at the very end to report final performance
The validation set is crucial. You can't use training accuracy (it always improves, even during overfitting) or test accuracy (that would be cheating — you'd be peeking at the exam). The validation set is your honest progress check.
Patience is how many epochs you wait without improvement before stopping. Why not stop immediately after the first epoch without improvement? Because validation metrics are noisy — they can fluctuate randomly.
Epoch | Val Accuracy | Action
------|--------------|-------
1 | 0.75 | New best! Save weights, reset counter
2 | 0.78 | New best! Save weights, reset counter
3 | 0.77 | No improvement, counter = 1
4 | 0.76 | No improvement, counter = 2
5 | 0.81 | New best! Save weights, reset counter
6 | 0.80 | No improvement, counter = 1
7 | 0.79 | No improvement, counter = 2
8 | 0.79 | No improvement, counter = 3
9 | 0.78 | No improvement, counter = 4
10 | 0.78 | No improvement, counter = 5 → STOP!Notice epoch 5 — validation accuracy improved after 2 epochs of decline. If patience was 1, we would have stopped too early. Patience gives the model a chance to recover from temporary dips.
Choosing Patience
This is the most critical part that beginners often get wrong. When early stopping fires, you must restore the best weights, not the last weights.
Why? Because the last few epochs were overfitting — that's why validation accuracy stopped improving! The best weights are from several epochs ago, when validation accuracy peaked.
import copy
# WRONG: Just stopping without restoring
for epoch in range(max_epochs):
train_one_epoch()
val_acc = evaluate()
if should_stop():
break
# Model now has the LAST weights (overfit!)
# CORRECT: Save best weights and restore them
best_weights = None
for epoch in range(max_epochs):
train_one_epoch()
val_acc = evaluate()
if val_acc > best_val_acc:
# Save a COPY of the weights
best_weights = copy.deepcopy(model.state_dict())
if should_stop():
break
# Restore the best weights
model.load_state_dict(best_weights)
# Model now has the BEST weights (optimal!)The Most Common Early Stopping Bug
Here's a complete, production-ready implementation:
import copy
import torch
import torch.nn as nn
class EarlyStopping:
"""Early stopping to stop training when validation metric stops improving."""
def __init__(self, patience=5, min_delta=0.0, mode='max'):
"""
Args:
patience: How many epochs to wait after last improvement
min_delta: Minimum change to qualify as improvement
mode: 'max' for accuracy (higher is better), 'min' for loss (lower is better)
"""
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.counter = 0
self.best_score = None
self.best_weights = None
self.early_stop = False
def __call__(self, val_metric, model):
"""Check if we should stop training.
Args:
val_metric: Current validation metric (accuracy or loss)
model: The model being trained
Returns:
True if training should stop, False otherwise
"""
score = val_metric if self.mode == 'max' else -val_metric
if self.best_score is None:
# First epoch
self.best_score = score
self.best_weights = copy.deepcopy(model.state_dict())
elif score < self.best_score + self.min_delta:
# No improvement
self.counter += 1
print(f"EarlyStopping counter: {self.counter}/{self.patience}")
if self.counter >= self.patience:
self.early_stop = True
else:
# Improvement!
self.best_score = score
self.best_weights = copy.deepcopy(model.state_dict())
self.counter = 0
return self.early_stop
def restore_best_weights(self, model):
"""Restore the best weights to the model."""
if self.best_weights is not None:
model.load_state_dict(self.best_weights)
# Usage example
model = YourModel()
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=5, mode='max') # max for accuracy
for epoch in range(max_epochs):
# Training
model.train()
for batch in train_loader:
optimizer.zero_grad()
loss = criterion(model(batch.x), batch.y)
loss.backward()
optimizer.step()
# Validation
model.eval()
val_acc = evaluate_on_validation(model, val_loader)
# Check early stopping
if early_stopping(val_acc, model):
print(f"Early stopping triggered at epoch {epoch+1}")
break
# CRITICAL: Restore best weights
early_stopping.restore_best_weights(model)
print(f"Restored best weights (val_acc={early_stopping.best_score:.4f})")What should you monitor? The most common choices:
| Metric | Mode | When to Use | Pros | Cons |
|---|---|---|---|---|
| Validation Accuracy | max | Classification tasks | Easy to interpret, directly measures performance | Can be misleading with imbalanced classes |
| Validation Loss | min | Any task | Smooth signal, less noisy than accuracy | Harder to interpret, can decrease while accuracy plateaus |
| F1 Score | max | Imbalanced classification | Better than accuracy for imbalanced data | More complex to compute |
| Validation Perplexity | min | Language modeling | Standard metric for LMs | Only for language tasks |
Default Choice
Early stopping isn't the only way to decide when to stop. Let's compare:
| Method | How It Works | Pros | Cons |
|---|---|---|---|
| Fixed epochs | Train for N epochs, then stop | Simple, reproducible | Wastes time if N too large, undertrains if N too small |
| Early stopping | Stop when validation metric plateaus | Automatic, prevents overfitting | Requires validation set, adds complexity |
| Learning rate scheduling | Reduce LR when progress slows | Can extend training productively | Doesn't prevent overfitting, just delays it |
| Manual inspection | Watch training curves, stop manually | Flexible, uses human judgment | Not reproducible, requires constant monitoring |
Best practice: Combine early stopping with a maximum epoch limit. This gives you automatic stopping with a safety net:
max_epochs = 100 # Safety net
early_stopping = EarlyStopping(patience=5)
for epoch in range(max_epochs):
train_one_epoch()
val_acc = evaluate()
if early_stopping(val_acc, model):
print(f"Early stopping at epoch {epoch+1}")
break
else:
# Loop completed without early stopping
print(f"Reached max epochs ({max_epochs})")
early_stopping.restore_best_weights(model)A picture is worth a thousand words. Here's what early stopping looks like in practice:
import matplotlib.pyplot as plt
# Training history
epochs = list(range(1, 51))
train_loss = [...] # Decreases monotonically
val_loss = [...] # Decreases then increases (overfitting)
plt.figure(figsize=(10, 6))
plt.plot(epochs, train_loss, label='Training Loss', color='blue')
plt.plot(epochs, val_loss, label='Validation Loss', color='orange')
# Mark the best epoch (where early stopping would restore weights)
best_epoch = 23 # Example: validation loss was lowest at epoch 23
plt.axvline(x=best_epoch, color='green', linestyle='--',
label=f'Best Epoch ({best_epoch})')
# Mark where early stopping fired
stop_epoch = 28 # Example: stopped at epoch 28 (patience=5)
plt.axvline(x=stop_epoch, color='red', linestyle='--',
label=f'Early Stop ({stop_epoch})')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Early Stopping: Validation Loss Plateaus')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()The key insight: training loss keeps decreasing (the model keeps improving on training data), but validation loss starts increasing after epoch 23 (the model is overfitting). Early stopping detects this and restores the weights from epoch 23.
- Not restoring best weights: The #1 bug. Always call restore_best_weights() after training stops.
- Using training metric instead of validation: Training accuracy always improves, even during overfitting. Use validation metric.
- Patience too small: patience=1 is too aggressive. Use at least 3-5.
- No validation set: You need a separate validation set. Don't use test set for early stopping (that's cheating).
- Forgetting copy.deepcopy(): model.state_dict() returns references, not copies. Use copy.deepcopy().
- Monitoring the wrong metric: For classification, monitor accuracy (mode='max'). For regression, monitor loss (mode='min').
- Not setting max_epochs: Always have a maximum epoch limit as a safety net.
Sometimes validation metrics improve by tiny amounts (0.001%) due to noise. You might want to ignore these tiny improvements and only count "real" improvements. That's what min_delta does:
# Without min_delta
early_stopping = EarlyStopping(patience=5, min_delta=0.0)
# Improvement of 0.0001 counts as improvement
# With min_delta
early_stopping = EarlyStopping(patience=5, min_delta=0.001)
# Only improvements > 0.001 count as improvement
# Tiny fluctuations are ignored
# Example:
# Epoch 1: val_acc = 0.850 → New best
# Epoch 2: val_acc = 0.8505 → Improvement of 0.0005 < 0.001 → No improvement
# Epoch 3: val_acc = 0.852 → Improvement of 0.002 > 0.001 → New best!When to use min_delta: If your validation metric is very noisy and fluctuates by small amounts, set min_delta to filter out noise. For most problems, min_delta=0.0 (the default) works fine.
Here's a complete training script with early stopping:
import torch
import torch.nn as nn
from early_stopping import EarlyStopping
def train_with_early_stopping(
model, train_loader, val_loader,
max_epochs=100, patience=5
):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=patience, mode='max')
history = {'train_loss': [], 'val_acc': []}
for epoch in range(max_epochs):
# Training phase
model.train()
train_loss = 0.0
for batch_x, batch_y in train_loader:
optimizer.zero_grad()
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
# Validation phase
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_x, batch_y in val_loader:
output = model(batch_x)
pred = output.argmax(dim=1)
correct += (pred == batch_y).sum().item()
total += batch_y.size(0)
val_acc = correct / total
# Record history
history['train_loss'].append(train_loss)
history['val_acc'].append(val_acc)
# Print progress
print(f"Epoch {epoch+1}/{max_epochs} - "
f"train_loss: {train_loss:.4f} - "
f"val_acc: {val_acc:.4f}")
# Check early stopping
if early_stopping(val_acc, model):
print(f"\nEarly stopping triggered at epoch {epoch+1}")
print(f"Best validation accuracy: {early_stopping.best_score:.4f}")
break
# Restore best weights
early_stopping.restore_best_weights(model)
print("\nRestored best model weights")
return history
# Usage
model = YourModel()
history = train_with_early_stopping(
model, train_loader, val_loader,
max_epochs=100, patience=5
)- Early stopping prevents overfitting by monitoring validation metrics and stopping when they plateau.
- Patience controls sensitivity: Higher patience is more conservative, lower patience stops faster.
- Always restore best weights: The last weights are overfit; the best weights are from several epochs ago.
- Use copy.deepcopy(): Make true copies of weights, not references.
- Monitor validation metrics: Never use training metrics (they always improve) or test metrics (that's cheating).
- Combine with max_epochs: Set a maximum epoch limit as a safety net.
- Default settings work well: patience=5, min_delta=0.0, mode='max' for accuracy.
- Early stopping is free regularization: No hyperparameters to tune, just works.
The Bottom Line
Related Articles
Adam Optimizer Explained: Why It's Better Than Plain Gradient Descent
A complete beginner's guide to the Adam optimizer - how it adapts learning rates per parameter, why it converges faster than SGD, and how to use it effectively in PyTorch.
Batch Normalization Explained: Why Your Neural Network Needs It
A complete beginner's guide to Batch Normalization - what it is, why it works, how to implement it, and the critical train vs eval mode difference that trips up everyone.
Dropout Explained: The Surprisingly Simple Trick That Prevents Overfitting
A complete beginner's guide to Dropout regularization - why randomly turning off neurons makes neural networks smarter, how it works, and how to use it correctly in PyTorch.