
Adam Optimizer Explained: Why It's Better Than Plain Gradient Descent
A complete beginner's guide to the Adam optimizer - how it adapts learning rates per parameter, why it converges faster than SGD, and how to use it effectively in PyTorch.
Imagine driving a car where you can only set one speed for the entire journey — 60 mph on highways, 60 mph in school zones, 60 mph on bumpy roads. That's what plain gradient descent (SGD) does: one learning rate for all parameters. Adam (Adaptive Moment Estimation) is like having adaptive cruise control that automatically adjusts speed based on road conditions. This post explains exactly how Adam works, why it's become the default optimizer for most deep learning tasks, and how to use it effectively.
What You'll Learn
Let's start by understanding what we're improving upon. SGD (Stochastic Gradient Descent) is the simplest optimizer. The update rule is:
# Plain SGD update rule
weight = weight - learning_rate * gradient
# Example:
# If gradient = 0.5 and learning_rate = 0.1
# weight = weight - 0.1 * 0.5 = weight - 0.05Every parameter gets the same learning rate. This causes three major problems:
Imagine you're training a network with 1 million parameters. Some parameters have large, consistent gradients (they know which direction to go). Others have tiny, noisy gradients (they're uncertain). With one global learning rate:
- Large gradients: If learning rate is too high, these parameters overshoot and oscillate
- Small gradients: If learning rate is too low, these parameters barely move and learning is slow
You're forced to choose a learning rate that's a compromise — not optimal for anyone.
Mini-batch gradients are noisy estimates of the true gradient. One batch might say 'go left', the next says 'go right'. SGD follows these noisy signals directly, leading to a zigzag path instead of a smooth descent.
The Zigzag Problem
Loss landscapes often have ravines (steep in one direction, flat in another) and plateaus (flat everywhere). SGD struggles with both:
- Ravines: SGD bounces between the steep walls instead of smoothly descending
- Plateaus: Gradients are tiny, so SGD barely moves even though there's a cliff edge nearby
Adam combines two earlier innovations: Momentum and RMSprop. Let's understand each before seeing how Adam combines them.
Momentum adds 'inertia' to gradient descent. Instead of following the current gradient exactly, we maintain a velocity — a running average of recent gradients.
# Momentum update rule
velocity = beta * velocity + (1 - beta) * gradient
weight = weight - learning_rate * velocity
# beta is typically 0.9 (90% old velocity, 10% new gradient)
# This smooths out noise and builds up speed in consistent directionsThink of it like pushing a ball down a hill. The ball doesn't instantly change direction with every bump — it has momentum that smooths out the path. If gradients consistently point in one direction, velocity builds up and we move faster. If gradients oscillate, velocity averages them out and we move more carefully.
The Bowling Ball Analogy
RMSprop (Root Mean Square Propagation) adapts the learning rate for each parameter based on the magnitude of recent gradients.
# RMSprop update rule
squared_gradient_avg = beta * squared_gradient_avg + (1 - beta) * gradient**2
weight = weight - learning_rate / sqrt(squared_gradient_avg + epsilon) * gradient
# Parameters with large gradients get smaller effective learning rates
# Parameters with small gradients get larger effective learning ratesThe key insight: divide the learning rate by the square root of the average squared gradient. This means:
- Large gradients → Large denominator → Smaller effective learning rate → Smaller steps
- Small gradients → Small denominator → Larger effective learning rate → Larger steps
Each parameter gets its own adaptive learning rate based on its gradient history.
Adam combines momentum (for smoothing) and RMSprop (for adaptive rates). Here's the complete algorithm:
# Adam update rule (simplified)
# 1. Compute first moment (momentum-like)
m = beta1 * m + (1 - beta1) * gradient
# 2. Compute second moment (RMSprop-like)
v = beta2 * v + (1 - beta2) * gradient**2
# 3. Bias correction (important for early steps)
m_corrected = m / (1 - beta1**t) # t = current step number
v_corrected = v / (1 - beta2**t)
# 4. Update weight
weight = weight - learning_rate * m_corrected / (sqrt(v_corrected) + epsilon)
# Typical hyperparameters:
# beta1 = 0.9 (momentum decay)
# beta2 = 0.999 (RMSprop decay)
# epsilon = 1e-8 (numerical stability)Let's break down each component:
m is a running average of gradients (like momentum). beta1 = 0.9 means we keep 90% of the old average and add 10% of the new gradient. This smooths out noise and builds up speed in consistent directions.
v is a running average of squared gradients (like RMSprop). beta2 = 0.999 means we keep 99.9% of the old average and add 0.1% of the new squared gradient. This tracks the 'volatility' of each parameter's gradients.
Here's a subtle but important detail. At the start of training, m and v are initialized to zero. This creates a bias toward zero in the early steps. Adam corrects this by dividing by (1 - beta**t), where t is the step number.
# Why bias correction matters
# Suppose beta1 = 0.9, and we're at step 1
# Without correction:
m = 0.9 * 0 + 0.1 * gradient = 0.1 * gradient # Too small!
# With correction:
m_corrected = (0.1 * gradient) / (1 - 0.9**1) = (0.1 * gradient) / 0.1 = gradient # Correct!
# As t increases, (1 - beta**t) approaches 1, so correction becomes negligibleThe final update divides the smoothed gradient (m_corrected) by the square root of the smoothed squared gradient (sqrt(v_corrected)). This gives each parameter an adaptive learning rate based on its gradient history.
The Adaptive Cruise Control Analogy
Adam provides several key advantages over plain SGD:
In practice, Adam typically converges 5-10x faster than SGD. Why? Because it adapts the learning rate per parameter. Parameters that need large steps get them, parameters that need small steps get them. No more one-size-fits-all compromise.
With SGD, choosing the right learning rate is critical and problem-specific. Too high and training explodes, too low and it crawls. Adam is much more forgiving. The default lr=1e-3 (0.001) works well for most problems. You can often use it without tuning.
In problems like NLP, many parameters have zero gradients most of the time (sparse gradients). Adam handles this well because it adapts per parameter. Parameters that rarely update get larger effective learning rates when they do update.
The default hyperparameters (beta1=0.9, beta2=0.999, lr=1e-3) work well for most problems. This is why Adam has become the default optimizer — it 'just works' without extensive tuning.
PyTorch makes Adam easy to use. Here's a complete example:
import torch
import torch.nn as nn
import torch.optim as optim
# Define your model
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
# Create Adam optimizer
optimizer = optim.Adam(
model.parameters(),
lr=1e-3, # Learning rate (default: 1e-3)
betas=(0.9, 0.999), # (beta1, beta2) - momentum and RMSprop decay
eps=1e-8, # Epsilon for numerical stability
weight_decay=0 # L2 regularization (more on this below)
)
# Training loop
for epoch in range(num_epochs):
for batch_x, batch_y in train_loader:
# 1. Zero gradients from previous step
optimizer.zero_grad()
# 2. Forward pass
output = model(batch_x)
loss = criterion(output, batch_y)
# 3. Backward pass (compute gradients)
loss.backward()
# 4. Update weights using Adam
optimizer.step()
# That's it! Adam handles all the complexity internally| Parameter | Default | What It Does | When to Change |
|---|---|---|---|
| lr | 1e-3 | Base learning rate | Increase if training is too slow, decrease if loss explodes |
| beta1 | 0.9 | Momentum decay (first moment) | Rarely changed; 0.9 works well |
| beta2 | 0.999 | RMSprop decay (second moment) | Rarely changed; 0.999 works well |
| eps | 1e-8 | Numerical stability constant | Never change this |
| weight_decay | 0 | L2 regularization strength | Set to 1e-4 or 1e-5 to prevent overfitting |
Weight decay is L2 regularization built into the optimizer. It adds a penalty for large weights, helping prevent overfitting.
# Without weight decay
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# With weight decay (recommended for most problems)
optimizer = optim.Adam(
model.parameters(),
lr=1e-3,
weight_decay=1e-4 # Penalize large weights
)
# What weight decay does:
# weight = weight - lr * gradient - lr * weight_decay * weight
# ^^^^^^^^^^^^^^^^^^^^^^^^^
# This term shrinks weights toward zeroCommon weight decay values:
- 0: No regularization (only use if you have tons of data)
- 1e-5 (0.00001): Mild regularization
- 1e-4 (0.0001): Standard choice for most problems
- 1e-3 (0.001): Strong regularization (if overfitting is severe)
Start with 1e-4
Adam is the default choice for most problems, but SGD with momentum still has its place:
| Optimizer | Best For | Pros | Cons |
|---|---|---|---|
| Adam | Most problems, especially NLP and small datasets | Fast convergence, works out of box, less tuning needed | Sometimes worse final performance than well-tuned SGD |
| SGD + Momentum | Computer vision, very large datasets, when you have time to tune | Can achieve slightly better final accuracy | Requires careful learning rate tuning, slower convergence |
| AdamW | Transformers, modern NLP | Better weight decay than Adam, state-of-the-art results | Slightly more complex |
The Practical Rule
- Learning rate too high: If loss explodes to NaN in the first few steps, your learning rate is too high. Try 1e-4 instead of 1e-3.
- Not using weight decay: Without regularization, models often overfit. Start with weight_decay=1e-4.
- Forgetting optimizer.zero_grad(): Gradients accumulate by default. Always call zero_grad() before backward().
- Using Adam for everything: For computer vision with huge datasets, well-tuned SGD can outperform Adam. Don't be dogmatic.
- Not adjusting learning rate: For very long training runs, consider learning rate scheduling (reduce lr when progress plateaus).
- Comparing Adam and SGD with same learning rate: Adam typically needs a smaller learning rate than SGD. Don't compare them with the same lr.
For long training runs, you might want to reduce the learning rate over time. PyTorch provides several schedulers:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Option 1: Reduce LR when validation loss plateaus
scheduler = ReduceLROnPlateau(
optimizer,
mode='min', # Minimize validation loss
factor=0.5, # Multiply LR by 0.5 when plateau detected
patience=5, # Wait 5 epochs before reducing
verbose=True
)
# Training loop
for epoch in range(num_epochs):
train_loss = train_one_epoch()
val_loss = validate()
# Update learning rate based on validation loss
scheduler.step(val_loss)
# Option 2: Reduce LR every N epochs
scheduler = StepLR(
optimizer,
step_size=30, # Reduce every 30 epochs
gamma=0.1 # Multiply LR by 0.1
)
for epoch in range(num_epochs):
train_one_epoch()
scheduler.step() # Update LR after each epoch- Adam adapts learning rates per parameter based on gradient history, making it much more effective than plain SGD.
- It combines momentum (smoothing) and RMSprop (adaptive rates) to get the best of both worlds.
- Default hyperparameters work well: lr=1e-3, beta1=0.9, beta2=0.999 are good starting points.
- Use weight decay: weight_decay=1e-4 provides mild regularization and helps prevent overfitting.
- Adam converges 5-10x faster than SGD in most cases, with less hyperparameter tuning needed.
- Always call optimizer.zero_grad() before backward() to clear old gradients.
- For transformers, use AdamW (Adam with decoupled weight decay) for best results.
- Consider learning rate scheduling for very long training runs.
The Bottom Line
Related Articles
Batch Normalization Explained: Why Your Neural Network Needs It
A complete beginner's guide to Batch Normalization - what it is, why it works, how to implement it, and the critical train vs eval mode difference that trips up everyone.
Dropout Explained: The Surprisingly Simple Trick That Prevents Overfitting
A complete beginner's guide to Dropout regularization - why randomly turning off neurons makes neural networks smarter, how it works, and how to use it correctly in PyTorch.
Early Stopping Explained: Knowing When to Stop Training
A complete beginner's guide to early stopping - how to automatically find the optimal training duration, prevent overfitting, and save the best model weights.