import os
import warnings
import logging
logging.getLogger('PIL.TiffImagePlugin').setLevel(logging.CRITICAL)
warnings.filterwarnings('ignore', category=UserWarning)
DATA_PATH = '/Users/kellycaylor/dev/geoAI/data'This tutorial shows how to fine-tune a geospatial foundation model to predict the next NDVI image in a timeseries. We use Sentinel-2 data from Santa Barbara, encode temporal information (day-of-year, month) to capture seasonality, and train using TerraTorchβs standard workflow.
Step 0: Setup
os.environ["HF_HOME"] = os.path.join(DATA_PATH, "hfhome")
os.environ["HF_HUB_CACHE"] = os.path.join(DATA_PATH, "hub")
os.environ["HF_DATASETS_CACHE"] = os.path.join(DATA_PATH, "datasets")
os.environ["TRANSFORMERS_CACHE"] = os.path.join(DATA_PATH, "transformers")Step 1: Download and Organize Data
Download NDVI timeseries from STAC
Weβll download 2-3 years of Sentinel-2 data from Santa Barbara and compute NDVI:
from datetime import datetime, timedelta
import numpy as np
import rasterio
from rasterio.io import MemoryFile
from pystac_client import Client
import planetary_computer as pc
from pathlib import Path
# Study area
bbox = [-120.1, 34.4, -119.7, 34.5]
# Date range: 2022-2024 (multi-year for seasonality)
start_date = "2022-01-01"
end_date = "2024-12-31"
# Search Sentinel-2
catalog = Client.open("https://planetarycomputer.microsoft.com/api/stac/v1",
modifier=pc.sign_inplace)
search = catalog.search(
collections=["sentinel-2-l2a"],
bbox=bbox,
datetime=f"{start_date}/{end_date}",
query={"eo:cloud_cover": {"lt": 20}}
)
items = list(search.items())
print(f"Found {len(items)} scenes")Compute and save NDVI images with temporal metadata
Organize data for TerraTorch regression DataModule:
def compute_ndvi_from_item(item, bbox, output_dir):
"""Compute NDVI and save as GeoTIFF with temporal metadata."""
# Get red and NIR bands
red_href = pc.sign(item.assets["B04"].href)
nir_href = pc.sign(item.assets["B08"].href)
# Window from bbox
with rasterio.open(red_href) as src:
window = src.window(*bbox)
red = src.read(1, window=window).astype(float)
transform = src.window_transform(window)
crs = src.crs
with rasterio.open(nir_href) as src:
nir = src.read(1, window=src.window(*bbox)).astype(float)
# Compute NDVI
ndvi = (nir - red) / (nir + red + 1e-8)
# Extract date (THIS IS MAKING A LOT OF ASSUMPTIONS ABOUT THE DATA FORMAT & LOCATION)
date = datetime.fromisoformat(item.properties['datetime'].replace('Z', '+00:00'))
# Filename includes date for temporal extraction
filename = f"ndvi_{date.strftime('%Y%m%d')}.tif"
output_path = output_dir / filename
# Save
with rasterio.open(
output_path, 'w',
driver='GTiff',
height=ndvi.shape[0],
width=ndvi.shape[1],
count=1,
dtype=ndvi.dtype,
crs=crs,
transform=transform
) as dst:
dst.write(ndvi, 1)
return output_path, date
# Create output directory
ndvi_dir = Path(DATA_PATH) / "santa_barbara_ndvi"
ndvi_dir.mkdir(exist_ok=True, parents=True)
# Download sample (limit to 100 scenes for demo)
metadata = []
for item in items[:100]:
try:
output_path, date = compute_ndvi_from_item(item, bbox, ndvi_dir)
metadata.append({'path': str(output_path), 'date': date})
except Exception as e:
print(f"Skipping scene: {e}")
print(f"Saved {len(metadata)} NDVI images")Create train/val/test splits
Split by time to avoid leakage:
import pandas as pd
# Build metadata DataFrame
df = pd.DataFrame(metadata)
df = df.sort_values('date').reset_index(drop=True)
# Split: 60% train, 20% val, 20% test (temporal)
n = len(df)
train_end = int(0.6 * n)
val_end = int(0.8 * n)
train_df = df.iloc[:train_end]
val_df = df.iloc[train_end:val_end]
test_df = df.iloc[val_end:]
print(f"Train: {len(train_df)} ({train_df['date'].min()} to {train_df['date'].max()})")
print(f"Val: {len(val_df)} ({val_df['date'].min()} to {val_df['date'].max()})")
print(f"Test: {len(test_df)} ({test_df['date'].min()} to {test_df['date'].max()})")Organize for TerraTorch regression format
Regression format: image_X.tif and label_X.tif
# Create regression dataset structure
regression_dir = Path(DATA_PATH) / "ndvi_regression"
for split_name in ['train', 'val', 'test']:
(regression_dir / split_name).mkdir(exist_ok=True, parents=True)
def create_regression_pairs(df, split_dir, history_length=10):
"""Create input-target pairs for timeseries regression."""
pairs = []
for i in range(history_length, len(df)):
# Input: stack of history_length previous NDVI images
history_paths = df.iloc[i-history_length:i]['path'].tolist()
target_path = df.iloc[i]['path'].values[0]
target_date = df.iloc[i]['date'].values[0]
# Stack history into multi-band image
with rasterio.open(history_paths[0]) as src:
profile = src.profile
profile.update(count=history_length)
history_data = []
for path in history_paths:
with rasterio.open(path) as src:
history_data.append(src.read(1))
history_stack = np.stack(history_data)
# Save input (history stack)
input_path = split_dir / f"image_{i:04d}.tif"
with rasterio.open(input_path, 'w', **profile) as dst:
dst.write(history_stack)
# Copy target
target_out = split_dir / f"label_{i:04d}.tif"
with rasterio.open(target_path) as src:
data = src.read()
profile = src.profile
with rasterio.open(target_out, 'w', **profile) as dst:
dst.write(data)
pairs.append({
'input': str(input_path),
'target': str(target_out),
'date': target_date
})
return pd.DataFrame(pairs)
# Create pairs for each split
train_pairs = create_regression_pairs(train_df, regression_dir / 'train')
val_pairs = create_regression_pairs(val_df, regression_dir / 'val')
test_pairs = create_regression_pairs(test_df, regression_dir / 'test')
print(f"Created {len(train_pairs)} training pairs")
print(f"Created {len(val_pairs)} validation pairs")
print(f"Created {len(test_pairs)} test pairs")Step 2: Add Temporal Encodings
Create temporal metadata file for each split:
def add_temporal_features(df):
"""Add cyclic temporal encodings to DataFrame."""
df['day_of_year'] = df['date'].dt.dayofyear
df['month'] = df['date'].dt.month
# Cyclic encoding
df['doy_sin'] = np.sin(2 * np.pi * df['day_of_year'] / 365)
df['doy_cos'] = np.cos(2 * np.pi * df['day_of_year'] / 365)
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
return df
train_pairs = add_temporal_features(train_pairs)
val_pairs = add_temporal_features(val_pairs)
test_pairs = add_temporal_features(test_pairs)
# Save metadata
train_pairs.to_csv(regression_dir / 'train_metadata.csv', index=False)
val_pairs.to_csv(regression_dir / 'val_metadata.csv', index=False)
test_pairs.to_csv(regression_dir / 'test_metadata.csv', index=False)Step 3: Custom DataModule with Temporal Features
Create a custom DataModule that loads temporal encodings:
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import albumentations as A
from albumentations.pytorch import ToTensorV2
class TemporalNDVIDataset(Dataset):
"""Dataset that loads NDVI timeseries with temporal encodings."""
def __init__(self, metadata_csv, transform=None):
self.metadata = pd.read_csv(metadata_csv)
self.transform = transform
def __len__(self):
return len(self.metadata)
def __getitem__(self, idx):
row = self.metadata.iloc[idx]
# Load input (history stack)
with rasterio.open(row['input']) as src:
image = src.read().astype(np.float32) # (history_length, H, W)
# Load target
with rasterio.open(row['target']) as src:
label = src.read(1).astype(np.float32) # (H, W)
# Get temporal features
temporal = np.array([
row['doy_sin'],
row['doy_cos'],
row['month_sin'],
row['month_cos']
], dtype=np.float32)
# Transpose to (H, W, C) for albumentations
image = np.transpose(image, (1, 2, 0))
if self.transform:
transformed = self.transform(image=image, mask=label)
image = transformed['image']
label = transformed['mask']
return {
'image': image,
'label': label,
'temporal': torch.FloatTensor(temporal)
}
class TemporalNDVIDataModule(pl.LightningDataModule):
"""DataModule for temporal NDVI regression."""
def __init__(self, data_dir, batch_size=4, num_workers=0,
means=None, stds=None):
super().__init__()
self.data_dir = Path(data_dir)
self.batch_size = batch_size
self.num_workers = num_workers
self.means = means or [0.327] # NDVI mean from TerraMesh
self.stds = stds or [0.322] # NDVI std from TerraMesh
def setup(self, stage=None):
# Transforms
train_transform = A.Compose([
A.Resize(224, 224),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=self.means, std=self.stds),
ToTensorV2()
])
val_transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=self.means, std=self.stds),
ToTensorV2()
])
self.train_dataset = TemporalNDVIDataset(
self.data_dir / 'train_metadata.csv',
transform=train_transform
)
self.val_dataset = TemporalNDVIDataset(
self.data_dir / 'val_metadata.csv',
transform=val_transform
)
self.test_dataset = TemporalNDVIDataset(
self.data_dir / 'test_metadata.csv',
transform=val_transform
)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size,
shuffle=True, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size,
num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size,
num_workers=self.num_workers)
# Initialize
datamodule = TemporalNDVIDataModule(
data_dir=regression_dir,
batch_size=4,
num_workers=0
)
datamodule.setup('fit')Step 4: Custom Task with Temporal Conditioning
Modify TerraTorch task to include temporal encoder:
import torch.nn as nn
import torch.nn.functional as F
from terratorch.tasks import PixelwiseRegressionTask
from terratorch.models import PrithviModelFactory
class TemporalRegressionTask(PixelwiseRegressionTask):
"""Regression task with temporal conditioning."""
def __init__(self, *args, temporal_dim=4, hidden_dim=256, **kwargs):
super().__init__(*args, **kwargs)
# Add temporal encoder
encoder_dim = 768 # Prithvi embedding dimension
self.temporal_encoder = nn.Sequential(
nn.Linear(temporal_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, encoder_dim)
)
# Modify decoder input to accept concatenated features
# This is a simplification - in practice, modify the decoder or use FiLM
def forward(self, batch):
"""Forward pass with temporal conditioning."""
x = batch['image']
temporal = batch['temporal']
# Spatial encoding
spatial_features = self.model.encoder(x)[-1] # Get last encoder output
# Temporal encoding
temporal_features = self.temporal_encoder(temporal)
# Simple fusion: add temporal to spatial (broadcast)
# Shape: spatial_features is (B, C, H, W), temporal_features is (B, C)
temporal_features = temporal_features.unsqueeze(-1).unsqueeze(-1)
fused_features = spatial_features + temporal_features
# Decode
output = self.model.decoder([fused_features])
return {'output': output}
def training_step(self, batch, batch_idx):
"""Training step with temporal features."""
outputs = self(batch)
loss = F.mse_loss(outputs['output'], batch['label'].unsqueeze(1))
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
"""Validation step with temporal features."""
outputs = self(batch)
loss = F.mse_loss(outputs['output'], batch['label'].unsqueeze(1))
self.log('val_loss', loss)
return loss
# Initialize task
task = TemporalRegressionTask(
model_factory="EncoderDecoderFactory",
model_args={
'backbone': 'prithvi_eo_v1_100',
'backbone_pretrained': True,
'decoder': 'FCNDecoder',
'num_classes': 1, # Single-channel NDVI output
},
freeze_backbone=True,
freeze_decoder=False,
lr=1e-4
)Step 5: Train
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.loggers import WandbLogger
# Callbacks
checkpoint = ModelCheckpoint(monitor='val_loss', save_top_k=1, save_last=True)
early_stop = EarlyStopping(monitor='val_loss', patience=10)
# Logger
wandb_logger = WandbLogger(project='ndvi-forecasting')
# Trainer
trainer = Trainer(
max_epochs=50,
accelerator='auto',
devices=1,
callbacks=[checkpoint, early_stop],
logger=wandb_logger
)
# Train
trainer.fit(task, datamodule)Step 6: Inference
# Load best model
best_task = TemporalRegressionTask.load_from_checkpoint(checkpoint.best_model_path)
best_task.eval()
# Get test batch
test_batch = next(iter(datamodule.test_dataloader()))
# Predict
with torch.no_grad():
outputs = best_task(test_batch)
predictions = outputs['output']
# Visualize
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Input (last timestep from history)
axes[0].imshow(test_batch['image'][0, -1].cpu(), cmap='RdYlGn')
axes[0].set_title('Input (t-1)')
axes[0].axis('off')
# Ground truth
axes[1].imshow(test_batch['label'][0].cpu(), cmap='RdYlGn')
axes[1].set_title('Ground Truth (t)')
axes[1].axis('off')
# Prediction
axes[2].imshow(predictions[0, 0].cpu(), cmap='RdYlGn')
axes[2].set_title('Prediction (t)')
axes[2].axis('off')
plt.tight_layout()Summary
Key steps: 1. Download multi-year Sentinel-2 data from STAC 2. Compute NDVI and organize in regression format 3. Add cyclic temporal encodings (sin/cos of day-of-year, month) 4. Create custom DataModule that loads temporal features 5. Modify Task to include temporal encoder and fusion 6. Train with PyTorch Lightning 7. Run inference on test set
Temporal encoding enables: - Model learns seasonal patterns - Forecasts reflect time-of-year context - Single-step predictions can be applied recursively for multi-step forecasts