Back to articles
Early Stopping Explained: Knowing When to Stop Training

Early Stopping Explained: Knowing When to Stop Training

A complete beginner's guide to early stopping - how to automatically find the optimal training duration, prevent overfitting, and save the best model weights.

14 min read

Imagine studying for an exam. If you stop too early, you haven't learned enough. If you study too long, you start overthinking and second-guessing yourself. There's a sweet spot — and finding it is crucial. Early stopping solves the same problem for neural networks: it automatically finds the optimal training duration, preventing both underfitting (stopping too early) and overfitting (training too long). This post explains exactly how early stopping works, why it's essential, and how to implement it correctly.

What You'll Learn

By the end of this post, you'll understand: what overfitting is and how to detect it, how early stopping monitors validation metrics, why you must restore the best weights (not the last), how to choose the patience parameter, and how to implement early stopping in PyTorch.

Training a neural network is an iterative process. Each epoch, the model sees the entire training dataset and updates its weights. But how many epochs should you train for? This is harder than it sounds.

If you stop training after 5 epochs when the model needs 50, you get underfitting. The model hasn't learned the patterns in your data yet. Both training and test accuracy are low.

The Unprepared Student

Imagine taking an exam after studying for only 1 hour when you needed 10. You haven't learned the material yet. Your practice test score is low, and your real exam score is low. That's underfitting — the model simply hasn't learned enough.

If you train for 500 epochs when the model only needed 50, you get overfitting. The model starts memorizing the training data instead of learning generalizable patterns. Training accuracy keeps improving, but test accuracy plateaus or even decreases.

Here's what happens during overfitting:

  1. Early epochs: Model learns general patterns (good)
  2. Middle epochs: Model refines understanding (still good)
  3. Late epochs: Model starts memorizing training examples (bad)
  4. Very late epochs: Model has memorized training data perfectly but fails on new data (very bad)

The Over-Prepared Student

Imagine studying so much that you memorize every practice problem word-for-word, including the typos. You ace the practice test (100%) but fail the real exam because the questions are slightly different. That's overfitting — memorization instead of understanding.

There's an optimal number of epochs where the model has learned the patterns but hasn't started memorizing. This is where test accuracy is highest. The problem: you don't know this number in advance. It depends on:

  • Your dataset size
  • Model complexity
  • Learning rate
  • Regularization strength
  • Random initialization

You could guess ("let's try 100 epochs"), but that's wasteful. Early stopping finds the sweet spot automatically.

Early stopping monitors a validation metric (usually validation accuracy or validation loss) after each epoch. When the metric stops improving, training stops. Here's the algorithm:

early_stopping_algorithm.py
python
# Early stopping algorithm (pseudocode)

best_val_metric = 0  # or infinity for loss
best_weights = None
patience_counter = 0
patience = 5  # How many epochs to wait

for epoch in range(max_epochs):
    # Train for one epoch
    train_one_epoch()
    
    # Evaluate on validation set
    val_metric = evaluate_on_validation()
    
    # Check if this is the best we've seen
    if val_metric > best_val_metric:  # or < for loss
        # New best! Save the weights
        best_val_metric = val_metric
        best_weights = copy_model_weights()
        patience_counter = 0  # Reset counter
    else:
        # No improvement
        patience_counter += 1
        
        # Have we waited long enough?
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

# Restore the best weights (CRITICAL!)
restore_weights(best_weights)

Let's break down each component:

You need three datasets:

  • Training set: Used to update weights
  • Validation set: Used to monitor progress and decide when to stop
  • Test set: Used only at the very end to report final performance

The validation set is crucial. You can't use training accuracy (it always improves, even during overfitting) or test accuracy (that would be cheating — you'd be peeking at the exam). The validation set is your honest progress check.

Patience is how many epochs you wait without improvement before stopping. Why not stop immediately after the first epoch without improvement? Because validation metrics are noisy — they can fluctuate randomly.

patience_example.txt
text
Epoch | Val Accuracy | Action
------|--------------|-------
1     | 0.75         | New best! Save weights, reset counter
2     | 0.78         | New best! Save weights, reset counter
3     | 0.77         | No improvement, counter = 1
4     | 0.76         | No improvement, counter = 2
5     | 0.81         | New best! Save weights, reset counter
6     | 0.80         | No improvement, counter = 1
7     | 0.79         | No improvement, counter = 2
8     | 0.79         | No improvement, counter = 3
9     | 0.78         | No improvement, counter = 4
10    | 0.78         | No improvement, counter = 5 → STOP!

Notice epoch 5 — validation accuracy improved after 2 epochs of decline. If patience was 1, we would have stopped too early. Patience gives the model a chance to recover from temporary dips.

Choosing Patience

Common patience values: 3-10 epochs. Smaller patience (3-5) stops faster but might stop too early. Larger patience (7-10) is more conservative but takes longer. Start with patience=5 — it's a good default.

This is the most critical part that beginners often get wrong. When early stopping fires, you must restore the best weights, not the last weights.

Why? Because the last few epochs were overfitting — that's why validation accuracy stopped improving! The best weights are from several epochs ago, when validation accuracy peaked.

save_restore_weights.py
python
import copy

# WRONG: Just stopping without restoring
for epoch in range(max_epochs):
    train_one_epoch()
    val_acc = evaluate()
    if should_stop():
        break
# Model now has the LAST weights (overfit!)

# CORRECT: Save best weights and restore them
best_weights = None
for epoch in range(max_epochs):
    train_one_epoch()
    val_acc = evaluate()
    if val_acc > best_val_acc:
        # Save a COPY of the weights
        best_weights = copy.deepcopy(model.state_dict())
    if should_stop():
        break

# Restore the best weights
model.load_state_dict(best_weights)
# Model now has the BEST weights (optimal!)

The Most Common Early Stopping Bug

Forgetting to restore the best weights means you evaluate the overfit model. Your test accuracy will be mysteriously low. ALWAYS restore best_weights after training stops. Use copy.deepcopy() to make a true copy, not just a reference.

Here's a complete, production-ready implementation:

early_stopping_pytorch.py
python
import copy
import torch
import torch.nn as nn

class EarlyStopping:
    """Early stopping to stop training when validation metric stops improving."""
    
    def __init__(self, patience=5, min_delta=0.0, mode='max'):
        """
        Args:
            patience: How many epochs to wait after last improvement
            min_delta: Minimum change to qualify as improvement
            mode: 'max' for accuracy (higher is better), 'min' for loss (lower is better)
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.best_weights = None
        self.early_stop = False
    
    def __call__(self, val_metric, model):
        """Check if we should stop training.
        
        Args:
            val_metric: Current validation metric (accuracy or loss)
            model: The model being trained
        
        Returns:
            True if training should stop, False otherwise
        """
        score = val_metric if self.mode == 'max' else -val_metric
        
        if self.best_score is None:
            # First epoch
            self.best_score = score
            self.best_weights = copy.deepcopy(model.state_dict())
        elif score < self.best_score + self.min_delta:
            # No improvement
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            # Improvement!
            self.best_score = score
            self.best_weights = copy.deepcopy(model.state_dict())
            self.counter = 0
        
        return self.early_stop
    
    def restore_best_weights(self, model):
        """Restore the best weights to the model."""
        if self.best_weights is not None:
            model.load_state_dict(self.best_weights)

# Usage example
model = YourModel()
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=5, mode='max')  # max for accuracy

for epoch in range(max_epochs):
    # Training
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        loss = criterion(model(batch.x), batch.y)
        loss.backward()
        optimizer.step()
    
    # Validation
    model.eval()
    val_acc = evaluate_on_validation(model, val_loader)
    
    # Check early stopping
    if early_stopping(val_acc, model):
        print(f"Early stopping triggered at epoch {epoch+1}")
        break

# CRITICAL: Restore best weights
early_stopping.restore_best_weights(model)
print(f"Restored best weights (val_acc={early_stopping.best_score:.4f})")

What should you monitor? The most common choices:

MetricModeWhen to UseProsCons
Validation AccuracymaxClassification tasksEasy to interpret, directly measures performanceCan be misleading with imbalanced classes
Validation LossminAny taskSmooth signal, less noisy than accuracyHarder to interpret, can decrease while accuracy plateaus
F1 ScoremaxImbalanced classificationBetter than accuracy for imbalanced dataMore complex to compute
Validation PerplexityminLanguage modelingStandard metric for LMsOnly for language tasks

Default Choice

For classification: use validation accuracy (mode='max'). For regression: use validation loss (mode='min'). These are the most interpretable and work well in practice.

Early stopping isn't the only way to decide when to stop. Let's compare:

MethodHow It WorksProsCons
Fixed epochsTrain for N epochs, then stopSimple, reproducibleWastes time if N too large, undertrains if N too small
Early stoppingStop when validation metric plateausAutomatic, prevents overfittingRequires validation set, adds complexity
Learning rate schedulingReduce LR when progress slowsCan extend training productivelyDoesn't prevent overfitting, just delays it
Manual inspectionWatch training curves, stop manuallyFlexible, uses human judgmentNot reproducible, requires constant monitoring

Best practice: Combine early stopping with a maximum epoch limit. This gives you automatic stopping with a safety net:

combined_stopping.py
python
max_epochs = 100  # Safety net
early_stopping = EarlyStopping(patience=5)

for epoch in range(max_epochs):
    train_one_epoch()
    val_acc = evaluate()
    
    if early_stopping(val_acc, model):
        print(f"Early stopping at epoch {epoch+1}")
        break
else:
    # Loop completed without early stopping
    print(f"Reached max epochs ({max_epochs})")

early_stopping.restore_best_weights(model)

A picture is worth a thousand words. Here's what early stopping looks like in practice:

plot_early_stopping.py
python
import matplotlib.pyplot as plt

# Training history
epochs = list(range(1, 51))
train_loss = [...]  # Decreases monotonically
val_loss = [...]    # Decreases then increases (overfitting)

plt.figure(figsize=(10, 6))
plt.plot(epochs, train_loss, label='Training Loss', color='blue')
plt.plot(epochs, val_loss, label='Validation Loss', color='orange')

# Mark the best epoch (where early stopping would restore weights)
best_epoch = 23  # Example: validation loss was lowest at epoch 23
plt.axvline(x=best_epoch, color='green', linestyle='--', 
            label=f'Best Epoch ({best_epoch})')

# Mark where early stopping fired
stop_epoch = 28  # Example: stopped at epoch 28 (patience=5)
plt.axvline(x=stop_epoch, color='red', linestyle='--',
            label=f'Early Stop ({stop_epoch})')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Early Stopping: Validation Loss Plateaus')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

The key insight: training loss keeps decreasing (the model keeps improving on training data), but validation loss starts increasing after epoch 23 (the model is overfitting). Early stopping detects this and restores the weights from epoch 23.

  1. Not restoring best weights: The #1 bug. Always call restore_best_weights() after training stops.
  2. Using training metric instead of validation: Training accuracy always improves, even during overfitting. Use validation metric.
  3. Patience too small: patience=1 is too aggressive. Use at least 3-5.
  4. No validation set: You need a separate validation set. Don't use test set for early stopping (that's cheating).
  5. Forgetting copy.deepcopy(): model.state_dict() returns references, not copies. Use copy.deepcopy().
  6. Monitoring the wrong metric: For classification, monitor accuracy (mode='max'). For regression, monitor loss (mode='min').
  7. Not setting max_epochs: Always have a maximum epoch limit as a safety net.

Sometimes validation metrics improve by tiny amounts (0.001%) due to noise. You might want to ignore these tiny improvements and only count "real" improvements. That's what min_delta does:

min_delta_example.py
python
# Without min_delta
early_stopping = EarlyStopping(patience=5, min_delta=0.0)
# Improvement of 0.0001 counts as improvement

# With min_delta
early_stopping = EarlyStopping(patience=5, min_delta=0.001)
# Only improvements > 0.001 count as improvement
# Tiny fluctuations are ignored

# Example:
# Epoch 1: val_acc = 0.850 → New best
# Epoch 2: val_acc = 0.8505 → Improvement of 0.0005 < 0.001 → No improvement
# Epoch 3: val_acc = 0.852 → Improvement of 0.002 > 0.001 → New best!

When to use min_delta: If your validation metric is very noisy and fluctuates by small amounts, set min_delta to filter out noise. For most problems, min_delta=0.0 (the default) works fine.

Here's a complete training script with early stopping:

complete_training.py
python
import torch
import torch.nn as nn
from early_stopping import EarlyStopping

def train_with_early_stopping(
    model, train_loader, val_loader, 
    max_epochs=100, patience=5
):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    early_stopping = EarlyStopping(patience=patience, mode='max')
    
    history = {'train_loss': [], 'val_acc': []}
    
    for epoch in range(max_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        for batch_x, batch_y in train_loader:
            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation phase
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                output = model(batch_x)
                pred = output.argmax(dim=1)
                correct += (pred == batch_y).sum().item()
                total += batch_y.size(0)
        
        val_acc = correct / total
        
        # Record history
        history['train_loss'].append(train_loss)
        history['val_acc'].append(val_acc)
        
        # Print progress
        print(f"Epoch {epoch+1}/{max_epochs} - "
              f"train_loss: {train_loss:.4f} - "
              f"val_acc: {val_acc:.4f}")
        
        # Check early stopping
        if early_stopping(val_acc, model):
            print(f"\nEarly stopping triggered at epoch {epoch+1}")
            print(f"Best validation accuracy: {early_stopping.best_score:.4f}")
            break
    
    # Restore best weights
    early_stopping.restore_best_weights(model)
    print("\nRestored best model weights")
    
    return history

# Usage
model = YourModel()
history = train_with_early_stopping(
    model, train_loader, val_loader,
    max_epochs=100, patience=5
)
  1. Early stopping prevents overfitting by monitoring validation metrics and stopping when they plateau.
  2. Patience controls sensitivity: Higher patience is more conservative, lower patience stops faster.
  3. Always restore best weights: The last weights are overfit; the best weights are from several epochs ago.
  4. Use copy.deepcopy(): Make true copies of weights, not references.
  5. Monitor validation metrics: Never use training metrics (they always improve) or test metrics (that's cheating).
  6. Combine with max_epochs: Set a maximum epoch limit as a safety net.
  7. Default settings work well: patience=5, min_delta=0.0, mode='max' for accuracy.
  8. Early stopping is free regularization: No hyperparameters to tune, just works.

The Bottom Line

Early stopping is one of the simplest and most effective techniques in deep learning. It automatically finds the optimal training duration, prevents overfitting, and requires minimal tuning. The key is understanding the patience parameter and always restoring the best weights. Get these right, and early stopping will save you hours of wasted training time and prevent overfitting without any manual intervention.

Related Articles