
Transfer Learning in NLP: Standing on the Shoulders of Giants
A complete beginner's guide to transfer learning in NLP - how pretrained models work, why freezing encoders makes sense, and how to use sentence transformers effectively.
Imagine you want to become a chef. You could start from scratch — learning what fire is, how heat works, basic chemistry. Or you could start with knowledge that master chefs have already figured out, and focus on your specific recipes. Transfer learning is the second approach: borrowing intelligence from models trained on massive datasets, and adapting it to your specific problem. This post explains how transfer learning revolutionized NLP, why it works so well, and how to use it effectively.
What You'll Learn
Before transfer learning, every NLP project started from zero. Want to classify movie reviews? Train a model from scratch on your 10,000 reviews. Want to detect spam? Train from scratch on your emails. Want to answer questions? Train from scratch on your Q&A pairs.
This had three major problems:
Deep learning models have millions of parameters. To train them well, you need millions of examples. But most real-world projects have thousands, not millions. Training from scratch on small datasets leads to severe overfitting — the model memorizes the training data but fails on new examples.
The Data Hunger Problem
Training a language model from scratch takes weeks on expensive GPUs. Every project repeats this expensive process, even though they're all learning the same basic things: what words mean, how grammar works, how sentences relate to each other. It's like every chef learning from scratch that water boils at 100°C — wasteful duplication of effort.
With limited data, models learn superficial patterns. A spam classifier might learn 'if email contains "free money", it's spam' — but miss deeper patterns like writing style, urgency markers, or social engineering tactics. These deeper patterns require massive datasets to learn.
Transfer learning flips the script. Instead of starting from zero, you start with a model that's already been trained on billions of words. This model has already learned:
- What words mean and how they relate to each other
- Grammar and syntax patterns
- Common phrases and idioms
- Semantic relationships (synonyms, antonyms, analogies)
- Context and how meaning changes based on surrounding words
You take this pretrained model and adapt it to your specific task. This is called transfer learning — transferring knowledge from one task (general language understanding) to another (your specific problem).
The Chef Analogy
Let's understand what happens when a model is 'pretrained'. The most common approach is called masked language modeling:
The model is shown billions of sentences with random words masked out, and learns to predict the missing words:
Original: "The cat sat on the mat"
Masked: "The cat [MASK] on the mat"
Task: Predict that [MASK] = "sat"
Original: "I love eating pizza for dinner"
Masked: "I love [MASK] pizza for dinner"
Task: Predict that [MASK] = "eating"
Original: "The weather is beautiful today"
Masked: "The [MASK] is beautiful today"
Task: Predict that [MASK] = "weather"To predict the masked word, the model must understand:
- Context: What words appear before and after
- Grammar: What part of speech fits here (noun, verb, adjective)
- Semantics: What meaning makes sense in this context
- World knowledge: Common patterns and relationships
After training on billions of sentences, the model develops a rich internal representation of language. It hasn't just memorized words — it's learned the deep structure of how language works.
Sentence transformers are pretrained models specifically trained to produce good sentence embeddings. They're trained using contrastive learning:
Training pairs:
Similar sentences (should have similar embeddings):
- "Book a flight to Tokyo" ↔ "Reserve a plane ticket to Tokyo"
- "What's the weather?" ↔ "How's the weather today?"
Dissimilar sentences (should have different embeddings):
- "Book a flight" ↔ "What's the weather?"
- "I love pizza" ↔ "The sky is blue"
The model learns to make similar sentences close in embedding space,
and dissimilar sentences far apart.This training creates embeddings where semantic similarity = geometric proximity. Sentences that mean similar things end up close together in the 384-dimensional space.
There are two main ways to use a pretrained model:
Use the pretrained model as a frozen feature extractor. You never update its weights — you just use it to convert text into embeddings, then train a small classifier on top.
from sentence_transformers import SentenceTransformer
import torch.nn as nn
# 1. Load pretrained model (frozen)
encoder = SentenceTransformer('all-MiniLM-L6-v2')
# We NEVER update encoder's weights
# 2. Convert all text to embeddings once
train_embeddings = encoder.encode(train_texts) # (N, 384)
# 3. Train a small classifier on the embeddings
classifier = nn.Sequential(
nn.Linear(384, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
# Only classifier weights are updated during trainingPros:
- Fast: Only training a small classifier, not the entire encoder
- Low memory: Don't need to store gradients for the encoder
- Works on CPU: No need for expensive GPUs
- Can't overfit the encoder: The pretrained weights stay perfect
Cons:
- Can't adapt encoder: If your domain is very different from the pretraining data, you're stuck
- Slightly lower accuracy: Fine-tuning usually gives 2-5% better accuracy
Update the pretrained model's weights on your specific task. You start with the pretrained weights and continue training, but with a very small learning rate.
from transformers import AutoModel
import torch.nn as nn
# 1. Load pretrained model (will be updated)
encoder = AutoModel.from_pretrained('bert-base-uncased')
# 2. Add a classifier head
model = nn.Sequential(
encoder,
nn.Linear(768, num_classes)
)
# 3. Train the ENTIRE model (encoder + classifier)
# Use a small learning rate for the encoder
optimizer = torch.optim.Adam([
{'params': encoder.parameters(), 'lr': 1e-5}, # Small LR for encoder
{'params': classifier.parameters(), 'lr': 1e-3} # Normal LR for classifier
])Pros:
- Best accuracy: Usually 2-5% better than frozen features
- Adapts to your domain: Can learn domain-specific patterns
Cons:
- Slow: Training the entire encoder takes much longer
- Needs GPU: Too slow on CPU
- High memory: Need to store gradients for millions of parameters
- Can overfit: With small datasets, you might make the encoder worse
Which Approach to Choose?
Let's dig deeper into why freezing the encoder is often the right choice, especially for small datasets.
The pretrained encoder was trained on billions of sentences. Your dataset has thousands. If you try to 'improve' it with your tiny dataset, you'll almost certainly make it worse. This is called catastrophic forgetting — the model forgets its general knowledge while memorizing your specific examples.
The Dictionary Analogy
Computing gradients through a transformer encoder is expensive. By freezing it, you:
- Compute embeddings once: Convert all text to embeddings before training starts
- No backprop through encoder: Only compute gradients for the small classifier
- Train on CPU: The classifier is small enough to train without a GPU
- Iterate faster: Training takes minutes instead of hours
With a frozen encoder, embeddings never change. You can compute them once and save to disk:
import torch
from pathlib import Path
def get_embeddings(texts, cache_path):
# Check if we already computed these
if Path(cache_path).exists():
print("Loading from cache (instant!)")
return torch.load(cache_path)
# First time - compute and save
print("Computing embeddings (takes a few minutes)...")
embeddings = encoder.encode(texts)
torch.save(embeddings, cache_path)
return embeddings
# First run: computes embeddings (slow)
train_emb = get_embeddings(train_texts, 'train_embeddings.pt')
# Subsequent runs: loads from cache (instant)
train_emb = get_embeddings(train_texts, 'train_embeddings.pt')This is a huge time saver. Computing 15,000 embeddings takes 2-3 minutes. If you're experimenting with different classifier architectures, you'd waste hours recomputing the same embeddings. With caching, subsequent runs start instantly.
Let's see a complete example of using sentence transformers for classification:
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
# 1. Load pretrained sentence transformer
model = SentenceTransformer('all-MiniLM-L6-v2')
# This model produces 384-dimensional embeddings
# 2. Encode your texts
train_texts = [
"Book a flight to Tokyo",
"What's the weather today?",
"Play some music",
# ... thousands more
]
train_embeddings = model.encode(
train_texts,
batch_size=256, # Process 256 at a time
show_progress_bar=True, # Show progress
convert_to_tensor=True # Return PyTorch tensor
)
print(train_embeddings.shape) # (N, 384)
# 3. Build a classifier on top
classifier = nn.Sequential(
nn.Linear(384, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
# 4. Train the classifier (not the encoder!)
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(50):
# Forward pass
logits = classifier(train_embeddings)
loss = criterion(logits, train_labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 5. Inference on new text
new_text = "Reserve a plane ticket"
new_embedding = model.encode([new_text], convert_to_tensor=True)
prediction = classifier(new_embedding).argmax(dim=1)There are many pretrained sentence transformers. Here are the most popular:
| Model | Embedding Size | Size | Speed | Quality | Best For |
|---|---|---|---|---|---|
| all-MiniLM-L6-v2 | 384 | 80MB | Fast | Good | General purpose, CPU-friendly |
| all-mpnet-base-v2 | 768 | 420MB | Medium | Best | When you need highest quality |
| paraphrase-MiniLM-L6-v2 | 384 | 80MB | Fast | Good | Paraphrase detection |
| multi-qa-MiniLM-L6-cos-v1 | 384 | 80MB | Fast | Good | Question answering |
Start with all-MiniLM-L6-v2
- Fine-tuning on tiny datasets: With <5,000 examples, stick to frozen features. Fine-tuning will overfit.
- Not caching embeddings: Computing embeddings takes minutes. Cache them to disk and reuse.
- Using the wrong model: all-MiniLM-L6-v2 is for general text. For code, use code-specific models. For scientific text, use scientific models.
- Forgetting to normalize: Some models require L2 normalization of embeddings. Check the model card.
- Comparing embeddings with wrong metric: Use cosine similarity, not Euclidean distance, for sentence embeddings.
- Training the encoder on small data: If you have <10,000 examples, don't fine-tune. You'll make it worse.
Transfer learning has revolutionized NLP. Here's what changed:
| Before Transfer Learning | After Transfer Learning |
|---|---|
| Need millions of labeled examples | Need thousands of labeled examples |
| Train for weeks on expensive GPUs | Train for hours on CPU |
| Each project starts from scratch | Each project starts with pretrained knowledge |
| Accuracy: 60-70% on hard tasks | Accuracy: 85-95% on hard tasks |
| Only big companies can do NLP | Anyone can do NLP |
The Democratization of NLP
- Transfer learning means starting with pretrained knowledge instead of training from scratch.
- Pretrained models learned from billions of sentences and understand deep language patterns.
- Two approaches: Feature extraction (frozen encoder) and fine-tuning (update encoder).
- Start with frozen features: Faster, works on CPU, can't overfit the encoder.
- Only fine-tune if: You have 10,000+ examples, a GPU, and need that extra 5% accuracy.
- Cache embeddings: Compute once, save to disk, reuse forever.
- all-MiniLM-L6-v2 is a great default: Small, fast, high-quality embeddings.
- Transfer learning democratized NLP: Anyone can now build state-of-the-art systems.
The Bottom Line
Related Articles
From Words to Intelligence: Building an MLP Classifier on Pretrained Sentence Embeddings
A deep dive into pretrained sentence embeddings, MLP architecture, BatchNorm, Dropout, Adam, and early stopping — with full PyTorch implementation.
Understanding Neural Networks: From Word Counting to Meaning Understanding
A beginner-friendly guide to pretrained sentence embeddings, multi-layer perceptrons, and the building blocks that make modern NLP work — explained with simple examples and zero jargon.
Understanding Transformers: The Architecture Behind Modern AI
A comprehensive guide to understanding the Transformer architecture that powers GPT, BERT, and other modern language models.