TorchGeo is a PyTorch domain library for geospatial data, providing datasets, samplers, transforms, and pre-trained models for satellite imagery and geospatial applications.
import torchimport torchgeofrom torchgeo.datasets import RasterDataset, stack_samplesfrom torchgeo.transforms import AugmentationSequentialimport matplotlib.pyplot as pltimport numpy as npprint(f"TorchGeo version: {torchgeo.__version__}")print(f"PyTorch version: {torch.__version__}")
TorchGeo version: 0.7.1
PyTorch version: 2.7.1
Core Dataset Classes
RasterDataset basics
from torchgeo.datasets import RasterDatasetfrom torchgeo.samplers import RandomGeoSamplerimport tempfileimport osfrom pathlib import Path# Create a simple custom dataset (not inheriting from RasterDataset for demo)from torchgeo.datasets import BoundingBoxfrom rtree.index import Index, Propertyclass SampleGeoDataset:"""Sample geospatial dataset for demonstration"""def__init__(self, transforms=None):self.transforms = transforms# Define dataset boundsself.bounds = BoundingBox(-10.0, 10.0, -10.0, 10.0, 0, 100)# Define resolution (meters per pixel)self.res =10.0# 10 meter resolution# Create spatial index required by TorchGeo samplersself.index = Index(interleaved=False, properties=Property(dimension=3))# Add the dataset bounds to the indexself.index.insert(0, tuple(self.bounds))def__getitem__(self, query):# Create synthetic data for demonstration sample = {'image': torch.rand(3, 256, 256), # RGB image'bbox': query,'crs': 'EPSG:4326' }ifself.transforms: sample =self.transforms(sample)return sampledef__len__(self):return1000# Arbitrary length for sampling# Initialize datasetdataset = SampleGeoDataset()print(f"Dataset created: {type(dataset).__name__}")print(f"Dataset bounds: {dataset.bounds}")
from torchgeo.datasets import RESISC45, EuroSAT# Note: These require downloaded data files# For demonstration, we show the usage patterns# RESISC45 - Remote sensing image scene classification# resisc45 = RESISC45(root='data/resisc45', download=True)# print(f"RESISC45 classes: {len(resisc45.classes)}")# EuroSAT - Sentinel-2 image classification # eurosat = EuroSAT(root='data/eurosat', download=True)# print(f"EuroSAT classes: {len(eurosat.classes)}")print("Vision dataset classes ready for use with downloaded data")
Vision dataset classes ready for use with downloaded data
Geospatial Sampling
RandomGeoSampler
from torchgeo.samplers import RandomGeoSampler, GridGeoSamplerfrom torchgeo.datasets import BoundingBox# Define a region of interestroi = BoundingBox( minx=-10.0, maxx=10.0, miny=-10.0, maxy=10.0, mint=0, maxt=100)# For demonstration, show sampler concepts without full implementationprint("TorchGeo Samplers:")print("- RandomGeoSampler: Randomly samples patches from spatial regions")print("- GridGeoSampler: Systematically samples patches in a grid pattern") print("- Units can be PIXELS or CRS (coordinate reference system)")print(f"Sample ROI: {roi}")# Note: Actual usage requires proper GeoDataset implementation# random_sampler = RandomGeoSampler(dataset=dataset, size=256, length=100, roi=roi)
TorchGeo Samplers:
- RandomGeoSampler: Randomly samples patches from spatial regions
- GridGeoSampler: Systematically samples patches in a grid pattern
- Units can be PIXELS or CRS (coordinate reference system)
Sample ROI: BoundingBox(minx=-10.0, maxx=10.0, miny=-10.0, maxy=10.0, mint=0, maxt=100)
GridGeoSampler
# Grid-based systematic sampling conceptprint("GridGeoSampler Usage Pattern:")print("- size: Patch size in pixels (e.g., 256)")print("- stride: Step size between patches (e.g., 128 for overlap)")print("- roi: Region of interest as BoundingBox")print("- Provides systematic spatial coverage")# Example conceptual usage:# grid_sampler = GridGeoSampler(dataset=dataset, size=256, stride=128, roi=roi)
GridGeoSampler Usage Pattern:
- size: Patch size in pixels (e.g., 256)
- stride: Step size between patches (e.g., 128 for overlap)
- roi: Region of interest as BoundingBox
- Provides systematic spatial coverage
Data Transforms
Basic transforms
import torchvision.transforms as Tfrom torchgeo.transforms import AugmentationSequential# Standard computer vision transforms for preprocessingnormalization_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])# Basic geometric augmentationsbasic_augments = T.Compose([ T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.5),])print("Transform sequences created:")print("- Normalization transform for pretrained models") print("- Basic augmentations for training")print("- TorchGeo's AugmentationSequential preserves spatial relationships")
Transform sequences created:
- Normalization transform for pretrained models
- Basic augmentations for training
- TorchGeo's AugmentationSequential preserves spatial relationships
from torchgeo.datasets import IntersectionDataset, UnionDataset# Multi-modal data fusion conceptprint("TorchGeo Dataset Fusion:")print("- IntersectionDataset: Combines data that exists in ALL datasets")print("- UnionDataset: Combines data that exists in ANY dataset")print("- Useful for multi-modal analysis (optical + SAR + DEM)")# Example fusion workflow:print("\nTypical fusion workflow:")print("1. Load optical imagery dataset (Sentinel-2)")print("2. Load elevation dataset (DEM)")print("3. Load land cover dataset (labels)")print("4. Use IntersectionDataset to ensure spatial-temporal alignment")print("5. Sample consistent patches across all modalities")# Note: Requires proper GeoDataset implementations# fused_ds = IntersectionDataset(optical_ds, dem_ds, landcover_ds)
TorchGeo Dataset Fusion:
- IntersectionDataset: Combines data that exists in ALL datasets
- UnionDataset: Combines data that exists in ANY dataset
- Useful for multi-modal analysis (optical + SAR + DEM)
Typical fusion workflow:
1. Load optical imagery dataset (Sentinel-2)
2. Load elevation dataset (DEM)
3. Load land cover dataset (labels)
4. Use IntersectionDataset to ensure spatial-temporal alignment
5. Sample consistent patches across all modalities
import pytorch_lightning as plfrom torch.utils.data import DataLoaderclass GeospatialDataModule(pl.LightningDataModule):"""Data module for geospatial training"""def__init__(self, batch_size=32, num_workers=4):super().__init__()self.batch_size = batch_sizeself.num_workers = num_workersdef setup(self, stage=None):print("Setting up geospatial data module:")print("- Train/val split: 80/20")print("- Spatial sampling strategy")print("- Multi-worker data loading")def train_dataloader(self):print("Creating train dataloader with TorchGeo samplers")returnNone# Would return actual DataLoader with GeoSamplerdef val_dataloader(self):print("Creating validation dataloader")returnNone# Would return actual DataLoader# Example usage patternprint("PyTorch Lightning + TorchGeo Integration:")print("- Use GeoDataModule for spatial-aware data loading")print("- Combine with GeoSamplers for patch-based training")print("- Stack samples for batch processing")print("- Supports multi-modal geospatial data")datamodule = GeospatialDataModule(batch_size=8)datamodule.setup()
PyTorch Lightning + TorchGeo Integration:
- Use GeoDataModule for spatial-aware data loading
- Combine with GeoSamplers for patch-based training
- Stack samples for batch processing
- Supports multi-modal geospatial data
Setting up geospatial data module:
- Train/val split: 80/20
- Spatial sampling strategy
- Multi-worker data loading
Pre-trained Models
Using TorchGeo models
from torchgeo.models import ResNet18_Weightsimport torchvision.models as models# Load pre-trained weights for satellite imagery# weights = ResNet18_Weights.SENTINEL2_ALL_MOCO# model = models.resnet18(weights=weights)# For demonstration without actual weights:model = models.resnet18(pretrained=False)model.conv1 = torch.nn.Conv2d( in_channels=12, # Sentinel-2 has 12 bands out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)print(f"Model adapted for {model.conv1.in_channels} input channels")
Model adapted for 12 input channels
/Users/kellycaylor/mambaforge/envs/geoAI/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning:
The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
/Users/kellycaylor/mambaforge/envs/geoAI/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning:
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
Fine-tuning for classification
import torch.nn as nnclass GeospatialClassifier(nn.Module):"""Classifier for geospatial data"""def__init__(self, backbone, num_classes=10):super().__init__()self.backbone = backbone# Replace classifier headifhasattr(backbone, 'fc'): in_features = backbone.fc.in_features backbone.fc = nn.Linear(in_features, num_classes)def forward(self, x):returnself.backbone(x)# Create classifierclassifier = GeospatialClassifier(model, num_classes=10)print(f"Classifier created for {classifier.backbone.fc.out_features} classes")
class TemporalDataset:"""Dataset for temporal satellite imagery"""def__init__(self, time_steps=5):self.time_steps = time_stepsself.bounds = BoundingBox(-10.0, 10.0, -10.0, 10.0, 0, 100)# Create spatial indexself.index = Index(interleaved=False, properties=Property(dimension=3))self.index.insert(0, tuple(self.bounds))def__getitem__(self, query):# Simulate temporal data temporal_images = []for t inrange(self.time_steps):# Each time step has slightly different data image = torch.rand(3, 256, 256) + t *0.1 temporal_images.append(image)return {'image': torch.stack(temporal_images, dim=0), # [T, C, H, W]'bbox': query,'timestamps': torch.arange(self.time_steps) }# Create temporal datasettemporal_ds = TemporalDataset(time_steps=5)print("Temporal dataset created for time series analysis")
Temporal dataset created for time series analysis
Performance Optimization
Caching and preprocessing
class CachedDataset:"""Dataset with caching for repeated access"""def__init__(self, cache_size=1000):self.cache = {}self.cache_size = cache_sizeself.bounds = BoundingBox(-10.0, 10.0, -10.0, 10.0, 0, 100)# Create spatial indexself.index = Index(interleaved=False, properties=Property(dimension=3))self.index.insert(0, tuple(self.bounds))def__getitem__(self, query): query_key =str(query)if query_key inself.cache:returnself.cache[query_key]# Generate/load data sample = {'image': torch.rand(3, 256, 256),'bbox': query }# Cache if space availableiflen(self.cache) <self.cache_size:self.cache[query_key] = samplereturn sampleprint("Cached dataset implementation ready")