graph TD A[Week 1: Data Foundations] --> B[Week 2: Attention & Blocks] B --> C[Week 3: Full Encoder] C --> D[Week 4: MAE Pretraining] D --> E[Week 5: Training Loop] E --> F[Week 6: Eval & Viz] F --> G[Week 7+: Interop & Scale]
Why this cheatsheet
Quick, student-facing reference for how our Geospatial Foundation Model (GFM) is organized, what each part does, and where to look as we go from MVP to more capable systems.
Roadmap at a glance
Minimal structure youβll use
geogfm/
core/
config.py # Minimal typed configs for model, data, training
data/
loaders.py # build_dataloader(...)
datasets/
stac_dataset.py # Simple STAC-backed dataset
transforms/
normalization.py # Per-channel normalization
patchify.py # Extract fixed-size patches
modules/
attention/
multihead_attention.py # Standard MHA (from scratch)
embeddings/
patch_embedding.py # Conv patch embedding
positional_encoding.py # Simple positional encoding
blocks/
transformer_block.py # PreNorm block (MHA + MLP)
heads/
reconstruction_head.py # Lightweight decoder/readout
losses/
mae_loss.py # Masked reconstruction loss
models/
gfm_vit.py # GeoViT-style encoder
gfm_mae.py # MAE wrapper (masking + encoder + head)
training/
optimizer.py # AdamW builder
loop.py # fit/train_step/eval_step with basic checkpointing
evaluation/
visualization.py # Visualize inputs vs reconstructions
# Outside the package (repo root)
configs/ # Small YAML/JSON run configs
tests/ # Unit tests
data/ # Datasets, splits, stats, build scripts
What each part does (one-liners)
- core/config.py: Typed configs for model/data/training; keeps parameters organized.
- data/datasets/stac_dataset.py: Reads imagery + metadata (e.g., STAC), returns tensors.
- data/transforms/normalization.py: Normalizes channels using precomputed stats.
- data/transforms/patchify.py: Turns large images into uniform patches for ViT.
- data/loaders.py: Builds PyTorch DataLoaders for train/val.
- modules/embeddings/patch_embedding.py: Projects image patches into token vectors.
- modules/embeddings/positional_encoding.py: Adds position info to tokens.
- modules/attention/multihead_attention.py: Lets tokens attend to each other.
- modules/blocks/transformer_block.py: Core transformer layer (attention + MLP).
- modules/heads/reconstruction_head.py: Reconstructs pixels from encoded tokens.
- modules/losses/mae_loss.py: Computes masked reconstruction loss for MAE.
- models/gfm_vit.py: Assembles the encoder backbone from blocks.
- models/gfm_mae.py: Wraps encoder with masking + reconstruction for pretraining.
- training/optimizer.py: Creates AdamW with common defaults.
- training/loop.py: Runs epochs, backprop, validation, and simple checkpoints.
- evaluation/visualization.py: Plots sample inputs and reconstructions.
From-scratch vs library-backed
- Use PyTorch for Dataset/DataLoader, AdamW, schedulers, AMP, checkpointing.
- Build core blocks from scratch first: PatchEmbedding, MHA, TransformerBlock, MAE loss/head.
- Later, swap in optimized options when needed:
- torch.nn.MultiheadAttention, timm ViT blocks, FlashAttention
- TorchGeo datasets/transforms, torchvision/kornia/albumentations
- torchmetrics for metrics; accelerate/lightning for training scale-up
Quick start (conceptual)
from geogfm.core.config import ModelConfig, DataConfig, TrainConfig
from geogfm.models.gfm_vit import GeoViTBackbone
from geogfm.models.gfm_mae import MaskedAutoencoder
from geogfm.data.loaders import build_dataloader
from geogfm.training.loop import fit
= ModelConfig(architecture="gfm_vit", embed_dim=768, depth=12, image_size=224)
model_cfg = DataConfig(dataset="stac", patch_size=16, num_workers=8)
data_cfg = TrainConfig(epochs=1, batch_size=8, optimizer={"name": "adamw", "lr": 2e-4})
train_cfg
= GeoViTBackbone(model_cfg)
encoder = MaskedAutoencoder(model_cfg, encoder)
model = build_dataloader(data_cfg)
train_dl, val_dl fit(model, (train_dl, val_dl), train_cfg)
What to notice: - The encoder and MAE wrapper are separate so the encoder can be reused for other tasks. - Data transforms (normalize/patchify) are decoupled from the model and driven by config.
Where data lives vs dataset code
data/
(repo root): datasets, splits, stats, caches, and build scripts (e.g., STAC builders). No Python package imports here.geogfm/data/datasets/
: pure Python classes (subclasstorch.utils.data.Dataset
) that read from paths provided via configs. No real data inside the package.
Why: separates large mutable artifacts (datasets) from installable, testable code.
MVP vs later phases
- MVP (Weeks 1β6): files shown above; single-node training; basic logging/viz.
- Phase 2 (Weeks 5β7+): AMP, scheduler, simple metrics (PSNR/SSIM), samplers, light registry.
- Phase 3 (Weeks 7β10): interop (HF/timm/TorchGeo), task heads, inference tiling, model hub/compat.
Extended reference structure (for context)
geogfm/
core/{registry.py, config.py, types.py, utils.py}
data/{loaders.py, samplers.py, datasets/*, transforms/*, tokenizers/*}
modules/{attention/*, embeddings/*, blocks/*, losses/*, heads/*, adapters/*}
models/{gfm_vit.py, gfm_mae.py, prithvi_compat.py, hub/*}
tasks/{pretraining_mae.py, classification.py, segmentation.py, change_detection.py, retrieval.py}
training/{loop.py, optimizer.py, scheduler.py, mixed_precision.py, callbacks.py, ema.py, checkpointing.py}
evaluation/{metrics.py, probes.py, visualization.py, nearest_neighbors.py}
inference/{serving.py, tiling.py, sliding_window.py, postprocess.py}
interoperability/{huggingface.py, timm.py, torchgeo.py}
utils/{logging.py, distributed.py, io.py, profiling.py, seed.py}
Week mapping (quick reference)
- Week 1: data (
data/datasets
,data/transforms
,data/loaders
,core/config.py
) - Week 2: attention/embeddings/blocks (
modules/
) - Week 3: architecture (
models/gfm_vit.py
,modules/heads/...
) - Week 4: MAE (
models/gfm_mae.py
,modules/losses/mae_loss.py
) - Week 5: training (
training/optimizer.py
,training/loop.py
) - Week 6: viz/metrics (
evaluation/visualization.py
) - Weeks 7β10: interop, tasks, inference, larger models (e.g., Prithvi)