📚 Table of Contents
- Introduction
- Why Do We Need Batch Normalization?
- How Batch Normalization Works
- Mathematical Breakdown
- Step-by-Step Example with Python
- Summary
- References
🚀 Introduction
Batch Normalization (BN) is a technique used in deep learning to normalize activations within a network, improving training speed, stability, and performance. It was introduced in 2015 by Ioffe and Szegedy and has since become a standard component in Convolutional Neural Networks (CNNs).
❓ Why Do We Need Batch Normalization?
Training deep neural networks can be challenging due to issues such as internal covariate shift and exploding/vanishing gradients. Batch Normalization helps by:
- Standardizing inputs to each layer to prevent large variations in activations.
- Accelerating training by allowing higher learning rates.
- Reducing sensitivity to weight initialization.
- Acting as a regularizer, reducing the need for dropout in some cases.
🔍 How Batch Normalization Works
Batch Normalization is applied after the convolutional layer (or fully connected layer) but before the activation function. It standardizes activations using the formula:
-
Compute the mean and variance for each feature in a batch:
-
Normalize the activations:
-
Scale and shift using learnable parameters ( \gamma ) and ( \beta ):
where:
- (scale) and (shift) are learnable parameters.
- is a small constant for numerical stability.
🛠️ Step-by-Step Example with Python
📌 Installing Dependencies
pip install torch torchvision numpy matplotlib
🔧 Implementing Batch Normalization in PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# Define a simple CNN with Batch Normalization
class CNNWithBatchNorm(nn.Module):
def __init__(self):
super(CNNWithBatchNorm, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(16) # BatchNorm for 16 feature maps
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(32)
self.fc1 = nn.Linear(32 * 8 * 8, 10) # Fully connected layer
def forward(self, x):
x = torch.relu(self.bn1(self.conv1(x))) # Apply BatchNorm after conv
x = torch.relu(self.bn2(self.conv2(x)))
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
return x
# Initialize model
model = CNNWithBatchNorm()
print(model)
➕ Running the Code
# Generate random input tensor simulating an image batch
input_tensor = torch.randn(4, 3, 32, 32) # Batch of 4 images, 3 channels (RGB), 32x32 size
# Forward pass through the network
output = model(input_tensor)
print("Output shape:", output.shape) # Should be (4, 10) for 10 classes
🎯 Summary
- Batch Normalization stabilizes training by normalizing activations within a batch.
- It speeds up training and allows for higher learning rates.
- It acts as a regularizer, reducing overfitting in deep networks.
- PyTorch provides
nn.BatchNorm2d()
for CNNs andnn.BatchNorm1d()
for fully connected layers.