# geogfm.training
Course Roadmap Mapping
This weekβs work in the broader GFM plan.
Week | Stage | Focus | You will build (geogfm) | Library tools | Outcome |
---|---|---|---|---|---|
5 | Stage 2: Train Foundation Model | Training Optimization | training/optimizer.py ; training/loop.py |
torch.optim.AdamW ; schedulers, AMP optional |
Single-epoch run; basic checkpoint save/restore |
Weekly goals
- Build
fit
,train_step
,eval_step
with logging - Configure AdamW; optionally add LR scheduler/AMP
- Save and restore a basic checkpoint
Session Outline (and Tangled Code)
- Concepts β Components mapping
- Optimizer builder β
training/optimizer.py
- Minimal training loop with logging/validation β
training/loop.py
- Optimizer builder β
Package inits
1) Optimizer Builder
from __future__ import annotations
from typing import Dict, Any
import torch
def build_optimizer(model: torch.nn.Module, cfg: Dict[str, Any]) -> torch.optim.Optimizer:
= (cfg.get("name") or "adamw").lower()
name = float(cfg.get("lr", 2e-4))
lr = float(cfg.get("weight_decay", 0.05))
weight_decay if name == "adamw":
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
elif name == "adam":
return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
else:
raise ValueError(f"Unsupported optimizer: {name}")
2) Training Loop (fit/train/eval)
from __future__ import annotations
from typing import Tuple, Callable, Optional
import time
import torch
from torch.utils.data import DataLoader
@torch.no_grad()
def evaluate(model: torch.nn.Module, loader: DataLoader, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> float:
eval()
model.= 0.0, 0
total_loss, count for batch in loader:
if isinstance(batch, (list, tuple)):
= batch[0]
images else:
= batch
images = model(images)
outputs if isinstance(outputs, dict) and "reconstructions" in outputs:
= outputs["reconstructions"]
preds # Target as non-overlapping patches (assumes square patches and stride=patch)
= images.shape
b, c, h, w = preds.shape[-1]
p = images.unfold(2, p, p).unfold(3, p, p).contiguous().view(b, -1, c, p, p)
target try:
= loss_fn(preds, target, outputs.get("mask"))
loss except TypeError:
= loss_fn(preds, target)
loss else:
raise RuntimeError("Model output not supported for evaluation")
+= float(loss) * images.size(0)
total_loss += images.size(0)
count return total_loss / max(1, count)
def fit(model: torch.nn.Module,
loaders: Tuple[DataLoader, Optional[DataLoader]],
optimizer: torch.optim.Optimizer,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],int = 1,
epochs: | str = "cpu") -> None:
device: torch.device = torch.device(device)
device
model.to(device)
= loaders
train_loader, val_loader for epoch in range(1, epochs + 1):
model.train()= time.time()
start = 0.0, 0
running_loss, count for batch in train_loader:
if isinstance(batch, (list, tuple)):
= batch[0].to(device)
images else:
= batch.to(device)
images =True)
optimizer.zero_grad(set_to_none= model(images)
outputs if isinstance(outputs, dict) and "reconstructions" in outputs:
= outputs["reconstructions"]
preds = images.shape
b, c, h, w = preds.shape[-1]
p = images.unfold(2, p, p).unfold(3, p, p).contiguous().view(b, -1, c, p, p)
target try:
= loss_fn(preds, target, outputs.get("mask"))
loss except TypeError:
= loss_fn(preds, target)
loss else:
raise RuntimeError("Model output not supported for training")
loss.backward()
optimizer.step()+= float(loss) * images.size(0)
running_loss += images.size(0)
count = running_loss / max(1, count)
train_loss
= f"Epoch {epoch:03d} | train_loss={train_loss:.4f}"
msg if val_loader is not None:
= evaluate(model, val_loader, loss_fn)
val_loss += f" | val_loss={val_loss:.4f}"
msg = time.time() - start
elapsed print(msg + f" | time={elapsed:.1f}s")
Usage snippet (non-tangled)
# Example (after Weeks 1β4 have produced datasets + MAE):
# from geogfm.training.optimizer import build_optimizer
# from geogfm.training.loop import fit
# optimizer = build_optimizer(mae, {"name": "adamw", "lr": 2e-4})
# fit(mae, (train_loader, val_loader), optimizer, masked_mse_loss, epochs=1, device="cpu")