🧠 Overview
This notebook demonstrates the core insight behind Residual Networks (ResNets) by comparing them to a traditional plain CNN of similar depth. You will:
- Visualize how residual connections transform intermediate representations
- Compare training loss and accuracy curves
- Analyze gradient flow to see why ResNets enable deeper models to train effectively
📦 Setup and Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm📊 Define Plain CNN and ResNet-like CNN
class PlainBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return F.relu(out)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = x
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(identity)
return F.relu(out)
class PlainCNN(nn.Module):
def __init__(self):
super().__init__()
self.block1 = PlainBlock(3, 16)
self.block2 = PlainBlock(16, 32, stride=2)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.pool(x).view(x.size(0), -1)
return self.fc(x)
class ResNetMini(nn.Module):
def __init__(self):
super().__init__()
self.block1 = ResBlock(3, 16)
self.block2 = ResBlock(16, 32, stride=2)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.pool(x).view(x.size(0), -1)
return self.fc(x)🖼️ Load CIFAR-10 Data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False, num_workers=2)🚀 Training Loop and Gradient Tracking
def train_model(model, name, epochs=10):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
train_loss = []
train_gradients = []
for epoch in range(epochs):
model.train()
running_loss = 0
grad_norm = 0
for inputs, targets in tqdm(train_loader, desc=f"{name} Epoch {epoch+1}"):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
grad_norm += total_norm ** 0.5
optimizer.step()
running_loss += loss.item()
train_loss.append(running_loss / len(train_loader))
train_gradients.append(grad_norm / len(train_loader))
return train_loss, train_gradients📈 Train and Compare
plain_model = PlainCNN()
resnet_model = ResNetMini()
plain_loss, plain_grads = train_model(plain_model, "PlainCNN")
resnet_loss, resnet_grads = train_model(resnet_model, "ResNetMini")📊 Plot Training Loss and Gradient Flow
plt.figure(figsize=(14,5))
plt.subplot(1,2,1)
plt.plot(plain_loss, label="PlainCNN")
plt.plot(resnet_loss, label="ResNetMini")
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.subplot(1,2,2)
plt.plot(plain_grads, label="PlainCNN")
plt.plot(resnet_grads, label="ResNetMini")
plt.title("Gradient Norm per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Average Gradient Norm")
plt.legend()
plt.suptitle("Residual Connections Help Optimize Deeper Networks")
plt.tight_layout()
plt.show()✅ Conclusion
This notebook demonstrates that residual connections help preserve gradient flow, allowing deeper networks to train faster and more reliably. Even though our example was relatively shallow (2 blocks), the benefits in convergence and stability are clear.