# geogfm.models
# Package init for model modules
Course Roadmap Mapping
This weekβs work in the broader GFM plan.
Week | Stage | Focus | You will build (geogfm) | Library tools | Outcome |
---|---|---|---|---|---|
3 | Stage 1: Build GFM Architecture | Complete Architecture | models/gfm_vit.py ; modules/heads/reconstruction_head.py |
torch.nn (timm as reference) |
Encoder assembled; end-to-end forward on dummy input |
Weekly goals
- Wire blocks into a GeoViT-style encoder
- Add a simple reconstruction head
- Run end-to-end forward pass on dummy input
Session Outline (and Tangled Code)
- Concepts β Components mapping
- Token pipeline β
PatchEmbedding
+ positional encoding + stack ofTransformerBlock
s - Encoder backbone β
models/gfm_vit.py
- Decoder/readout for MAE β
modules/heads/reconstruction_head.py
- Token pipeline β
Package inits
1) GeoViT Backbone β geogfm/models/gfm_vit.py
from __future__ import annotations
from dataclasses import dataclass
from typing import List
import torch
import torch.nn as nn
from geogfm.modules.embeddings.patch_embedding import PatchEmbedding, PatchEmbedConfig
from geogfm.modules.embeddings.positional_encoding import sinusoidal_positional_encoding
from geogfm.modules.blocks.transformer_block import TransformerBlock
@dataclass
class ViTBackboneConfig:
int = 3
in_channels: int = 224
image_size: int = 16
patch_size: int = 256
embed_dim: int = 8
depth: int = 8
num_heads: float = 4.0
mlp_ratio:
class GeoViTBackbone(nn.Module):
def __init__(self, cfg: ViTBackboneConfig):
super().__init__()
self.cfg = cfg
# Tokenization: Conv2d-based patchify + linear projection
self.patch_embed = PatchEmbedding(PatchEmbedConfig(cfg.in_channels, cfg.embed_dim, cfg.patch_size))
= (cfg.image_size // cfg.patch_size) ** 2
num_patches # Fixed positional encodings for stability and speed in the session
self.pos_embed = nn.Parameter(sinusoidal_positional_encoding(num_patches, cfg.embed_dim), requires_grad=False)
# Encoder: stack of PreNorm Transformer blocks
self.blocks = nn.ModuleList([
=cfg.mlp_ratio) for _ in range(cfg.depth)
TransformerBlock(cfg.embed_dim, cfg.num_heads, mlp_ratio
])self.norm = nn.LayerNorm(cfg.embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return latent token sequence of shape (batch, num_tokens, embed_dim)."""
= self.patch_embed(x) # (B, N, D)
tokens = tokens + self.pos_embed.unsqueeze(0)
tokens for blk in self.blocks:
= blk(tokens)
tokens = self.norm(tokens)
tokens return tokens # (B, N, D)
2) Reconstruction Head β geogfm/modules/heads/reconstruction_head.py
from __future__ import annotations
import torch
import torch.nn as nn
class ReconstructionHead(nn.Module):
"""Token-wise MLP to reconstruct patch pixels from latent tokens."""
def __init__(self, embed_dim: int, out_channels: int, patch_size: int):
super().__init__()
self.out_channels = out_channels
self.patch_size = patch_size
# Two-layer MLP mapping from token dim D -> (C * P * P)
self.linear = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.GELU(),* patch_size * patch_size),
nn.Linear(embed_dim, out_channels
)
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
# Transform tokens: (B, N, D) -> (B, N, C*P*P) -> (B, N, C, P, P)
= tokens.shape
b, n, d = self.linear(tokens)
x = x.view(b, n, self.out_channels, self.patch_size, self.patch_size)
x return x
Quick Forward Check (non-tangled)
import torch
# Use locally defined classes in this session (avoid importing from geogfm here)
= torch.randn(2, 3, 64, 64)
x = GeoViTBackbone(ViTBackboneConfig(in_channels=3, image_size=64, patch_size=16, embed_dim=128, depth=2, num_heads=4))
vit = vit(x)
latent = ReconstructionHead(embed_dim=128, out_channels=3, patch_size=16)
head = head(latent)
recon print("latent:", latent.shape, "recon:", recon.shape)
latent: torch.Size([2, 16, 128]) recon: torch.Size([2, 16, 3, 16, 16])