# GeogFM package init (Week 1)
# Tangles to `geogfm/__init__.py`
# Provides base package visibility for data + core modules
Introduction
Today weβre building a complete pipeline that transforms raw satellite imagery into model-ready embeddings. This mirrors how Large Language Models process text: raw text β tokens β embeddings. Our geospatial version: raw GeoTIFF β patches β embeddings.
Course Roadmap Mapping
This weekβs work in the broader GFM plan.
Week | Stage | Focus | You will build (geogfm) | Library tools | Outcome |
---|---|---|---|---|---|
1 | Stage 1: Build GFM Architecture | Data Foundations | core/config.py ; data/datasets/stac_dataset.py ; data/transforms/{normalization.py, patchify.py} ; data/loaders.py |
torch.utils.data.Dataset /DataLoader , rasterio , numpy |
Config-driven dataloaders that yield normalized patches |
Weekly goals
- Implement a minimal dataset, transforms, and dataloaders
- Normalize channels; extract patches deterministically
- Verify shapes/CRS/stats prints; run a tiny DataLoader
Session Outline (and Tangled Code)
- Concepts β Components mapping
- Configuration schemas β
core/config.py
- Normalization and patchifying transforms β
data/transforms/{normalization.py, patchify.py}
- Minimal STAC-like dataset β
data/datasets/stac_dataset.py
- DataLoader builders β
data/loaders.py
- Package init files β
geogfm/__init__.py
and subpackages
- Configuration schemas β
Package inits
# Core subpackage init (Week 1)
# Tangles to `geogfm/core/__init__.py`
# Exposes config schemas
# Data subpackage init (Week 1)
# Tangles to `geogfm/data/__init__.py`
# Exposes datasets, loaders, and transforms
# Datasets init (Week 1)
# Tangles to `geogfm/data/datasets/__init__.py`
# Exposes minimal STAC-like dataset
# Transforms init (Week 1)
# Tangles to `geogfm/data/transforms/__init__.py`
# Exposes normalization and patchify transforms
1) Typed Configs β geogfm/core/config.py
# Typed configuration schemas (Week 1)
# Tangles to `geogfm/core/config.py`
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, Any, Optional
@dataclass
class ModelConfig:
str = "gfm_vit"
architecture: int = 3
in_channels: int = 64
image_size: int = 16
patch_size: int = 128
embed_dim: int = 4
depth: int = 4
num_heads: float = 4.0
mlp_ratio:
@dataclass
class DataConfig:
str = "stac"
dataset: str = "data/out" # or a small sample dir
root_dir: str = "train"
split: int = 64
image_size: int = 3
in_channels: int = 64 # synthetic length fallback
length: int = 0
num_workers: int = 8
batch_size: int = 42
seed:
@dataclass
class TrainConfig:
int = 1
epochs: int = 8
batch_size: str, Any] = field(
optimizer: Dict[=lambda: {"name": "adamw", "lr": 2e-4})
default_factorystr = "cpu" device:
2) Normalization Transform β geogfm/data/transforms/normalization.py
# Channel-wise normalization utilities (Week 1)
# Tangles to `geogfm/data/transforms/normalization.py`
from __future__ import annotations
import numpy as np
from typing import Tuple, Dict, Any, Optional
= np.ndarray
Array
def minmax_normalize(data: Array, global_min: Optional[Array] = None, global_max: Optional[Array] = None) -> tuple[Array, dict]:
= data.shape
bands, height, width = np.zeros_like(data, dtype=np.float32)
normalized if global_min is None or global_max is None:
= np.array([data[i].min() for i in range(bands)], dtype=np.float32)
mins = np.array([data[i].max() for i in range(bands)], dtype=np.float32)
maxs = "local"
src else:
= global_min.astype(np.float32), global_max.astype(np.float32), "global"
mins, maxs, src for i in range(bands):
= maxs[i] - mins[i]
rng if rng > 0:
= (data[i] - mins[i]) / rng
normalized[i] else:
= 0
normalized[i] = {"source": src, "mins": mins, "maxs": maxs, "output_range": (float(normalized.min()), float(normalized.max()))}
stats return normalized, stats
def zscore_normalize(data: Array, global_mean: Optional[Array] = None, global_std: Optional[Array] = None) -> tuple[Array, dict]:
= data.shape
bands, height, width = np.zeros_like(data, dtype=np.float32)
normalized if global_mean is None or global_std is None:
= np.array([data[i].mean() for i in range(bands)], dtype=np.float32)
means = np.array([data[i].std() for i in range(bands)], dtype=np.float32)
stds = "local"
src else:
= global_mean.astype(np.float32), global_std.astype(np.float32), "global"
means, stds, src for i in range(bands):
if stds[i] > 0:
= (data[i] - means[i]) / stds[i]
normalized[i] else:
= 0
normalized[i] = {"source": src, "means": means, "stds": stds, "output_mean": float(normalized.mean()), "output_std": float(normalized.std())}
stats return normalized, stats
4) Minimal STAC-like Dataset β geogfm/data/datasets/stac_dataset.py
# Minimal STAC-like dataset (Week 1)
# Tangles to `geogfm/data/datasets/stac_dataset.py`
#| auto-imports: true
from __future__ import annotations
from typing import List, Optional
from pathlib import Path
import random
import numpy as np
import torch
from torch.utils.data import Dataset
import rasterio as rio
class StacLikeDataset(Dataset):
"""Minimal dataset reading GeoTIFF files under a directory, or generating synthetic data.
Returns images sized to (C, H, W) where H=W=image_size and divisible by patch size.
"""
def __init__(self, root_dir: str, split: str = "train", image_size: int = 64, in_channels: int = 3, length: int = 64, seed: int = 42):
self.root = Path(root_dir)
self.split = split
self.image_size = int(image_size)
self.in_channels = int(in_channels)
self.length = int(length)
self.rng = random.Random(seed)
self.files: List[Path] = []
if self.root.exists():
for p in self.root.rglob("*.tif"):
self.files.append(p)
if split == "val":
self.files = self.files[::5]
elif split == "train":
self.files = [p for i, p in enumerate(self.files) if i % 5 != 0]
def __len__(self) -> int:
return max(len(self.files), self.length)
def _load_or_synthesize(self, idx: int) -> np.ndarray:
if self.files:
= self.files[idx % len(self.files)]
path with rio.open(path) as src:
= src.read(out_shape=(min(self.in_channels, src.count), self.image_size, self.image_size))
arr if arr.shape[0] < self.in_channels:
= np.zeros((self.in_channels - arr.shape[0], self.image_size, self.image_size), dtype=arr.dtype)
pad = np.concatenate([arr, pad], axis=0)
arr return arr.astype(np.float32)
# synthetic fallback
self.rng.seed(idx)
return np.random.rand(self.in_channels, self.image_size, self.image_size).astype(np.float32)
def __getitem__(self, idx: int) -> torch.Tensor:
= self._load_or_synthesize(idx)
arr return torch.from_numpy(arr)
5) DataLoader Builders β geogfm/data/loaders.py
# DataLoader builders (Week 1)
# Tangles to `geogfm/data/loaders.py`
#| eval: false
from __future__ import annotations
from typing import Tuple
from torch.utils.data import DataLoader
from geogfm.core.config import DataConfig
from geogfm.data.datasets.stac_dataset import StacLikeDataset
def build_dataloader(cfg: DataConfig) -> Tuple[DataLoader, DataLoader]:
= StacLikeDataset(cfg.root_dir, split="train", image_size=cfg.image_size, in_channels=cfg.in_channels, length=cfg.length, seed=cfg.seed)
train_ds = StacLikeDataset(cfg.root_dir, split="val", image_size=cfg.image_size, in_channels=cfg.in_channels, length=max(8, cfg.length // 5), seed=cfg.seed)
val_ds = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=False)
train_dl = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=False)
val_dl return train_dl, val_dl
Learning Objectives
By building this pipeline, you will: - Implement GeoTIFF loading and preprocessing functions - Create patch extraction with spatial metadata - Build tensor normalization and encoding functions
- Construct a PyTorch DataLoader for model training - Connect to a simple embedding layer to verify end-to-end functionality
Session Roadmap
flowchart TD A["Setup & GeoTIFF Loading"] --> B["Geo Preprocessing Functions"] B --> C["Patch Extraction with Metadata"] C --> D["Tensor Operations & Normalization"] D --> E["DataLoader Construction"] E --> F["Embedding Layer Integration"] F --> G["End-to-End Pipeline Test"]
Setting Up
Letβs establish our development environment and define the core constants weβll use throughout.
Imports and Configuration
import os
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import rasterio as rio
import matplotlib.pyplot as plt
from typing import Tuple, Dict, Any
# Set seeds for reproducibility
= 42
SEED
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# Pipeline constants
= 64
PATCH_SIZE = 32 # 50% overlap
STRIDE = 8
BATCH_SIZE = 256
EMBEDDING_DIM
print(f"β Environment setup complete")
print(f"β Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"β Stride: {STRIDE} (overlap: {(PATCH_SIZE-STRIDE)/PATCH_SIZE*100:.0f}%)")
β Environment setup complete
β Patch size: 64x64
β Stride: 32 (overlap: 50%)
Data Preparation
# Set up data paths - use book/data for course sample data
if "__file__" in globals():
# From chapters folder, go up 2 levels to book folder, then to data
= Path(__file__).parent.parent / "data"
DATA_DIR else:
# Fallback for interactive environments - look for book folder
= Path.cwd()
current while current.name not in ["book", "geoAI"] and current.parent != current:
= current.parent
current if current.name == "book":
= current / "data"
DATA_DIR elif current.name == "geoAI":
= current / "book" / "data"
DATA_DIR else:
= Path("data")
DATA_DIR
=True)
DATA_DIR.mkdir(exist_ok= DATA_DIR / "landcover_sample.tif"
SAMPLE_PATH
# Verify data file exists
if not SAMPLE_PATH.exists():
raise FileNotFoundError(f"Data file not found at {SAMPLE_PATH}. Please ensure the landcover_sample.tif file is available in the data directory.")
print(f"β Data ready: {SAMPLE_PATH.name}")
print(f"β File size: {SAMPLE_PATH.stat().st_size / 1024:.1f} KB")
print(f"β Full path: {SAMPLE_PATH}")
β Data ready: landcover_sample.tif
β File size: 12.6 KB
β Full path: /Users/kellycaylor/dev/geoAI/book/data/landcover_sample.tif
Step 1: GeoTIFF Loading and Inspection
Goal: Build a function that loads and extracts essential information from any GeoTIFF.
π οΈ Build It: GeoTIFF Loader Function
Your task: Complete this function to load a GeoTIFF and return both the data and metadata.
def load_geotiff(file_path: Path) -> Tuple[np.ndarray, Dict[str, Any]]:
"""
Load a GeoTIFF and extract data + metadata.
Returns:
data: (bands, height, width) array
metadata: dict with CRS, transform, resolution, etc.
"""
with rio.open(file_path) as src:
# TODO: Load the data array
= src.read() # YOUR CODE: Load raster data
data
# TODO: Extract metadata
= {
metadata 'crs': src.crs, # YOUR CODE: Get coordinate reference system
'transform': src.transform, # YOUR CODE: Get geospatial transform
'shape': data.shape, # YOUR CODE: Get array dimensions
'dtype': data.dtype, # YOUR CODE: Get data type
'resolution': src.res, # YOUR CODE: Get pixel resolution
'bounds': src.bounds, # YOUR CODE: Get spatial bounds
}
return data, metadata
# Test your function
= load_geotiff(SAMPLE_PATH)
data, metadata print(f"β Loaded shape: {data.shape}")
print(f"β Data type: {metadata['dtype']}")
print(f"β Resolution: {metadata['resolution']}")
print(f"β CRS: {metadata['crs']}")
β Loaded shape: (3, 64, 64)
β Data type: uint8
β Resolution: (0.25, 0.25)
β CRS: PROJCS["Projection: Transverse Mercator; Datum: WGS84; Ellipsoid: WGS84",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Transverse_Mercator"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",19],PARAMETER["scale_factor",0.9993],PARAMETER["false_easting",500000],PARAMETER["false_northing",-5300000],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH]]
π Verify It: Inspect Your Data
# Examine the data you loaded
= data.shape
bands, height, width print(f"Image dimensions: {height}Γ{width} pixels")
print(f"Number of bands: {bands}")
print(f"Value ranges per band:")
for i, band in enumerate(data):
print(f" Band {i+1}: {band.min():.0f} to {band.max():.0f}")
# Quick visualization
= plt.subplots(1, bands, figsize=(12, 4))
fig, axes if bands == 1:
= [axes]
axes
for i, band in enumerate(data):
='viridis')
axes[i].imshow(band, cmapf'Band {i+1}')
axes[i].set_title('off')
axes[i].axis(
'Raw Satellite Bands')
plt.suptitle(
plt.tight_layout() plt.show()
Image dimensions: 64Γ64 pixels
Number of bands: 3
Value ranges per band:
Band 1: 0 to 254
Band 2: 0 to 254
Band 3: 0 to 254
Step 2: Geo Preprocessing Functions
Goal: Build preprocessing functions that operate on the full image before patch extraction.
π οΈ Build It: Normalization Functions
Weβll create two normalization functions that can work with either local statistics (calculated from the input data) or global statistics (pre-computed from a training dataset). Global statistics ensure consistent normalization across different image tiles and are crucial for foundation model training.
Why use global statistics? When training on multiple images, each tile might have different value ranges. Using global statistics ensures that the same pixel value represents the same relative intensity across all training data.
Min-Max Normalization Function
def minmax_normalize(data: np.ndarray,
= None,
global_min: np.ndarray = None) -> tuple[np.ndarray, dict]:
global_max: np.ndarray """
Min-max normalize spectral bands to [0,1] range.
Args:
data: (bands, height, width) array
global_min: Optional (bands,) array of global minimums per band
global_max: Optional (bands,) array of global maximums per band
Returns:
normalized: (bands, height, width) array with values in [0,1]
stats: Dictionary containing the min/max values used
"""
= data.shape
bands, height, width = np.zeros_like(data, dtype=np.float32)
normalized
# Use global stats if provided, otherwise calculate from data
if global_min is None or global_max is None:
# Calculate per-band statistics from this data
= np.array([data[i].min() for i in range(bands)])
mins = np.array([data[i].max() for i in range(bands)])
maxs = "local (calculated from input)"
stats_source else:
# Use provided global statistics
= global_min
mins = global_max
maxs = "global (provided)"
stats_source
# Apply normalization per band
for i in range(bands):
= maxs[i] - mins[i]
band_range if band_range > 0: # Avoid division by zero
= (data[i] - mins[i]) / band_range
normalized[i] else:
= 0 # Handle constant bands
normalized[i]
# Package statistics for inspection
= {
stats 'source': stats_source,
'mins': mins,
'maxs': maxs,
'output_range': (normalized.min(), normalized.max())
}
return normalized, stats
Z-Score Normalization Function
def zscore_normalize(data: np.ndarray,
= None,
global_mean: np.ndarray = None) -> tuple[np.ndarray, dict]:
global_std: np.ndarray """
Z-score normalize spectral bands to mean=0, std=1.
Args:
data: (bands, height, width) array
global_mean: Optional (bands,) array of global means per band
global_std: Optional (bands,) array of global standard deviations per band
Returns:
normalized: (bands, height, width) standardized array
stats: Dictionary containing the mean/std values used
"""
= data.shape
bands, height, width = np.zeros_like(data, dtype=np.float32)
normalized
# Use global stats if provided, otherwise calculate from data
if global_mean is None or global_std is None:
# Calculate per-band statistics from this data
= np.array([data[i].mean() for i in range(bands)])
means = np.array([data[i].std() for i in range(bands)])
stds = "local (calculated from input)"
stats_source else:
# Use provided global statistics
= global_mean
means = global_std
stds = "global (provided)"
stats_source
# Apply normalization per band
for i in range(bands):
if stds[i] > 0: # Avoid division by zero
= (data[i] - means[i]) / stds[i]
normalized[i] else:
= 0 # Handle constant bands
normalized[i]
# Package statistics for inspection
= {
stats 'source': stats_source,
'means': means,
'stds': stds,
'output_mean': normalized.mean(),
'output_std': normalized.std()
}
return normalized, stats
print("β Normalization functions created")
print(" - minmax_normalize: scales to [0,1] range")
print(" - zscore_normalize: standardizes to mean=0, std=1")
β Normalization functions created
- minmax_normalize: scales to [0,1] range
- zscore_normalize: standardizes to mean=0, std=1
Test Both Functions with Local Statistics
# Test min-max normalization with local statistics
= minmax_normalize(data)
minmax_data, minmax_stats print("π Min-Max Normalization (local stats):")
print(f" Source: {minmax_stats['source']}")
print(f" Original range: {data.min():.0f} to {data.max():.0f}")
print(f" Normalized range: {minmax_stats['output_range'][0]:.3f} to {minmax_stats['output_range'][1]:.3f}")
print(f" Per-band mins: {minmax_stats['mins']}")
print(f" Per-band maxs: {minmax_stats['maxs']}")
print()
# Test z-score normalization with local statistics
= zscore_normalize(data)
zscore_data, zscore_stats print("π Z-Score Normalization (local stats):")
print(f" Source: {zscore_stats['source']}")
print(f" Output mean: {zscore_stats['output_mean']:.6f}")
print(f" Output std: {zscore_stats['output_std']:.6f}")
print(f" Per-band means: {zscore_stats['means']}")
print(f" Per-band stds: {zscore_stats['stds']}")
π Min-Max Normalization (local stats):
Source: local (calculated from input)
Original range: 0 to 254
Normalized range: 0.000 to 1.000
Per-band mins: [0 0 0]
Per-band maxs: [254 254 254]
π Z-Score Normalization (local stats):
Source: local (calculated from input)
Output mean: 0.000000
Output std: 1.000000
Per-band means: [126.14306641 126.14306641 126.14306641]
Per-band stds: [73.10237725 73.10237725 73.10237725]
Test with Global Statistics
# Simulate global statistics from a larger dataset
# In practice, these would be pre-computed from your entire training corpus
= np.array([100, 150, 200]) # Example global minimums per band
global_mins = np.array([1500, 2000, 2500]) # Example global maximums per band
global_maxs = np.array([800, 1200, 1600]) # Example global means per band
global_means = np.array([300, 400, 500]) # Example global standard deviations per band
global_stds
print("π Testing with Global Statistics:")
print(f" Global mins: {global_mins}")
print(f" Global maxs: {global_maxs}")
print(f" Global means: {global_means}")
print(f" Global stds: {global_stds}")
print()
# Test with global statistics
= minmax_normalize(data, global_mins, global_maxs)
minmax_global, minmax_global_stats = zscore_normalize(data, global_means, global_stds)
zscore_global, zscore_global_stats
print("π Min-Max with Global Stats:")
print(f" Source: {minmax_global_stats['source']}")
print(f" Output range: {minmax_global_stats['output_range'][0]:.3f} to {minmax_global_stats['output_range'][1]:.3f}")
print()
print("π Z-Score with Global Stats:")
print(f" Source: {zscore_global_stats['source']}")
print(f" Output mean: {zscore_global_stats['output_mean']:.3f}")
print(f" Output std: {zscore_global_stats['output_std']:.3f}")
π Testing with Global Statistics:
Global mins: [100 150 200]
Global maxs: [1500 2000 2500]
Global means: [ 800 1200 1600]
Global stds: [300 400 500]
π Min-Max with Global Stats:
Source: global (provided)
Output range: 0.000 to 0.182
π Z-Score with Global Stats:
Source: global (provided)
Output mean: 168.496
Output std: 36.333
What to notice: When using global statistics, the output ranges and distributions differ from local normalization. This is expected and ensures consistency across different image tiles in your dataset.
π οΈ Build It: Spatial Cropping Function
# Patch extraction utilities (Week 1)
# Tangles to `geogfm/data/transforms/patchify.py`
#| auto-imports: true
from __future__ import annotations
import numpy as np
= np.ndarray
Array
def crop_to_patches(data: Array, patch_size: int, stride: int) -> Array:
"""
Crop image to dimensions that allow complete patch extraction.
Args:
data: (bands, height, width) array
patch_size: size of patches to extract
stride: step size between patches
Returns:
cropped: (bands, new_height, new_width) array
"""
= data.shape
bands, height, width
# TODO: Calculate how many complete patches fit
= (height - patch_size) // stride + 1
patches_h = (width - patch_size) // stride + 1
patches_w
# TODO: Calculate the required dimensions
= (patches_h - 1) * stride + patch_size
new_height = (patches_w - 1) * stride + patch_size
new_width
# TODO: Crop the data
= data[:, :new_height, :new_width]
cropped
print(f"β Cropped from {height}Γ{width} to {new_height}Γ{new_width}")
print(
f"β Will generate {patches_h}Γ{patches_w} = {patches_h*patches_w} patches")
return cropped
# Test your cropping function
= crop_to_patches(data, 8, STRIDE) cropped_data
β Cropped from 64Γ64 to 40Γ40
β Will generate 2Γ2 = 4 patches
Step 3: Patch Extraction with Metadata
Goal: Extract patches while preserving spatial context information.
π οΈ Build It: Patch Extraction Function
# Patch extraction with spatial metadata (Week 1)
# Tangles (append) to `geogfm/data/transforms/patchify.py`
#| tangle-mode: append
#| auto-import: true
def extract_patches(
data: Array,int,
patch_size: int
stride: -> Array:
) """
Extract patches.
Args:
data: (bands, height, width) normalized array
patch_size: size of patches
stride: step between patches
Returns:
patches: (n_patches, bands, patch_size, patch_size) array
"""
= data.shape
bands, height, width = []
patches for r in range(0, height - patch_size + 1, stride):
for c in range(0, width - patch_size + 1, stride):
+patch_size, c:c+patch_size])
patches.append(data[:, r:rreturn np.stack(patches, axis=0)
def extract_patches_with_metadata(
data: Array,int,
patch_size: int,
stride:
transform-> Tuple[Array, Array]:
) """
Extract patches with their spatial coordinates.
Args:
data: (bands, height, width) normalized array
patch_size: size of patches
stride: step between patches
transform: rasterio transform object
Returns:
patches: (n_patches, bands, patch_size, patch_size) array
coordinates: (n_patches, 4) array of [min_x, min_y, max_x, max_y]
"""
= data.shape
bands, height, width = []
patches = []
coordinates
# TODO: Iterate through patch positions
for row in range(0, height - patch_size + 1, stride):
for col in range(0, width - patch_size + 1, stride):
# TODO: Extract patch from all bands
= data[:, row:row+patch_size, col:col+patch_size]
patch
patches.append(patch)
# TODO: Calculate real-world coordinates using transform
= transform * (col, row) # Top-left
min_x, max_y = transform * \
max_x, min_y + patch_size, row + patch_size) # Bottom-right
(col
coordinates.append([min_x, min_y, max_x, max_y])
= np.array(patches)
patches = np.array(coordinates)
coordinates
print(f"β Extracted {len(patches)} patches")
print(f"β Patch shape: {patches.shape}")
print(f"β Coordinate shape: {coordinates.shape}")
return patches, coordinates
# Test your patch extraction
= extract_patches_with_metadata(
patches, coords 8, 4, metadata['transform']
data,
)
# Visualize a few patches
= plt.subplots(2, 4, figsize=(12, 6))
fig, axes for i in range(8):
= i // 4, i % 4
row, col # Show first band of each patch
0], cmap='viridis')
axes[row, col].imshow(patches[i, f'Patch {i}')
axes[row, col].set_title('off')
axes[row, col].axis(
'Sample Extracted Patches (Band 1)')
plt.suptitle(
plt.tight_layout() plt.show()
β Extracted 225 patches
β Patch shape: (225, 3, 8, 8)
β Coordinate shape: (225, 4)
# Reconstruction from patches (Week 1)
# Tangles (append) to `geogfm/data/transforms/patchify.py`
#| tangle-mode: append
def reconstruct_from_patches(
patches: Array,int,
height: int,
width: int
patch_size: -> Array:
) """
Reassemble a (bands, H, W) image from non-overlapping square patches.
Parameters
----------
patches : Array
Patches in row-major scan order with shape (N, bands, patch_size, patch_size),
where N must equal (height // patch_size) * (width // patch_size).
height : int
Target image height in pixels.
width : int
Target image width in pixels.
patch_size : int
Size of each square patch in pixels. Assumes stride == patch_size (no overlap).
Returns
-------
Array
Reconstructed image of shape (bands, height, width).
Notes
-----
- Assumes patches are laid out left-to-right, top-to-bottom (row-major).
- Ignores any remainder if `height` or `width` is not divisible by `patch_size`
(i.e., only the `grid_h * patch_size` by `grid_w * patch_size` area is filled).
- No blending is performed (because there is no overlap).
Examples
--------
>>> # Suppose height=width=64, patch_size=32, bands=13
>>> # patches.shape == (4, 13, 32, 32), ordered row-major
>>> img = reconstruct_from_patches(patches, 64, 64, 32)
>>> img.shape
(13, 64, 64)
"""
= patches.shape[1]
bands = height // patch_size
grid_h = width // patch_size
grid_w = np.zeros((bands, height, width), dtype=patches.dtype)
out = 0
idx for r in range(grid_h):
for c in range(grid_w):
*patch_size:(r+1)*patch_size, c *
out[:, r+1)*patch_size] = patches[idx]
patch_size:(c+= 1
idx return out
import numpy as np
# Let's assume you have already defined reconstruct_from_patches from above
# Example parameters
= 3 # e.g., RGB
bands = 8 # image height in pixels
height = 8 # image width in pixels
width = 4 # each patch is 4x4 pixels
patch_size
# Number of patches needed to cover the image (row-major order)
= height // patch_size
grid_h = width // patch_size
grid_w = grid_h * grid_w
num_patches
# Create random patches (N, bands, patch_size, patch_size)
= np.random.randint(
patches =0, high=256,
low=(num_patches, bands, patch_size, patch_size),
size=np.uint8
dtype
)
print("Patches shape:", patches.shape)
print("First patch (band 0):\n", patches[0, 0])
# Reconstruct the image
= reconstruct_from_patches(patches, height, width, patch_size)
reconstructed
print("\nReconstructed image shape:", reconstructed.shape)
print("Top-left 4x4 of reconstructed image, band 0:\n",
0, :4, :4])
reconstructed[
# Verify reconstruction matches original patches in correct positions
# The top-left 4x4 region should be identical to the first patch (band 0)
assert np.array_equal(reconstructed[0, :4, :4], patches[0, 0])
print("\nVerification passed: top-left patch matches the original.")
Patches shape: (4, 3, 4, 4)
First patch (band 0):
[[102 220 225 95]
[179 61 234 203]
[ 92 3 98 243]
[ 14 149 245 46]]
Reconstructed image shape: (3, 8, 8)
Top-left 4x4 of reconstructed image, band 0:
[[102 220 225 95]
[179 61 234 203]
[ 92 3 98 243]
[ 14 149 245 46]]
Verification passed: top-left patch matches the original.
Step 4: Tensor Operations & Metadata Encoding
Goal: Convert numpy arrays to PyTorch tensors and encode metadata.
π οΈ Build It: Metadata Encoder
def encode_metadata(coordinates: np.ndarray) -> np.ndarray:
"""
Encode spatial metadata as features.
Args:
coordinates: (n_patches, 4) array of [min_x, min_y, max_x, max_y]
Returns:
encoded: (n_patches, n_features) array
"""
# TODO: Calculate spatial features
= (coordinates[:, 0] + coordinates[:, 2]) / 2
center_x = (coordinates[:, 1] + coordinates[:, 3]) / 2
center_y = coordinates[:, 2] - coordinates[:, 0]
width = coordinates[:, 3] - coordinates[:, 1]
height = width * height
area
# TODO: Normalize spatial features (handle zero std to avoid divide by zero)
def safe_normalize(values):
"""Normalize values, handling zero standard deviation."""
= values.mean()
mean_val = values.std()
std_val if std_val > 0:
return (values - mean_val) / std_val
else:
return np.zeros_like(values) # All values are the same
= np.column_stack([
features # Normalized center X
safe_normalize(center_x), # Normalized center Y
safe_normalize(center_y), # Normalized area (handles constant area)
safe_normalize(area), / height, # Aspect ratio
width
])
print(f"β Encoded metadata shape: {features.shape}")
print(f"β Feature statistics:")
= ['center_x', 'center_y', 'area', 'aspect_ratio']
feature_names for i, name in enumerate(feature_names):
print(
f" {name}: mean={features[:, i].mean():.3f}, std={features[:, i].std():.3f}")
return features.astype(np.float32)
# Test metadata encoding
= encode_metadata(coords) encoded_metadata
β Encoded metadata shape: (225, 4)
β Feature statistics:
center_x: mean=0.000, std=1.000
center_y: mean=0.000, std=1.000
area: mean=0.000, std=0.000
aspect_ratio: mean=1.000, std=0.000
π οΈ Build It: Tensor Conversion
def create_tensors(patches: np.ndarray, metadata: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert numpy arrays to PyTorch tensors.
Args:
patches: (n_patches, bands, height, width) array
metadata: (n_patches, n_features) array
Returns:
patch_tensors: (n_patches, bands, height, width) tensor
metadata_tensors: (n_patches, n_features) tensor
"""
# TODO: Convert to tensors with appropriate dtypes
= torch.from_numpy(patches).float()
patch_tensors = torch.from_numpy(metadata).float()
metadata_tensors
print(f"β Patch tensors: {patch_tensors.shape}, dtype: {patch_tensors.dtype}")
print(f"β Metadata tensors: {metadata_tensors.shape}, dtype: {metadata_tensors.dtype}")
return patch_tensors, metadata_tensors
# Create tensors
= create_tensors(patches, encoded_metadata) patch_tensors, metadata_tensors
β Patch tensors: torch.Size([4, 3, 4, 4]), dtype: torch.float32
β Metadata tensors: torch.Size([225, 4]), dtype: torch.float32
Step 5: DataLoader Construction
Goal: Build a PyTorch Dataset and DataLoader for training.
π οΈ Build It: Custom Dataset Class
class GeospatialDataset(Dataset):
"""Dataset for geospatial patches with metadata."""
def __init__(self, patch_tensors: torch.Tensor, metadata_tensors: torch.Tensor):
"""
Args:
patch_tensors: (n_patches, bands, height, width)
metadata_tensors: (n_patches, n_features)
"""
self.patches = patch_tensors
self.metadata = metadata_tensors
# TODO: Create dummy labels for demonstration (in real use, load from file)
self.labels = torch.randint(0, 5, (len(patch_tensors),)) # 5 land cover classes
def __len__(self) -> int:
"""Return number of patches."""
return len(self.patches)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get a single item.
Returns:
patch: (bands, height, width) tensor
metadata: (n_features,) tensor
label: scalar tensor
"""
return self.patches[idx], self.metadata[idx], self.labels[idx]
# Test your dataset
= GeospatialDataset(patch_tensors, metadata_tensors)
dataset print(f"β Dataset length: {len(dataset)}")
# Test getting an item
= dataset[0]
sample_patch, sample_metadata, sample_label print(f"β Sample patch shape: {sample_patch.shape}")
print(f"β Sample metadata shape: {sample_metadata.shape}")
print(f"β Sample label: {sample_label.item()}")
β Dataset length: 4
β Sample patch shape: torch.Size([3, 4, 4])
β Sample metadata shape: torch.Size([4])
β Sample label: 2
π οΈ Build It: DataLoader
# TODO: Create DataLoader with appropriate batch size and shuffling
= DataLoader(
dataloader
dataset, =BATCH_SIZE,
batch_size=True,
shuffle=0, # Set to 0 for compatibility
num_workers=True if torch.cuda.is_available() else False
pin_memory
)
print(f"β DataLoader created with batch size {BATCH_SIZE}")
print(f"β Number of batches: {len(dataloader)}")
# Test the DataLoader
for batch_idx, (patches, metadata, labels) in enumerate(dataloader):
print(f"β Batch {batch_idx}:")
print(f" Patches: {patches.shape}")
print(f" Metadata: {metadata.shape}")
print(f" Labels: {labels.shape}")
if batch_idx == 1: # Show first two batches
break
β DataLoader created with batch size 8
β Number of batches: 1
β Batch 0:
Patches: torch.Size([4, 3, 4, 4])
Metadata: torch.Size([4, 4])
Labels: torch.Size([4])
Step 6: Embedding Layer Integration
Goal: Connect to a simple embedding layer to verify end-to-end functionality.
π οΈ Build It: Simple GFM Embedding Layer
class SimpleGFMEmbedding(nn.Module):
"""Simple embedding layer for geospatial patches."""
def __init__(self, input_channels: int, metadata_features: int, embed_dim: int, patch_size: int = 64):
super().__init__()
# TODO: Build adaptive patch encoder based on patch size
if patch_size >= 32:
# Larger patches: multi-layer CNN
= min(8, patch_size // 4)
kernel1 = min(4, patch_size // 8)
kernel2 self.patch_encoder = nn.Sequential(
32, kernel_size=kernel1, stride=kernel1//2),
nn.Conv2d(input_channels,
nn.ReLU(),32, 64, kernel_size=kernel2, stride=kernel2//2),
nn.Conv2d(
nn.ReLU(),1),
nn.AdaptiveAvgPool2d(
nn.Flatten(),
)else:
# Smaller patches: simpler encoder
= min(4, patch_size // 2)
kernel self.patch_encoder = nn.Sequential(
64, kernel_size=kernel, stride=1),
nn.Conv2d(input_channels,
nn.ReLU(),1),
nn.AdaptiveAvgPool2d(
nn.Flatten(),
)
# TODO: Build metadata encoder
self.metadata_encoder = nn.Sequential(
32),
nn.Linear(metadata_features,
nn.ReLU(),32, 32),
nn.Linear(
)
# TODO: Build fusion layer
# Calculate patch encoder output size
with torch.no_grad():
= torch.randn(1, input_channels, patch_size, patch_size)
dummy_patch = self.patch_encoder(dummy_patch).shape[1]
patch_feat_size
self.fusion = nn.Sequential(
+ 32, embed_dim),
nn.Linear(patch_feat_size
nn.ReLU(),
nn.Linear(embed_dim, embed_dim),
)
def forward(self, patches: torch.Tensor, metadata: torch.Tensor) -> torch.Tensor:
"""
Args:
patches: (batch, channels, height, width)
metadata: (batch, n_features)
Returns:
embeddings: (batch, embed_dim)
"""
# TODO: Encode patches and metadata
= self.patch_encoder(patches)
patch_features = self.metadata_encoder(metadata)
metadata_features
# TODO: Fuse features
= torch.cat([patch_features, metadata_features], dim=1)
combined = self.fusion(combined)
embeddings
return embeddings
# Create and test the model
= SimpleGFMEmbedding(
model =bands,
input_channels=encoded_metadata.shape[1],
metadata_features=EMBEDDING_DIM,
embed_dim=PATCH_SIZE
patch_size
)
print(f"β Model created")
print(f"β Model parameters: {sum(p.numel() for p in model.parameters()):,}")
β Model created
β Model parameters: 130,848
Step 7: End-to-End Pipeline Test
Goal: Run the complete pipeline and verify everything works together.
π οΈ Build It: Complete Pipeline Function
def geotiff_to_embeddings_pipeline(
file_path: Path,int = 8,
patch_size: int = 4,
stride: int = 8,
batch_size: int = 256
embed_dim: -> torch.Tensor:
) """
Complete pipeline from GeoTIFF to embeddings.
Args:
file_path: Path to GeoTIFF file
patch_size: Size of patches to extract
stride: Step between patches
batch_size: Batch size for processing
embed_dim: Embedding dimension
Returns:
all_embeddings: (n_patches, embed_dim) tensor
"""
print("π Starting GeoTIFF β Embeddings Pipeline")
# Step 1: Load data
print("π Loading GeoTIFF...")
= load_geotiff(file_path)
data, metadata
# Step 2: Preprocess
print("π§ Preprocessing...")
= minmax_normalize(data)
norm_data, norm_stats = crop_to_patches(norm_data, patch_size, stride)
cropped_data
# Step 3: Extract patches
print("βοΈ Extracting patches...")
= extract_patches_with_metadata(cropped_data, patch_size, stride, metadata['transform'])
patches, coords
# Step 4: Encode metadata
print("π Encoding metadata...")
= encode_metadata(coords)
encoded_meta
# Step 5: Create tensors
print("π’ Creating tensors...")
= create_tensors(patches, encoded_meta)
patch_tensors, meta_tensors
# Step 6: Create dataset and dataloader
print("π¦ Creating DataLoader...")
= GeospatialDataset(patch_tensors, meta_tensors)
dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False)
dataloader
# Step 7: Create model and generate embeddings
print("π§ Generating embeddings...")
= SimpleGFMEmbedding(
model =data.shape[0],
input_channels=encoded_meta.shape[1],
metadata_features=embed_dim,
embed_dim=patch_size
patch_size
)eval()
model.
= []
all_embeddings with torch.no_grad():
for patches_batch, meta_batch, _ in dataloader:
= model(patches_batch, meta_batch)
embeddings
all_embeddings.append(embeddings)
= torch.cat(all_embeddings, dim=0)
all_embeddings print(f"β
Pipeline complete! Generated {len(all_embeddings)} embeddings")
return all_embeddings
# Run the complete pipeline
= geotiff_to_embeddings_pipeline(SAMPLE_PATH)
embeddings print(f"\nπ Final Result:")
print(f"β Embeddings shape: {embeddings.shape}")
print(f"β Embedding statistics:")
print(f" Mean: {embeddings.mean().item():.4f}")
print(f" Std: {embeddings.std().item():.4f}")
print(f" Min: {embeddings.min().item():.4f}")
print(f" Max: {embeddings.max().item():.4f}")
π Starting GeoTIFF β Embeddings Pipeline
π Loading GeoTIFF...
π§ Preprocessing...
β Cropped from 64Γ64 to 64Γ64
β Will generate 15Γ15 = 225 patches
βοΈ Extracting patches...
β Extracted 225 patches
β Patch shape: (225, 3, 8, 8)
β Coordinate shape: (225, 4)
π Encoding metadata...
β Encoded metadata shape: (225, 4)
β Feature statistics:
center_x: mean=0.000, std=1.000
center_y: mean=0.000, std=1.000
area: mean=0.000, std=0.000
aspect_ratio: mean=1.000, std=0.000
π’ Creating tensors...
β Patch tensors: torch.Size([225, 3, 8, 8]), dtype: torch.float32
β Metadata tensors: torch.Size([225, 4]), dtype: torch.float32
π¦ Creating DataLoader...
π§ Generating embeddings...
β
Pipeline complete! Generated 225 embeddings
π Final Result:
β Embeddings shape: torch.Size([225, 256])
β Embedding statistics:
Mean: -0.0052
Std: 0.0725
Min: -0.2700
Max: 0.3388
π Verify It: Pipeline Output Analysis
# Visualize embedding similarities
print("π Analyzing embedding relationships...")
# Check if we have enough embeddings for analysis
if len(embeddings) < 10:
print(f"β οΈ Only {len(embeddings)} embeddings available, using all of them")
= len(embeddings)
sample_size else:
print(f"β Using first 10 of {len(embeddings)} embeddings for similarity analysis")
= 10
sample_size
if sample_size > 1:
# Compute pairwise cosine similarities
from torch.nn.functional import cosine_similarity
= embeddings[:sample_size]
sample_embeddings = torch.zeros(sample_size, sample_size)
similarity_matrix
for i in range(sample_size):
for j in range(sample_size):
if i == j:
= 1.0 # Perfect self-similarity
similarity_matrix[i, j] else:
= cosine_similarity(sample_embeddings[i:i+1], sample_embeddings[j:j+1], dim=1)
sim = sim.item()
similarity_matrix[i, j]
# Plot similarity matrix
=(8, 6))
plt.figure(figsize='viridis', vmin=-1, vmax=1)
plt.imshow(similarity_matrix.numpy(), cmap='Cosine Similarity')
plt.colorbar(labelf'Embedding Similarity Matrix (First {sample_size} Patches)')
plt.title('Patch Index')
plt.xlabel('Patch Index')
plt.ylabel(
plt.show()
print(f"β Average similarity: {similarity_matrix.mean().item():.4f}")
print(f"β Similarity range: {similarity_matrix.min().item():.4f} to {similarity_matrix.max().item():.4f}")
else:
print("β οΈ Not enough embeddings for similarity analysis")
π Analyzing embedding relationships...
β Using first 10 of 225 embeddings for similarity analysis
β Average similarity: 0.9737
β Similarity range: 0.8781 to 1.0000
Conclusion
π Congratulations! Youβve successfully built a complete pipeline that transforms raw satellite imagery into model-ready embeddings.
What You Built:
- GeoTIFF Loader: Extracts both pixel data and spatial metadata
- Preprocessing Functions: Normalization and spatial cropping
- Patch Extractor: Creates patches while preserving spatial context
- Metadata Encoder: Transforms coordinates into learned features
- PyTorch Integration: Dataset, DataLoader, and model components
- Embedding Generator: Simple CNN that produces vector representations
Key Insights:
- Spatial Context Matters: Each patch carries location information
- Preprocessing is Critical: Normalization ensures stable training
- Modular Design: Each step can be optimized independently
- End-to-End Testing: Verify the complete pipeline works
Whatβs Next:
In the following sessions, youβll enhance each component: - Week 2: Advanced attention mechanisms for spatial relationships - Week 3: Complete GFM architecture with transformer blocks - Week 4: Pretraining strategies and masked autoencoding
The pipeline you built today forms the foundation for everything that follows! π