---
title: "Data Loading for Satellite Imagery"
subtitle: "Efficient data loading patterns"
jupyter: geoai
format:
html:
code-fold: false
---
## Introduction to Satellite Data Loading
Efficient data loading is crucial for satellite imagery analysis due to large file sizes, multiple spectral bands, and complex geospatial metadata.
```{python}
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import rasterio
from rasterio.windows import Window
import xarray as xr
from pathlib import Path
import matplotlib.pyplot as plt
print (f"PyTorch version: { torch. __version__} " )
print (f"Rasterio version: { rasterio. __version__} " )
```
## Basic Dataset Patterns
### Simple satellite dataset
```{python}
class SatelliteDataset(Dataset):
"""Basic satellite imagery dataset"""
def __init__ (self , image_paths, patch_size= 256 , transform= None ):
self .image_paths = [Path(p) for p in image_paths]
self .patch_size = patch_size
self .transform = transform
# Pre-compute dataset info
self ._scan_images()
def _scan_images(self ):
"""Scan images to get metadata"""
self .image_info = []
for path in self .image_paths:
# In practice, you'd open actual files
# For demo, simulate metadata
info = {
'path' : path,
'width' : 1024 ,
'height' : 1024 ,
'bands' : 4 ,
'dtype' : np.uint16
}
self .image_info.append(info)
def __len__ (self ):
return len (self .image_paths)
def __getitem__ (self , idx):
info = self .image_info[idx]
# Simulate loading satellite data
# In practice: data = rasterio.open(info['path']).read()
data = np.random.randint(0 , 4096 ,
(info['bands' ], self .patch_size, self .patch_size),
dtype= np.uint16)
# Convert to tensor
tensor_data = torch.from_numpy(data).float () / 4095.0 # Normalize
sample = {
'image' : tensor_data,
'path' : str (info['path' ]),
'metadata' : info
}
if self .transform:
sample = self .transform(sample)
return sample
# Example usage
image_paths = ['image1.tif' , 'image2.tif' , 'image3.tif' ]
dataset = SatelliteDataset(image_paths, patch_size= 256 )
print (f"Dataset size: { len (dataset)} " )
print (f"Sample keys: { list (dataset[0 ].keys())} " )
print (f"Image shape: { dataset[0 ]['image' ]. shape} " )
```
### Multi-temporal dataset
```{python}
class TemporalSatelliteDataset(Dataset):
"""Dataset for temporal satellite imagery sequences"""
def __init__ (self , data_root, sequence_length= 5 , time_step= 30 ):
self .data_root = Path(data_root)
self .sequence_length = sequence_length
self .time_step = time_step # Days between images
# In practice, scan directory for date-organized images
self .sequences = self ._find_temporal_sequences()
def _find_temporal_sequences(self ):
"""Find valid temporal sequences"""
# Simulate finding temporal sequences
sequences = []
for i in range (10 ): # 10 example sequences
start_date = f"2020- { (i % 12 ) + 1 :02d} -01"
sequence = {
'start_date' : start_date,
'location_id' : f'tile_ { i:03d} ' ,
'file_pattern' : f'tile_ { i:03d} _*.tif'
}
sequences.append(sequence)
return sequences
def __len__ (self ):
return len (self .sequences)
def __getitem__ (self , idx):
sequence = self .sequences[idx]
# Load temporal sequence
images = []
for t in range (self .sequence_length):
# Simulate temporal progression
# Each image in sequence has slight variations
base_image = np.random.randn(4 , 256 , 256 ) + t * 0.1
images.append(base_image)
# Stack temporal dimension: [T, C, H, W]
temporal_stack = np.stack(images, axis= 0 )
tensor_stack = torch.from_numpy(temporal_stack).float ()
return {
'images' : tensor_stack,
'sequence_id' : sequence['location_id' ],
'start_date' : sequence['start_date' ],
'time_steps' : self .sequence_length
}
# Example usage
temporal_dataset = TemporalSatelliteDataset('data/' , sequence_length= 5 )
sample = temporal_dataset[0 ]
print (f"Temporal dataset size: { len (temporal_dataset)} " )
print (f"Image sequence shape: { sample['images' ]. shape} " )
print (f"Sequence ID: { sample['sequence_id' ]} " )
```
## Memory-Efficient Loading
### Windowed reading for large files
```{python}
class WindowedSatelliteDataset(Dataset):
"""Dataset that reads windows from large satellite images"""
def __init__ (self , image_path, window_size= 512 , stride= 256 , max_windows= None ):
self .image_path = Path(image_path)
self .window_size = window_size
self .stride = stride
# Pre-compute all valid windows
self .windows = self ._compute_windows()
if max_windows and len (self .windows) > max_windows:
self .windows = self .windows[:max_windows]
def _compute_windows(self ):
"""Compute all valid windows for the image"""
# In practice, use rasterio to get actual dimensions
# Simulate large image dimensions
img_height, img_width = 4096 , 4096
windows = []
for row in range (0 , img_height - self .window_size + 1 , self .stride):
for col in range (0 , img_width - self .window_size + 1 , self .stride):
window = Window(col, row, self .window_size, self .window_size)
windows.append(window)
return windows
def __len__ (self ):
return len (self .windows)
def __getitem__ (self , idx):
window = self .windows[idx]
# In practice:
# with rasterio.open(self.image_path) as src:
# data = src.read(window=window)
# Simulate reading window
data = np.random.randint(0 , 2048 ,
(4 , self .window_size, self .window_size),
dtype= np.uint16)
tensor_data = torch.from_numpy(data).float () / 2047.0
return {
'image' : tensor_data,
'window' : window,
'window_bounds' : (window.col_off, window.row_off,
window.width, window.height)
}
# Example usage
windowed_dataset = WindowedSatelliteDataset('large_image.tif' ,
window_size= 512 ,
stride= 256 ,
max_windows= 100 )
print (f"Windowed dataset size: { len (windowed_dataset)} " )
print (f"First window shape: { windowed_dataset[0 ]['image' ]. shape} " )
print (f"Window bounds: { windowed_dataset[0 ]['window_bounds' ]} " )
```
### Lazy loading with caching
```{python}
from functools import lru_cache
from threading import Lock
class CachedSatelliteDataset(Dataset):
"""Dataset with intelligent caching for repeated access"""
def __init__ (self , image_paths, cache_size= 50 , patch_size= 256 ):
self .image_paths = [Path(p) for p in image_paths]
self .patch_size = patch_size
self .cache_lock = Lock()
# LRU cache for loaded images
self ._load_image = lru_cache(maxsize= cache_size)(self ._load_image_uncached)
def _load_image_uncached(self , image_path):
"""Load image without caching (wrapped by LRU cache)"""
# In practice: load with rasterio or other library
# Simulate loading time and memory usage
print (f"Loading { image_path} into cache..." )
# Simulate different image sizes and properties
bands = np.random.choice([3 , 4 , 8 , 12 ]) # Different satellite sensors
data = np.random.randint(0 , 4096 ,
(bands, self .patch_size, self .patch_size),
dtype= np.uint16)
return {
'data' : data,
'bands' : bands,
'loaded_at' : torch.tensor(0 ) # Timestamp placeholder
}
def __len__ (self ):
return len (self .image_paths) * 4 # Multiple patches per image
def __getitem__ (self , idx):
image_idx = idx // 4
patch_idx = idx % 4
image_path = str (self .image_paths[image_idx])
with self .cache_lock:
image_data = self ._load_image(image_path)
# Extract patch (simulate different patches from same image)
data = image_data['data' ].copy()
# Add some variation for different patches
if patch_idx > 0 :
noise = np.random.normal(0 , 50 , data.shape).astype(data.dtype)
data = np.clip(data.astype(np.int32) + noise, 0 , 4095 ).astype(np.uint16)
tensor_data = torch.from_numpy(data).float () / 4095.0
return {
'image' : tensor_data,
'image_idx' : image_idx,
'patch_idx' : patch_idx,
'bands' : image_data['bands' ]
}
def cache_info(self ):
"""Get cache statistics"""
return self ._load_image.cache_info()
# Example usage
cached_dataset = CachedSatelliteDataset(image_paths[:3 ], cache_size= 10 )
# Load some samples (first loads will cache images)
for i in range (6 ):
sample = cached_dataset[i]
print (f"Sample { i} : bands= { sample['bands' ]} , "
f"image_idx= { sample['image_idx' ]} , patch_idx= { sample['patch_idx' ]} " )
print (f"Cache statistics: { cached_dataset. cache_info()} " )
```
## Advanced Data Loading Strategies
### Multi-resolution dataset
```{python}
class MultiResolutionDataset(Dataset):
"""Dataset providing multiple resolutions of the same data"""
def __init__ (self , image_paths, resolutions= [128 , 256 , 512 ]):
self .image_paths = [Path(p) for p in image_paths]
self .resolutions = sorted (resolutions)
self .base_resolution = max (resolutions)
def __len__ (self ):
return len (self .image_paths)
def _resize_tensor(self , tensor, target_size):
"""Resize tensor to target size"""
import torch.nn.functional as F
# Add batch dimension for interpolation
tensor_4d = tensor.unsqueeze(0 ) # [1, C, H, W]
resized = F.interpolate(
tensor_4d,
size= (target_size, target_size),
mode= 'bilinear' ,
align_corners= False
)
return resized.squeeze(0 ) # Remove batch dimension
def __getitem__ (self , idx):
# Load base resolution data
base_data = np.random.randint(0 , 4096 ,
(4 , self .base_resolution, self .base_resolution),
dtype= np.uint16)
base_tensor = torch.from_numpy(base_data).float () / 4095.0
# Create multi-resolution versions
multi_res = {}
for res in self .resolutions:
if res == self .base_resolution:
multi_res[f'image_ { res} ' ] = base_tensor
else :
multi_res[f'image_ { res} ' ] = self ._resize_tensor(base_tensor, res)
# Add metadata
multi_res.update({
'path' : str (self .image_paths[idx]),
'base_resolution' : self .base_resolution,
'available_resolutions' : self .resolutions
})
return multi_res
# Example usage
multi_res_dataset = MultiResolutionDataset(image_paths, resolutions= [128 , 256 , 512 ])
sample = multi_res_dataset[0 ]
print ("Multi-resolution sample keys:" , list (sample.keys()))
for key in sample.keys():
if key.startswith('image_' ):
print (f" { key} : { sample[key]. shape} " )
```
### Balanced sampling dataset
```{python}
class BalancedSatelliteDataset(Dataset):
"""Dataset with balanced sampling across different conditions"""
def __init__ (self , image_paths, labels, balance_strategy= 'oversample' ):
self .image_paths = [Path(p) for p in image_paths]
self .labels = np.array(labels)
self .balance_strategy = balance_strategy
# Compute class weights and sampling indices
self .class_counts = np.bincount(self .labels)
self .num_classes = len (self .class_counts)
self ._compute_sampling_indices()
def _compute_sampling_indices(self ):
"""Compute sampling indices for balanced loading"""
if self .balance_strategy == 'oversample' :
# Oversample minority classes
max_count = np.max (self .class_counts)
self .sampling_indices = []
for class_id in range (self .num_classes):
class_indices = np.where(self .labels == class_id)[0 ]
# Repeat indices to match max count
repeats = max_count // len (class_indices)
remainder = max_count % len (class_indices)
oversampled = np.tile(class_indices, repeats)
if remainder > 0 :
extra = np.random.choice(class_indices, remainder, replace= False )
oversampled = np.concatenate([oversampled, extra])
self .sampling_indices.extend(oversampled)
# Shuffle the indices
self .sampling_indices = np.array(self .sampling_indices)
np.random.shuffle(self .sampling_indices)
elif self .balance_strategy == 'undersample' :
# Undersample majority classes
min_count = np.min (self .class_counts)
self .sampling_indices = []
for class_id in range (self .num_classes):
class_indices = np.where(self .labels == class_id)[0 ]
sampled = np.random.choice(class_indices, min_count, replace= False )
self .sampling_indices.extend(sampled)
self .sampling_indices = np.array(self .sampling_indices)
np.random.shuffle(self .sampling_indices)
def __len__ (self ):
return len (self .sampling_indices)
def __getitem__ (self , idx):
actual_idx = self .sampling_indices[idx]
# Load satellite image
data = np.random.randint(0 , 4096 , (4 , 256 , 256 ), dtype= np.uint16)
tensor_data = torch.from_numpy(data).float () / 4095.0
return {
'image' : tensor_data,
'label' : torch.tensor(self .labels[actual_idx], dtype= torch.long ),
'original_idx' : actual_idx,
'path' : str (self .image_paths[actual_idx])
}
# Example usage with imbalanced classes
np.random.seed(42 )
labels = np.random.choice([0 , 1 , 2 ], size= 50 , p= [0.7 , 0.2 , 0.1 ]) # Imbalanced
# Ensure we have one image path per label index used below
image_paths = [f"image_ { i} .tif" for i in range (len (labels))]
balanced_dataset = BalancedSatelliteDataset(
image_paths[:50 ],
labels,
balance_strategy= 'oversample'
)
# Check class distribution in balanced dataset
sample_labels = [balanced_dataset[i]['label' ].item() for i in range (len (balanced_dataset))]
balanced_counts = np.bincount(sample_labels)
print (f"Original class distribution: { np. bincount(labels)} " )
print (f"Balanced class distribution: { balanced_counts} " )
print (f"Balanced dataset size: { len (balanced_dataset)} " )
```
## DataLoader Optimization
### Custom collate functions
```{python}
def satellite_collate_fn(batch):
"""Custom collate function for satellite imagery batches"""
# Handle variable number of bands
max_bands = max (sample['image' ].shape[0 ] for sample in batch)
batch_size = len (batch)
# Get common spatial dimensions
height = batch[0 ]['image' ].shape[1 ]
width = batch[0 ]['image' ].shape[2 ]
# Pad images to same number of bands
padded_images = torch.zeros(batch_size, max_bands, height, width)
labels = []
paths = []
band_masks = torch.zeros(batch_size, max_bands, dtype= torch.bool )
for i, sample in enumerate (batch):
img = sample['image' ]
num_bands = img.shape[0 ]
padded_images[i, :num_bands] = img
band_masks[i, :num_bands] = True
if 'label' in sample:
labels.append(sample['label' ])
paths.append(sample['path' ])
result = {
'image' : padded_images,
'band_mask' : band_masks,
'path' : paths
}
if labels:
result['label' ] = torch.stack(labels)
return result
def variable_size_collate_fn(batch):
"""Collate function for variable-sized images"""
# Find max dimensions
max_height = max (sample['image' ].shape[1 ] for sample in batch)
max_width = max (sample['image' ].shape[2 ] for sample in batch)
max_bands = max (sample['image' ].shape[0 ] for sample in batch)
batch_size = len (batch)
# Create padded batch
padded_batch = torch.zeros(batch_size, max_bands, max_height, max_width)
size_masks = []
for i, sample in enumerate (batch):
img = sample['image' ]
c, h, w = img.shape
padded_batch[i, :c, :h, :w] = img
# Create mask for valid pixels
mask = torch.zeros(max_height, max_width, dtype= torch.bool )
mask[:h, :w] = True
size_masks.append(mask)
return {
'image' : padded_batch,
'size_mask' : torch.stack(size_masks),
'original_sizes' : [(s['image' ].shape[1 ], s['image' ].shape[2 ]) for s in batch]
}
# Example usage
dataloader = DataLoader(
dataset,
batch_size= 4 ,
shuffle= True ,
collate_fn= satellite_collate_fn,
num_workers= 0 ,
pin_memory= True
)
# Test the dataloader
batch = next (iter (dataloader))
print (f"Batch image shape: { batch['image' ]. shape} " )
print (f"Band mask shape: { batch['band_mask' ]. shape} " )
print (f"Paths: { len (batch['path' ])} " )
```
### Performance optimization
```{python}
def create_optimized_dataloader(dataset, batch_size= 32 , num_workers= 4 ):
"""Create optimized dataloader for satellite imagery"""
return DataLoader(
dataset,
batch_size= batch_size,
shuffle= True ,
num_workers= num_workers,
pin_memory= torch.cuda.is_available(), # Pin memory if GPU available
persistent_workers= True , # Keep workers alive between epochs
prefetch_factor= 2 , # Prefetch 2 batches per worker
drop_last= True , # Consistent batch sizes
collate_fn= satellite_collate_fn
)
# Memory usage monitoring
def monitor_memory_usage(dataloader, num_batches= 5 ):
"""Monitor memory usage during data loading"""
if torch.cuda.is_available():
print ("GPU memory monitoring:" )
torch.cuda.reset_peak_memory_stats()
initial_memory = torch.cuda.memory_allocated()
for i, batch in enumerate (dataloader):
if i >= num_batches:
break
current_memory = torch.cuda.memory_allocated()
peak_memory = torch.cuda.max_memory_allocated()
print (f"Batch { i} : Current= { current_memory/ 1e6 :.1f} MB, "
f"Peak= { peak_memory/ 1e6 :.1f} MB" )
else :
print ("GPU not available for memory monitoring" )
# Example usage
optimized_loader = create_optimized_dataloader(dataset, batch_size= 8 )
# monitor_memory_usage(optimized_loader)
print ("Optimized dataloader created" )
```
## Real-World Integration Examples
### Integration with preprocessing pipeline
```{python}
class PreprocessingSatelliteDataset(Dataset):
"""Dataset with integrated preprocessing pipeline"""
def __init__ (self , image_paths, preprocessing_config= None ):
self .image_paths = [Path(p) for p in image_paths]
self .config = preprocessing_config or self ._default_config()
def _default_config(self ):
"""Default preprocessing configuration"""
return {
'normalize' : True ,
'clip_percentiles' : (2 , 98 ),
'target_bands' : [2 , 3 , 4 , 7 ], # RGB + NIR
'target_resolution' : 256 ,
'augment' : True
}
def _preprocess_image(self , image):
"""Apply preprocessing pipeline"""
# Select target bands
if self .config['target_bands' ]:
available_bands = min (image.shape[0 ], max (self .config['target_bands' ]) + 1 )
target_bands = [b for b in self .config['target_bands' ] if b < available_bands]
image = image[target_bands]
# Normalize
if self .config['normalize' ]:
for band in range (image.shape[0 ]):
band_data = image[band]
if self .config['clip_percentiles' ]:
p_low, p_high = self .config['clip_percentiles' ]
low_val = np.percentile(band_data, p_low)
high_val = np.percentile(band_data, p_high)
band_data = np.clip(band_data, low_val, high_val)
# Normalize to [0, 1]
band_min, band_max = band_data.min (), band_data.max ()
if band_max > band_min:
image[band] = (band_data - band_min) / (band_max - band_min)
return image
def __len__ (self ):
return len (self .image_paths)
def __getitem__ (self , idx):
# Simulate loading multi-band satellite image
num_bands = np.random.choice([4 , 8 , 12 ]) # Different sensors
base_resolution = np.random.choice([256 , 512 ])
# Simulate realistic satellite data values
image = np.random.randint(100 , 4000 ,
(num_bands, base_resolution, base_resolution),
dtype= np.uint16).astype(np.float32)
# Apply preprocessing
processed_image = self ._preprocess_image(image)
# Convert to tensor
tensor_image = torch.from_numpy(processed_image)
return {
'image' : tensor_image,
'path' : str (self .image_paths[idx]),
'original_bands' : num_bands,
'processed_bands' : processed_image.shape[0 ]
}
# Example usage
preprocessing_config = {
'normalize' : True ,
'clip_percentiles' : (1 , 99 ),
'target_bands' : [0 , 1 , 2 , 3 ], # First 4 bands
'augment' : False
}
preprocessed_dataset = PreprocessingSatelliteDataset(
image_paths,
preprocessing_config= preprocessing_config
)
sample = preprocessed_dataset[0 ]
print (f"Preprocessed sample shape: { sample['image' ]. shape} " )
print (f"Original bands: { sample['original_bands' ]} " )
print (f"Processed bands: { sample['processed_bands' ]} " )
print (f"Value range: [ { sample['image' ]. min ():.3f} , { sample['image' ]. max ():.3f} ]" )
```
## Summary
Key data loading strategies for satellite imagery:
- **Memory efficiency**: Window-based reading, caching, lazy loading
- **Multi-temporal**: Handle time series of satellite observations
- **Multi-resolution**: Provide different spatial resolutions
- **Balanced sampling**: Handle imbalanced datasets
- **Custom collation**: Handle variable bands and sizes
- **Preprocessing integration**: Normalization, band selection, augmentation
- **Performance optimization**: Multi-processing, memory pinning, prefetching