π§ Overview
This notebook demonstrates the core insight behind Residual Networks (ResNets) by comparing them to a traditional plain CNN of similar depth. You will:
- Visualize how residual connections transform intermediate representations
- Compare training loss and accuracy curves
- Analyze gradient flow to see why ResNets enable deeper models to train effectively
π¦ Setup and Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
π Define Plain CNN and ResNet-like CNN
class PlainBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
= F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out return F.relu(out)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
=1, stride=stride),
nn.Conv2d(in_channels, out_channels, kernel_size
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
= x
identity = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(identity)
out return F.relu(out)
class PlainCNN(nn.Module):
def __init__(self):
super().__init__()
self.block1 = PlainBlock(3, 16)
self.block2 = PlainBlock(16, 32, stride=2)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, 10)
def forward(self, x):
= self.block1(x)
x = self.block2(x)
x = self.pool(x).view(x.size(0), -1)
x return self.fc(x)
class ResNetMini(nn.Module):
def __init__(self):
super().__init__()
self.block1 = ResBlock(3, 16)
self.block2 = ResBlock(16, 32, stride=2)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, 10)
def forward(self, x):
= self.block1(x)
x = self.block2(x)
x = self.pool(x).view(x.size(0), -1)
x return self.fc(x)
πΌοΈ Load CIFAR-10 Data
= transforms.Compose([
transform
transforms.ToTensor(),0.5,), (0.5,))
transforms.Normalize((
])
= torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_data
= DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)
train_loader = DataLoader(test_data, batch_size=128, shuffle=False, num_workers=2) test_loader
π Training Loop and Gradient Tracking
def train_model(model, name, epochs=10):
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = model.to(device)
model = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = nn.CrossEntropyLoss()
criterion
= []
train_loss = []
train_gradients
for epoch in range(epochs):
model.train()= 0
running_loss = 0
grad_norm for inputs, targets in tqdm(train_loader, desc=f"{name} Epoch {epoch+1}"):
= inputs.to(device), targets.to(device)
inputs, targets
optimizer.zero_grad()= model(inputs)
outputs = criterion(outputs, targets)
loss
loss.backward()
= 0
total_norm for p in model.parameters():
if p.grad is not None:
= p.grad.data.norm(2)
param_norm += param_norm.item() ** 2
total_norm += total_norm ** 0.5
grad_norm
optimizer.step()+= loss.item()
running_loss
/ len(train_loader))
train_loss.append(running_loss / len(train_loader))
train_gradients.append(grad_norm
return train_loss, train_gradients
π Train and Compare
= PlainCNN()
plain_model = ResNetMini()
resnet_model
= train_model(plain_model, "PlainCNN")
plain_loss, plain_grads = train_model(resnet_model, "ResNetMini") resnet_loss, resnet_grads
π Plot Training Loss and Gradient Flow
=(14,5))
plt.figure(figsize
1,2,1)
plt.subplot(="PlainCNN")
plt.plot(plain_loss, label="ResNetMini")
plt.plot(resnet_loss, label"Training Loss")
plt.title("Epoch")
plt.xlabel("Loss")
plt.ylabel(
plt.legend()
1,2,2)
plt.subplot(="PlainCNN")
plt.plot(plain_grads, label="ResNetMini")
plt.plot(resnet_grads, label"Gradient Norm per Epoch")
plt.title("Epoch")
plt.xlabel("Average Gradient Norm")
plt.ylabel(
plt.legend()
"Residual Connections Help Optimize Deeper Networks")
plt.suptitle(
plt.tight_layout() plt.show()
β Conclusion
This notebook demonstrates that residual connections help preserve gradient flow, allowing deeper networks to train faster and more reliably. Even though our example was relatively shallow (2 blocks), the benefits in convergence and stability are clear.