This cheatsheet demonstrates practical fine-tuning techniques for geospatial models using small examples that run quickly.
Setup and Sample Data
import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderimport numpy as npimport matplotlib.pyplot as plt# Set seeds for reproducibilitytorch.manual_seed(42)np.random.seed(42)print(f"PyTorch version: {torch.__version__}")print("Quick fine-tuning examples")
PyTorch version: 2.7.1
Quick fine-tuning examples
Simple Classification Model
class SimpleGeospatialModel(nn.Module):"""Lightweight model for demonstration"""def__init__(self, num_bands=6, num_classes=5):super().__init__()# Simple CNN backboneself.features = nn.Sequential( nn.Conv2d(num_bands, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(4) )# Classifier headself.classifier = nn.Sequential( nn.Linear(128*4*4, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, num_classes) )def forward(self, x): features =self.features(x) features = features.view(features.size(0), -1)returnself.classifier(features)# Create modelmodel = SimpleGeospatialModel(num_bands=6, num_classes=5)print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
class EarlyStopping:"""Early stopping utility"""def__init__(self, patience=3, min_delta=0.001):self.patience = patienceself.min_delta = min_deltaself.counter =0self.best_loss =float('inf')def__call__(self, val_loss):if val_loss <self.best_loss -self.min_delta:self.best_loss = val_lossself.counter =0returnFalse# Continue trainingelse:self.counter +=1returnself.counter >=self.patience # Stop if patience exceeded# Demonstrate early stoppingdef train_with_early_stopping():"""Train with early stopping""" model_es = SimpleGeospatialModel(num_bands=6, num_classes=5) optimizer = optim.Adam(model_es.parameters(), lr=0.001) early_stopping = EarlyStopping(patience=2)print("=== Early Stopping Demo ===")for epoch inrange(10): # Max 10 epochs train_loss, train_acc = train_epoch(model_es, train_loader, optimizer, criterion) val_loss, val_acc = validate(model_es, val_loader, criterion)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.1f}%, "f"Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.1f}%")if early_stopping(val_loss):print(f"Early stopping at epoch {epoch+1}")breaktrain_with_early_stopping()
=== Early Stopping Demo ===
Epoch 1: Train Loss: 1.612, Train Acc: 16.2%, Val Loss: 1.567, Val Acc: 20.0%
Epoch 2: Train Loss: 1.568, Train Acc: 18.8%, Val Loss: 1.514, Val Acc: 40.0%
Epoch 3: Train Loss: 1.530, Train Acc: 15.0%, Val Loss: 1.459, Val Acc: 40.0%
Epoch 4: Train Loss: 1.492, Train Acc: 33.8%, Val Loss: 1.407, Val Acc: 40.0%
Epoch 5: Train Loss: 1.395, Train Acc: 42.5%, Val Loss: 1.369, Val Acc: 40.0%
Epoch 6: Train Loss: 1.373, Train Acc: 38.8%, Val Loss: 1.240, Val Acc: 60.0%
Epoch 7: Train Loss: 1.240, Train Acc: 37.5%, Val Loss: 1.252, Val Acc: 40.0%
Epoch 8: Train Loss: 1.161, Train Acc: 52.5%, Val Loss: 1.039, Val Acc: 50.0%
Epoch 9: Train Loss: 1.051, Train Acc: 43.8%, Val Loss: 0.896, Val Acc: 60.0%
Epoch 10: Train Loss: 1.009, Train Acc: 52.5%, Val Loss: 0.948, Val Acc: 60.0%
Model Comparison
def compare_final_performance():"""Compare final performance of different strategies""" models = {'Full Training': model_full,'Frozen Features': model_frozen,'Layerwise LR': model_layerwise }print("\n=== Final Performance Comparison ===")for name, model in models.items(): val_loss, val_acc = validate(model, val_loader, criterion)print(f"{name:15}: Val Acc = {val_acc:.1f}%, Val Loss = {val_loss:.3f}")compare_final_performance()
=== Final Performance Comparison ===
Full Training : Val Acc = 40.0%, Val Loss = 1.371
Frozen Features: Val Acc = 20.0%, Val Loss = 1.575
Layerwise LR : Val Acc = 20.0%, Val Loss = 1.547
Transfer Learning Best Practices
def show_best_practices():"""Demonstrate transfer learning best practices"""print("\n=== Transfer Learning Best Practices ===") practices = {"Start with lower learning rates": "0.0001 - 0.001 typically work well","Freeze early layers initially": "Then gradually unfreeze if needed","Use different LRs for different layers": "Lower for pretrained, higher for new layers","Monitor validation carefully": "Use early stopping to prevent overfitting","Data augmentation is crucial": "Especially with limited training data","Gradual unfreezing": "Unfreeze layers progressively during training" }for practice, explanation in practices.items():print(f"β’ {practice}: {explanation}")show_best_practices()
=== Transfer Learning Best Practices ===
β’ Start with lower learning rates: 0.0001 - 0.001 typically work well
β’ Freeze early layers initially: Then gradually unfreeze if needed
β’ Use different LRs for different layers: Lower for pretrained, higher for new layers
β’ Monitor validation carefully: Use early stopping to prevent overfitting
β’ Data augmentation is crucial: Especially with limited training data
β’ Gradual unfreezing: Unfreeze layers progressively during training
Feature Visualization
def visualize_learned_features(model, sample_image):"""Visualize what the model has learned""" model.eval()# Get intermediate features features = []def hook_fn(module, input, output): features.append(output.detach())# Register hooks on conv layers hooks = []for name, module in model.features.named_modules():ifisinstance(module, nn.Conv2d): hooks.append(module.register_forward_hook(hook_fn))# Forward passwith torch.no_grad(): _ = model(sample_image.unsqueeze(0))# Remove hooksfor hook in hooks: hook.remove()print(f"\n=== Feature Map Analysis ===")for i, feature_map inenumerate(features):print(f"Layer {i+1}: {feature_map.shape}")return features# Analyze learned featuressample_img, _ = train_dataset[0]learned_features = visualize_learned_features(model_full, sample_img)
[W812 19:01:05.616707000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
Key Takeaways
print("\n=== Fine-tuning Strategy Summary ===")strategies = {"Full Training": "Train all parameters - best when you have lots of data","Frozen Features": "Only train classifier - fastest, good for small datasets", "Layerwise LR": "Different learning rates - balanced approach","Gradual Unfreezing": "Progressive training - best for complex adaptation","Early Stopping": "Prevent overfitting - essential for small datasets"}for strategy, description in strategies.items():print(f"β’ {strategy}: {description}")print(f"\nTraining completed successfully! All examples ran quickly.")
=== Fine-tuning Strategy Summary ===
β’ Full Training: Train all parameters - best when you have lots of data
β’ Frozen Features: Only train classifier - fastest, good for small datasets
β’ Layerwise LR: Different learning rates - balanced approach
β’ Gradual Unfreezing: Progressive training - best for complex adaptation
β’ Early Stopping: Prevent overfitting - essential for small datasets
Training completed successfully! All examples ran quickly.
Summary
Start simple with frozen features and classifier-only training
Use appropriate learning rates - lower for pretrained layers
Monitor validation carefully to avoid overfitting
Implement early stopping for robust training
Consider gradual unfreezing for complex adaptations
Visualize features to understand what the model learns
These techniques work across different model architectures and can be scaled up for larger, real-world applications.
Source Code
---title: "Fine-tuning Strategies"subtitle: "Basic fine-tuning approaches"jupyter: geoaiformat: html: code-fold: false---# Fine-tuning StrategiesThis cheatsheet demonstrates practical fine-tuning techniques for geospatial models using small examples that run quickly.## Setup and Sample Data```{python}import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderimport numpy as npimport matplotlib.pyplot as plt# Set seeds for reproducibilitytorch.manual_seed(42)np.random.seed(42)print(f"PyTorch version: {torch.__version__}")print("Quick fine-tuning examples")```## Simple Classification Model```{python}class SimpleGeospatialModel(nn.Module):"""Lightweight model for demonstration"""def__init__(self, num_bands=6, num_classes=5):super().__init__()# Simple CNN backboneself.features = nn.Sequential( nn.Conv2d(num_bands, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(4) )# Classifier headself.classifier = nn.Sequential( nn.Linear(128*4*4, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, num_classes) )def forward(self, x): features =self.features(x) features = features.view(features.size(0), -1)returnself.classifier(features)# Create modelmodel = SimpleGeospatialModel(num_bands=6, num_classes=5)print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")```## Synthetic Dataset for Fast Training```{python}class SyntheticGeospatialDataset(Dataset):"""Synthetic dataset that generates data on-the-fly"""def__init__(self, num_samples=100, size=64, num_bands=6, num_classes=5):self.num_samples = num_samplesself.size = sizeself.num_bands = num_bandsself.num_classes = num_classes# Fixed seed for consistent synthetic dataself.rng = np.random.RandomState(42)def__len__(self):returnself.num_samplesdef__getitem__(self, idx):# Generate synthetic satellite-like image# Different patterns for different classes class_label = idx %self.num_classes# Create class-specific patternsif class_label ==0: # Water image =self.rng.normal(0.2, 0.1, (self.num_bands, self.size, self.size))elif class_label ==1: # Forest image =self.rng.normal(0.4, 0.15, (self.num_bands, self.size, self.size))elif class_label ==2: # Urban image =self.rng.normal(0.6, 0.2, (self.num_bands, self.size, self.size))elif class_label ==3: # Agriculture image =self.rng.normal(0.5, 0.12, (self.num_bands, self.size, self.size))else: # Bare soil image =self.rng.normal(0.7, 0.18, (self.num_bands, self.size, self.size))# Add some spatial structure image = np.clip(image, 0, 1)return torch.FloatTensor(image), torch.LongTensor([class_label])# Create datasetstrain_dataset = SyntheticGeospatialDataset(num_samples=80, size=64)val_dataset = SyntheticGeospatialDataset(num_samples=20, size=64)train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)print(f"Training samples: {len(train_dataset)}")print(f"Validation samples: {len(val_dataset)}")# Show samplesample_image, sample_label = train_dataset[0]print(f"Sample shape: {sample_image.shape}, Label: {sample_label.item()}")```## Fine-tuning Strategy 1: Full Model Training```{python}def train_epoch(model, dataloader, optimizer, criterion):"""Train for one epoch""" model.train() total_loss =0 correct =0 total =0for batch_idx, (data, targets) inenumerate(dataloader): targets = targets.squeeze() optimizer.zero_grad() outputs = model(data) loss = criterion(outputs, targets) loss.backward() optimizer.step() total_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item()return total_loss /len(dataloader), 100.* correct / totaldef validate(model, dataloader, criterion):"""Validate model""" model.eval() total_loss =0 correct =0 total =0with torch.no_grad():for data, targets in dataloader: targets = targets.squeeze() outputs = model(data) loss = criterion(outputs, targets) total_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item()return total_loss /len(dataloader), 100.* correct / total# Setup for full trainingmodel_full = SimpleGeospatialModel(num_bands=6, num_classes=5)optimizer = optim.Adam(model_full.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss()print("=== Full Model Training ===")# Quick training (just 3 epochs for demo)for epoch inrange(3): train_loss, train_acc = train_epoch(model_full, train_loader, optimizer, criterion) val_loss, val_acc = validate(model_full, val_loader, criterion)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.1f}%, "f"Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.1f}%")```## Fine-tuning Strategy 2: Frozen Feature Extractor```{python}# Create a new model with frozen featuresmodel_frozen = SimpleGeospatialModel(num_bands=6, num_classes=5)# Freeze feature layersfor param in model_frozen.features.parameters(): param.requires_grad =False# Count trainable parameterstrainable_params =sum(p.numel() for p in model_frozen.parameters() if p.requires_grad)total_params =sum(p.numel() for p in model_frozen.parameters())print(f"=== Frozen Features Training ===")print(f"Trainable parameters: {trainable_params:,} / {total_params:,}")print(f"Frozen: {total_params - trainable_params:,} parameters")# Only optimize classifieroptimizer_frozen = optim.Adam(model_frozen.classifier.parameters(), lr=0.001)# Quick trainingfor epoch inrange(3): train_loss, train_acc = train_epoch(model_frozen, train_loader, optimizer_frozen, criterion) val_loss, val_acc = validate(model_frozen, val_loader, criterion)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.1f}%, "f"Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.1f}%")```## Fine-tuning Strategy 3: Layer-wise Learning Rates```{python}# Different learning rates for different partsdef create_layerwise_optimizer(model, base_lr=0.001):"""Create optimizer with different learning rates for different layers""" params_groups = [ {'params': model.features.parameters(), 'lr': base_lr *0.1}, # Lower LR for features {'params': model.classifier.parameters(), 'lr': base_lr} # Higher LR for classifier ]return optim.Adam(params_groups)model_layerwise = SimpleGeospatialModel(num_bands=6, num_classes=5)optimizer_layerwise = create_layerwise_optimizer(model_layerwise)print("=== Layerwise Learning Rates ===")print("Features: 0.0001, Classifier: 0.001")# Quick trainingfor epoch inrange(3): train_loss, train_acc = train_epoch(model_layerwise, train_loader, optimizer_layerwise, criterion) val_loss, val_acc = validate(model_layerwise, val_loader, criterion)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.1f}%, "f"Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.1f}%")```## Learning Rate Scheduling```{python}# Demonstrate learning rate schedulingfrom torch.optim.lr_scheduler import StepLR, CosineAnnealingLRdef train_with_scheduler():"""Train model with learning rate scheduling""" model_sched = SimpleGeospatialModel(num_bands=6, num_classes=5) optimizer = optim.Adam(model_sched.parameters(), lr=0.01) # Higher initial LR scheduler = StepLR(optimizer, step_size=2, gamma=0.5) # Reduce LR every 2 epochsprint("=== Learning Rate Scheduling ===")for epoch inrange(4): # 4 epochs to see LR changes train_loss, train_acc = train_epoch(model_sched, train_loader, optimizer, criterion) val_loss, val_acc = validate(model_sched, val_loader, criterion) current_lr = optimizer.param_groups[0]['lr']print(f"Epoch {epoch+1}: LR: {current_lr:.4f}, Train Loss: {train_loss:.3f}, "f"Train Acc: {train_acc:.1f}%, Val Acc: {val_acc:.1f}%") scheduler.step()train_with_scheduler()```## Early Stopping Implementation```{python}class EarlyStopping:"""Early stopping utility"""def__init__(self, patience=3, min_delta=0.001):self.patience = patienceself.min_delta = min_deltaself.counter =0self.best_loss =float('inf')def__call__(self, val_loss):if val_loss <self.best_loss -self.min_delta:self.best_loss = val_lossself.counter =0returnFalse# Continue trainingelse:self.counter +=1returnself.counter >=self.patience # Stop if patience exceeded# Demonstrate early stoppingdef train_with_early_stopping():"""Train with early stopping""" model_es = SimpleGeospatialModel(num_bands=6, num_classes=5) optimizer = optim.Adam(model_es.parameters(), lr=0.001) early_stopping = EarlyStopping(patience=2)print("=== Early Stopping Demo ===")for epoch inrange(10): # Max 10 epochs train_loss, train_acc = train_epoch(model_es, train_loader, optimizer, criterion) val_loss, val_acc = validate(model_es, val_loader, criterion)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.1f}%, "f"Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.1f}%")if early_stopping(val_loss):print(f"Early stopping at epoch {epoch+1}")breaktrain_with_early_stopping()```## Model Comparison```{python}def compare_final_performance():"""Compare final performance of different strategies""" models = {'Full Training': model_full,'Frozen Features': model_frozen,'Layerwise LR': model_layerwise }print("\n=== Final Performance Comparison ===")for name, model in models.items(): val_loss, val_acc = validate(model, val_loader, criterion)print(f"{name:15}: Val Acc = {val_acc:.1f}%, Val Loss = {val_loss:.3f}")compare_final_performance()```## Transfer Learning Best Practices```{python}def show_best_practices():"""Demonstrate transfer learning best practices"""print("\n=== Transfer Learning Best Practices ===") practices = {"Start with lower learning rates": "0.0001 - 0.001 typically work well","Freeze early layers initially": "Then gradually unfreeze if needed","Use different LRs for different layers": "Lower for pretrained, higher for new layers","Monitor validation carefully": "Use early stopping to prevent overfitting","Data augmentation is crucial": "Especially with limited training data","Gradual unfreezing": "Unfreeze layers progressively during training" }for practice, explanation in practices.items():print(f"β’ {practice}: {explanation}")show_best_practices()```## Feature Visualization```{python}def visualize_learned_features(model, sample_image):"""Visualize what the model has learned""" model.eval()# Get intermediate features features = []def hook_fn(module, input, output): features.append(output.detach())# Register hooks on conv layers hooks = []for name, module in model.features.named_modules():ifisinstance(module, nn.Conv2d): hooks.append(module.register_forward_hook(hook_fn))# Forward passwith torch.no_grad(): _ = model(sample_image.unsqueeze(0))# Remove hooksfor hook in hooks: hook.remove()print(f"\n=== Feature Map Analysis ===")for i, feature_map inenumerate(features):print(f"Layer {i+1}: {feature_map.shape}")return features# Analyze learned featuressample_img, _ = train_dataset[0]learned_features = visualize_learned_features(model_full, sample_img)```## Key Takeaways```{python}print("\n=== Fine-tuning Strategy Summary ===")strategies = {"Full Training": "Train all parameters - best when you have lots of data","Frozen Features": "Only train classifier - fastest, good for small datasets", "Layerwise LR": "Different learning rates - balanced approach","Gradual Unfreezing": "Progressive training - best for complex adaptation","Early Stopping": "Prevent overfitting - essential for small datasets"}for strategy, description in strategies.items():print(f"β’ {strategy}: {description}")print(f"\nTraining completed successfully! All examples ran quickly.")```## Summary- **Start simple** with frozen features and classifier-only training- **Use appropriate learning rates** - lower for pretrained layers- **Monitor validation** carefully to avoid overfitting - **Implement early stopping** for robust training- **Consider gradual unfreezing** for complex adaptations- **Visualize features** to understand what the model learnsThese techniques work across different model architectures and can be scaled up for larger, real-world applications.