This cheatsheet shows how to load and work with pre-trained models for geospatial AI, using real examples with small sample data.
Setup and Imports
import torchimport torch.nn as nnimport timmimport numpy as npfrom PIL import Imageimport matplotlib.pyplot as pltprint(f"PyTorch version: {torch.__version__}")print(f"TIMM available: β")# Set random seeds for reproducible resultstorch.manual_seed(42)np.random.seed(42)
PyTorch version: 2.7.1
TIMM available: β
TIMM (Torch Image Models) - Quick and Reliable
TIMM is the most reliable way to load pre-trained vision models. Letβs start with a small ResNet model.
# Load a lightweight ResNet modelmodel = timm.create_model('resnet18', pretrained=True, num_classes=10)model.eval()print(f"Model: {model.__class__.__name__}")print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")# Check input requirementsdata_config = timm.data.resolve_model_data_config(model)print(f"Expected input size: {data_config['input_size']}")print(f"Mean: {data_config['mean']}")print(f"Std: {data_config['std']}")
Raw data: torch.Size([6, 256, 256]) -> Preprocessed: torch.Size([1, 6, 224, 224])
Raw range: [0.000, 1.000]
Preprocessed range: [-2.454, 6.526]
Feature Extraction
def extract_features(model, data, layer_name='avgpool'):"""Extract features from a specific layer. Robust to different model implementations (e.g., timm vs torchvision). """ features = {} handle =Nonedef hook(name):def fn(module, input, output): features[name] = output.detach()return fn# Try requested layer name first, then common fallbacks candidate_names = [layer_name, 'global_pool', 'avgpool', 'head.global_pool'] named_modules =list(model.named_modules())for candidate in candidate_names:for name, module in named_modules:if name == candidate or candidate in name: handle = module.register_forward_hook(hook(name))breakif handle isnotNone:breakif handle isNone: available = [name for name, _ in named_modules]raiseValueError(f"Requested layer '{layer_name}' not found. Available modules include: "f"{available[:20]}{' ...'iflen(available) >20else''}" )# Forward passwith torch.no_grad(): _ = model(data)# Clean upif handle isnotNone: handle.remove()return features# Extract features from our samplefeatures = extract_features(model_6band, sample_data, 'global_pool')feature_name =list(features.keys())[0]feature_tensor = features[feature_name]print(f"Feature layer: {feature_name}")print(f"Feature shape: {feature_tensor.shape}")print(f"Feature stats: mean={feature_tensor.mean():.3f}, std={feature_tensor.std():.3f}")