Skip to content

Self-Attention in Neural Networks: A Simple Guide with Examples

Published:

๐Ÿ“š Table of Contents


๐Ÿš€ Introduction

Self-attention is a key mechanism in deep learning models, especially Transformers, which allows neural networks to weigh different parts of an input sequence when making predictions. Unlike traditional attention, self-attention works within a sequence, meaning each token attends to all others.

โ“ Why Do We Need Self-Attention?


๐Ÿ” How Self-Attention Works

Given an input sentence, self-attention assigns a score to each word based on its relevance to other words. The process involves:

  1. Create Query (Q), Key (K), and Value (V) Matrices

    • Each input token is projected into three different vectors.
  2. Compute Attention Scores

    • Scores are computed using the formula:

      Attention=softmax(QKTdk)V\text{Attention} = \text{softmax} \left( \frac{Q K^T}{\sqrt{d_k}} \right) V
    • This determines how much focus each word should get.

  3. Generate the Output

    • The weighted sum of values forms the new representation of each word.

๐Ÿงฎ Mathematical Breakdown

The self-attention mechanism follows these key steps:

  1. Compute Query (Q), Key (K), and Value (V) Matrices:

    Q=XWQ,K=XWK,V=XWVQ = X W_Q, \quad K = X W_K, \quad V = X W_V

    where WQ,WK,WVW_Q, W_K, W_V are learned weight matrices.

  2. Compute Attention Scores:

    Scores=QKTdk\text{Scores} = \frac{Q K^T}{\sqrt{d_k}}
  3. Apply Softmax to Normalize Scores:

    ฮฑ=softmax(Scores)\alpha = \text{softmax}(\text{Scores})
  4. Compute Final Self-Attention Output:

    Output=ฮฑV\text{Output} = \alpha V

๐Ÿ› ๏ธ Step-by-Step Example with Python

๐Ÿ“Œ Installing Dependencies

pip install torch numpy

๐Ÿ”ง Implementing Self-Attention in PyTorch

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.W_q = nn.Linear(embed_size, embed_size, bias=False)
        self.W_k = nn.Linear(embed_size, embed_size, bias=False)
        self.W_v = nn.Linear(embed_size, embed_size, bias=False)
        self.scale = torch.sqrt(torch.tensor(embed_size, dtype=torch.float32))

    def forward(self, x):
        Q = self.W_q(x)  # Query matrix
        K = self.W_k(x)  # Key matrix
        V = self.W_v(x)  # Value matrix

        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        attention_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output, attention_weights

โž• Running the Code

embed_size = 8  # Example embedding size
sequence_length = 5  # Example sequence length
batch_size = 1

# Random input tensor (batch_size, sequence_length, embed_size)
x = torch.randn(batch_size, sequence_length, embed_size)

# Initialize self-attention layer and pass the input
self_attention = SelfAttention(embed_size)
output, attention_weights = self_attention(x)

print("Self-Attention Output:\n", output)
print("Attention Weights:\n", attention_weights)

This implementation processes a sequence of 5 words, each represented by an 8-dimensional embedding. The output represents contextualized word embeddings, where each wordโ€™s representation depends on other words in the sequence.


๐ŸŽฏ Summary


๐Ÿ“– References