Skip to content

The Attention Mechanism in Neural Networks Explained with Examples

Published:

📚 Table of Contents


🚀 Introduction

Attention mechanism has revolutionized NLP and deep learning by allowing models to focus on relevant parts of the input when generating output—overcoming the bottleneck of fixed-size context vectors in classic encoder-decoder architectures.


âť“ Why Do We Need Attention?

In vanilla Seq2Seq models, the encoder compresses the input into a fixed-size vector. This works fine for short sentences but fails with longer or complex ones. Attention solves this by dynamically weighting all encoder outputs during decoding.


🔍 How Attention Works

In each decoding step:

  1. Compare the decoder’s current state with all encoder outputs.
  2. Compute alignment scores.
  3. Apply softmax to get attention weights.
  4. Compute a context vector as the weighted sum of encoder outputs.
  5. Combine it with the decoder state to predict the next token.

đź§  Types of Attention

We’ll implement Additive Attention as it’s easier to understand.


🛠️ Step-by-Step Example with Python

📌 Installing Dependencies

pip install torch numpy

📌 Preparing Data

We’ll use a dummy dataset for clarity.

import torch
import torch.nn as nn
import torch.nn.functional as F

# Simulate encoder outputs (sequence of hidden states) and decoder hidden state
encoder_outputs = torch.randn(5, 1, 16)  # 5 timesteps, batch size 1, hidden size 16
decoder_hidden = torch.randn(1, 1, 16)   # current decoder hidden state

đź”§ Implementing Additive Attention

class AdditiveAttention(nn.Module):
    def __init__(self, hidden_size):
        super(AdditiveAttention, self).__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs):
        # Repeat decoder hidden state across sequence length
        seq_len = encoder_outputs.size(0)
        hidden = hidden.repeat(seq_len, 1, 1)

        # Calculate energy
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        energy = energy.transpose(1, 2)  # reshape for batch matrix multiplication

        # Learnable vector for scoring
        v = self.v.repeat(encoder_outputs.size(1), 1).unsqueeze(1)

        # Compute attention weights
        attention_weights = torch.bmm(v, energy).squeeze(1)
        return F.softmax(attention_weights, dim=1)

âž• Using It in Practice

# Instantiate and apply attention mechanism
attention = AdditiveAttention(hidden_size=16)
weights = attention(decoder_hidden, encoder_outputs)
print("Attention weights:", weights)

This will output attention weights for each encoder timestep — showing how much the decoder focuses on each input position during prediction.


🎯 Summary

Attention allows neural networks to dynamically focus on relevant parts of the input. It’s the foundation of modern architectures like Transformers and improves performance on tasks like translation, summarization, and more.

đź“– References