Skip to content

Batch Normalization in CNNs: How It Works with Examples

Published:

📚 Table of Contents


🚀 Introduction

Batch Normalization (BN) is a technique used in deep learning to normalize activations within a network, improving training speed, stability, and performance. It was introduced in 2015 by Ioffe and Szegedy and has since become a standard component in Convolutional Neural Networks (CNNs).


❓ Why Do We Need Batch Normalization?

Training deep neural networks can be challenging due to issues such as internal covariate shift and exploding/vanishing gradients. Batch Normalization helps by:


🔍 How Batch Normalization Works

Batch Normalization is applied after the convolutional layer (or fully connected layer) but before the activation function. It standardizes activations using the formula:

  1. Compute the mean and variance for each feature in a batch:

    μ=1mi=1mxi\mu = \frac{1}{m} \sum_{i=1}^{m} x_i σ2=1mi=1m(xiμ)2\sigma^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu)^2
  2. Normalize the activations:

    xi^=xiμσ2+ϵ\hat{x_i} = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}
  3. Scale and shift using learnable parameters ( \gamma ) and ( \beta ):

    yi=γxi^+βy_i = \gamma \hat{x_i} + \beta

    where:

    • γ\gamma (scale) and β\beta (shift) are learnable parameters.
    • ϵ\epsilon is a small constant for numerical stability.

🛠️ Step-by-Step Example with Python

📌 Installing Dependencies

pip install torch torchvision numpy matplotlib

🔧 Implementing Batch Normalization in PyTorch

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Define a simple CNN with Batch Normalization
class CNNWithBatchNorm(nn.Module):
    def __init__(self):
        super(CNNWithBatchNorm, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)  # BatchNorm for 16 feature maps
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 8 * 8, 10)  # Fully connected layer

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))  # Apply BatchNorm after conv
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
        return x

# Initialize model
model = CNNWithBatchNorm()
print(model)

➕ Running the Code

# Generate random input tensor simulating an image batch
input_tensor = torch.randn(4, 3, 32, 32)  # Batch of 4 images, 3 channels (RGB), 32x32 size

# Forward pass through the network
output = model(input_tensor)
print("Output shape:", output.shape)  # Should be (4, 10) for 10 classes

🎯 Summary


📖 References