# Package init for `geogfm.modules` (Week 2)
# geogfm.modules
Course Roadmap Mapping
This weekβs work in the broader GFM plan.
Week | Stage | Focus | You will build (geogfm) | Library tools | Outcome |
---|---|---|---|---|---|
2 | Stage 1: Build GFM Architecture | Attention & Blocks | modules/embeddings/{patch_embedding.py, positional_encoding.py} ; modules/attention/multihead_attention.py ; modules/blocks/{mlp.py, transformer_block.py} |
torch.nn (compare with torch.nn.MultiheadAttention ) |
Blocks run forward with stable shapes; unit tests green |
Weekly goals
- Implement patch/positional embeddings and MHA from scratch
- Assemble a PreNorm transformer block (MHA + MLP)
- Validate tensor shapes and simple layer tests
Session Outline (and Tangled Code)
- Concepts β Components mapping
- Patch tokenization β
modules/embeddings/patch_embedding.py
- Positional encoding β
modules/embeddings/positional_encoding.py
- Scaled dot-product multihead self-attention β
modules/attention/multihead_attention.py
- Feedforward/MLP and Transformer block (PreNorm) β
modules/blocks/{mlp.py, transformer_block.py}
- Patch tokenization β
- What you will implement
- Patch embedding (Conv2d patchify + linear projection)
- Sinusoidal positional encodings (1D for tokens; optional 2D helper)
- Multihead self-attention from scratch
- MLP block and a PreNorm Transformer block
Package inits
# Package init for `geogfm.modules.attention` (Week 2)
# geogfm.modules.attention
# Package init for `geogfm.modules.embeddings` (Week 2)
# geogfm.modules.embeddings
# Package init for `geogfm.modules.blocks` (Week 2)
# geogfm.modules.blocks
# Package init for `geogfm.modules.heads` (Week 2)
# geogfm.modules.heads
# Package init for `geogfm.modules.losses` (Week 2)
# geogfm.modules.losses
1) Patch Embedding β geogfm/modules/embeddings/patch_embedding.py
# Patch embedding layer (Week 2)
from __future__ import annotations
from dataclasses import dataclass
from typing import Tuple
import torch
import torch.nn as nn
@dataclass
class PatchEmbedConfig:
int = 3
in_channels: int = 256
embed_dim: int = 16
patch_size:
class PatchEmbedding(nn.Module):
"""Conv2d-based patchifier producing token embeddings.
Input: (B, C, H, W)
Output: (B, N, D) where N = (H/ps)*(W/ps), D = embed_dim
"""
def __init__(self, cfg: PatchEmbedConfig):
super().__init__()
self.cfg = cfg
self.proj = nn.Conv2d(cfg.in_channels, cfg.embed_dim,
=cfg.patch_size, stride=cfg.patch_size)
kernel_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
# (B, D, H/ps, W/ps) -> (B, D, N) -> (B, N, D)
= self.proj(x)
x = x.shape
b, d, gh, gw = x.flatten(2).transpose(1, 2)
x return x
2) Positional Encoding (Sinusoidal) β geogfm/modules/embeddings/positional_encoding.py
# Sinusoidal positional encodings (Week 2)
from __future__ import annotations
import math
import torch
def sinusoidal_positional_encoding(seq_len: int, dim: int, device: torch.device | None = None) -> torch.Tensor:
"""Return (seq_len, dim) sinusoidal positional encodings."""
= device or torch.device("cpu")
device = torch.zeros(seq_len, dim, device=device)
pe = torch.arange(0, seq_len, dtype=torch.float32, device=device).unsqueeze(1)
position = torch.exp(torch.arange(0, dim, 2, device=device).float() * (-math.log(10000.0) / dim))
div_term 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe[:, return pe
def sinusoidal_positional_encoding_2d(height: int, width: int, dim: int, device: torch.device | None = None) -> torch.Tensor:
"""Return (height*width, dim) 2D positional encodings by concatenating two 1D encodings.
dim must be even.
"""
assert dim % 2 == 0, "dim must be even for 2D positional encoding"
= device or torch.device("cpu")
device = sinusoidal_positional_encoding(height, dim // 2, device) # (H, D/2)
pe_h = sinusoidal_positional_encoding(width, dim // 2, device) # (W, D/2)
pe_w = pe_h.unsqueeze(1).expand(height, width, dim // 2)
pe_h = pe_w.unsqueeze(0).expand(height, width, dim // 2)
pe_w = torch.cat([pe_h, pe_w], dim=-1).reshape(height * width, dim)
pe return pe
3) Multihead Self-Attention β geogfm/modules/attention/multihead_attention.py
# Multihead self-attention from scratch (Week 2)
from __future__ import annotations
from typing import Optional
import math
import torch
import torch.nn as nn
class MultiheadSelfAttention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0, proj_dropout: float = 0.0):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.attn_drop = nn.Dropout(attn_dropout)
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(proj_dropout)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
= x.shape
bsz, num_tokens, dim = self.qkv(x).reshape(bsz, num_tokens, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
qkv = qkv[0], qkv[1], qkv[2] # (B, H, N, Hd)
q, k, v
= 1.0 / math.sqrt(self.head_dim)
scale = (q @ k.transpose(-2, -1)) * scale # (B, H, N, N)
attn_scores if attn_mask is not None:
= attn_scores.masked_fill(attn_mask == 0, float("-inf"))
attn_scores = attn_scores.softmax(dim=-1)
attn = self.attn_drop(attn)
attn
= attn @ v # (B, H, N, Hd)
out = out.transpose(1, 2).reshape(bsz, num_tokens, dim) # (B, N, D)
out = self.proj(out)
out = self.proj_drop(out)
out return out
4) MLP and Transformer Block (PreNorm) β geogfm/modules/blocks/
# MLP feedforward block (Week 2)
from __future__ import annotations
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, embed_dim: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
super().__init__()
= int(embed_dim * mlp_ratio)
hidden_dim self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
= self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
x return x
# PreNorm Transformer block (Week 2)
from __future__ import annotations
import torch
import torch.nn as nn
from geogfm.modules.attention.multihead_attention import MultiheadSelfAttention
from geogfm.modules.blocks.mlp import MLP
class TransformerBlock(nn.Module):
"""PreNorm Transformer block: LN β MHA β Residual β LN β MLP β Residual."""
def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiheadSelfAttention(embed_dim, num_heads, proj_dropout=dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, mlp_ratio=mlp_ratio, dropout=dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
= x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
x return x
Quick Shape Checks (non-tangled)
import torch
from geogfm.modules.embeddings.patch_embedding import PatchEmbedding, PatchEmbedConfig
from geogfm.modules.embeddings.positional_encoding import sinusoidal_positional_encoding
from geogfm.modules.blocks.transformer_block import TransformerBlock
# Dummy image
= torch.randn(2, 3, 64, 64)
x = PatchEmbedding(PatchEmbedConfig(in_channels=3, embed_dim=128, patch_size=16))
pe = pe(x) # (2, 16, 128)
tokens = sinusoidal_positional_encoding(tokens.shape[1], 128, tokens.device)
pos = tokens + pos.unsqueeze(0)
tokens
= TransformerBlock(embed_dim=128, num_heads=4)
block = block(tokens)
out print("tokens ->", tokens.shape, "block out ->", out.shape)
tokens -> torch.Size([2, 16, 128]) block out -> torch.Size([2, 16, 128])