# geogfm.evaluation
Course Roadmap Mapping
This weekβs work in the broader GFM plan.
Week | Stage | Focus | You will build (geogfm) | Library tools | Outcome |
---|---|---|---|---|---|
6 | Stage 2: Train Foundation Model | Evaluation & Analysis | evaluation/visualization.py ; (optional) evaluation/metrics.py |
matplotlib ; torchmetrics optional |
Recon visuals; track validation loss/PSNR |
Weekly goals
- Implement reconstruction visualization utilities
- Add a simple metric (e.g., PSNR) and validation loop hooks
- Interpret embeddings or reconstructions qualitatively
Session Outline (and Tangled Code)
- Concepts β Components mapping
- Reconstruction plots and side-by-sides β
evaluation/visualization.py
- Simple PSNR metric (optional) β
evaluation/metrics.py
- Reconstruction plots and side-by-sides β
Package inits
1) Visualization Utilities
from __future__ import annotations
import torch
import matplotlib.pyplot as plt
def show_reconstruction_grid(images: torch.Tensor, recon_patches: torch.Tensor, max_items: int = 4) -> None:
"""Show input images and reconstructed images side-by-side.
images: (B, C, H, W)
recon_patches: (B, N, C, P, P) reconstructed patches; will be reassembled as a naive grid.
"""
= images.detach().cpu()
images = recon_patches.detach().cpu()
recon_patches = images.shape
b, c, h, w = recon_patches.shape[-1]
p = h // p
grid_h = w // p
grid_w
def assemble(patches: torch.Tensor) -> torch.Tensor:
# patches: (N, C, P, P)
= []
rows for r in range(grid_h):
= torch.cat([patches[r * grid_w + cidx] for cidx in range(grid_w)], dim=-1)
row
rows.append(row)= torch.cat(rows, dim=-2)
full return full
= min(max_items, b)
num = plt.subplots(num, 2, figsize=(8, 4 * num))
fig, axes if num == 1:
= [axes]
axes for i in range(num):
= assemble(recon_patches[i]) # (C, H, W)
recon_full 0].imshow(images[i][0], cmap="viridis")
axes[i][0].set_title("Input (band 1)")
axes[i][0].axis("off")
axes[i][1].imshow(recon_full[0], cmap="viridis")
axes[i][1].set_title("Reconstruction (band 1)")
axes[i][1].axis("off")
axes[i][
plt.tight_layout() plt.show()
2) Simple Metrics (optional)
from __future__ import annotations
import torch
def psnr(pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-8, max_val: float = 1.0) -> torch.Tensor:
"""Peak Signal-to-Noise Ratio over reconstructed patches.
pred/target: (B, N, C, P, P) in [0, max_val]
"""
= (pred - target) ** 2
mse = mse.mean(dim=(-1, -2, -3, -4)) # per-sample
mse return 20 * torch.log10(torch.tensor(max_val, device=pred.device)) - 10 * torch.log10(mse + eps)
Usage snippet (non-tangled)
# After obtaining a batch of images and model outputs:
# show_reconstruction_grid(images[:4], outputs["reconstructions"][:4])
# s_psnr = psnr(outputs["reconstructions"], target_patches).mean().item()
# print("PSNR:", s_psnr)