from __future__ import annotations
import torch
import torch.nn.functional as F
def masked_mse_loss(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Compute MSE only over masked tokens.
pred/target: (B, N, C, P, P)
mask: (B, N) with 1 for masked, 0 for visible tokens
"""
= pred.shape
b, n, c, p, q = pred.reshape(b, n, -1)
pred = target.reshape(b, n, -1)
target = mask.bool() # (B, N)
mask = pred[mask] - target[mask] # -> (num_masked, C*P*P)
diff if diff.numel() == 0:
return torch.tensor(0.0, device=pred.device, requires_grad=True)
return (diff ** 2).mean()
Course Roadmap Mapping
This weekβs work in the broader GFM plan.
Week | Stage | Focus | You will build (geogfm) | Library tools | Outcome |
---|---|---|---|---|---|
4 | Stage 2: Train Foundation Model | MAE Pretraining | models/gfm_mae.py ; modules/losses/mae_loss.py |
torch masking, numpy |
Masking + reconstruction; loss decreases on toy batch |
Weekly goals
- Implement MAE wrapper with masking strategy
- Build masked reconstruction loss; verify gradients
- Overfit a toy batch to confirm learning
Session Outline (and Tangled Code)
- Concepts β Components mapping
- Random masking of patch tokens β inside
MaskedAutoencoder
- Reconstruction of masked patches β
ReconstructionHead
- Loss on masked tokens only β
modules/losses/mae_loss.py
- Random masking of patch tokens β inside
1) MAE Loss (masked MSE) β geogfm/modules/losses/mae_loss.py
2) Masked Autoencoder Wrapper β geogfm/models/gfm_mae.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
import torch
import torch.nn as nn
from geogfm.models.gfm_vit import GeoViTBackbone, ViTBackboneConfig
from geogfm.modules.heads.reconstruction_head import ReconstructionHead
@dataclass
class MAEConfig:
vit: ViTBackboneConfigint = 3
out_channels: int = 16
patch_size: float = 0.75
mask_ratio:
class MaskedAutoencoder(nn.Module):
def __init__(self, cfg: MAEConfig):
super().__init__()
self.cfg = cfg
# Encoder backbone and reconstruction head
self.encoder = GeoViTBackbone(cfg.vit)
self.head = ReconstructionHead(cfg.vit.embed_dim, cfg.out_channels, cfg.patch_size)
@torch.no_grad()
def _random_mask(self, num_tokens: int, batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return keep_indices and mask matrix (B, N) with 1 for masked tokens."""
= int(round(num_tokens * (1.0 - self.cfg.mask_ratio)))
n_keep = torch.stack([torch.randperm(num_tokens, device=device) for _ in range(batch_size)], dim=0) # (B, N)
idx = idx[:, :n_keep]
keep = torch.ones(batch_size, num_tokens, device=device)
mask 1, keep, 0.0)
mask.scatter_(return keep, mask
def forward(self, images: torch.Tensor) -> dict:
"""Forward MAE.
images: (B, C, H, W)
Returns dict with: latent, reconstructions, mask
"""
= self.encoder(images) # (B, N, D)
tokens = tokens.shape
b, n, d = self._random_mask(n, b, tokens.device)
keep, mask # Gather visible tokens
= torch.arange(b, device=tokens.device).unsqueeze(-1).expand(b, keep.shape[1])
batch_indices = tokens[batch_indices, keep]
visible_tokens # For simplicity, decode all tokens by placing zeros for masked ones, then adding decoded visible back
= torch.zeros(b, n, self.cfg.out_channels, self.cfg.patch_size, self.cfg.patch_size, device=tokens.device)
decoded_all = self.head(visible_tokens) # (B, N_keep, C, P, P)
decoded_visible = decoded_visible
decoded_all[batch_indices, keep] return {"latent": tokens, "reconstructions": decoded_all, "mask": mask}
Quick Toy Batch (non-tangled)
import torch
# Use locally defined classes/functions and construct target patches matching shapes
from geogfm.models.gfm_vit import ViTBackboneConfig
= torch.randn(2, 3, 64, 64)
images = MaskedAutoencoder(MAEConfig(vit=ViTBackboneConfig(in_channels=3, image_size=64, patch_size=16, embed_dim=128, depth=2, num_heads=4),
mae =3, patch_size=16, mask_ratio=0.5))
out_channels= mae(images)
outputs
# Build target patches to match (B, N, C, P, P)
= 16
ps = images.unfold(2, ps, ps).unfold(3, ps, ps) # (B, C, H', W', ps, ps)
patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous().view(images.size(0), -1, images.size(1), ps, ps)
patches
= masked_mse_loss(outputs["reconstructions"], patches, outputs["mask"]) # simple target patches
loss print("latent:", outputs["latent"].shape, "recon:", outputs["reconstructions"].shape, "loss:", float(loss))
latent: torch.Size([2, 16, 128]) recon: torch.Size([2, 16, 3, 16, 16]) loss: 1.010398268699646