Attention

· Zach Ocean

Let’s implement some attention modules.

The 3Blue1Brown video on attention is really good for intuition. Sebastian Raschka’s “Build a Large Language Model” is really nice too for more in-the-weeds implementation details without any tricky math.

Here’s the intuition:

  • we’ve got some sequence; say it’s the sentence “the quick brown fox jumps over the lazy dog”. For the sake of simplicity let’s pretend that each word is a single token.
  • each token in the sequence wants to know which other tokens in the sequence to pay attention to; for example, “fox” should probably pay attention to “quick” and “brown” in order to understand that we are dealing not just with any fox, but a quick and brown one.
  • each token has a “query”, which is kind of like that token’s question that it poses to all the other tokens: “how much should I pay attention to you?”
  • each token has a “key”, which is like a token answering, “here’s how much you should pay attention to me!” We take the dot product of the query and key, and then normalize the resulting dot products (with softmax) so that we’ve got a probability distribution of how much a token should pay attention to the other tokens.
  • and what, specifically, should we pay attention to in a given token? The “value” is what we pay attention to - it can be thought of the actual information associated with each token. For a given token ii, we do a weighted sum of all the token values based on the attention weights for token ii.

This gives us simple self-attention.

import torch
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline
# self attention, non-causal, single head
class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.attention_dim = d_out
        self.w_q = torch.nn.Linear(d_in, d_out)
        self.w_k = torch.nn.Linear(d_in, d_out)
        self.w_v = torch.nn.Linear(d_in, d_out)

    def forward(self, x):
        # B: batch size
        # T: context length
        # E: embedding dimension

        B, T, E = x.shape
        queries = self.w_q(x)
        keys = self.w_k(x)
        values = self.w_k(x)
        assert queries.shape == keys.shape == values.shape == (B, T, self.attention_dim)

        attention =  torch.matmul(queries, keys.transpose(1, 2))
        assert attention.shape == (B, T, T)
        attention = torch.softmax(attention * (self.attention_dim ** -0.5), dim=-1)
        assert torch.all(torch.isclose(attention.sum(dim=-1), torch.ones(T)))
        output = torch.matmul(attention, values)
        assert output.shape == (B, T, self.attention_dim)
        return output
inputs = torch.tensor(
    [[0.43, .15, .89], # your
     [.55, .87, .66], # journey
     [.57, .85, .64], # starts
     [.22, .58, .33], # with
     [.77, .25, .10], # one
     [.05, .80, .55] # step
     ]
)
inputs, inputs.shape
batch = torch.stack([inputs] * 5, dim=0)
print(batch.shape)
sa = SelfAttention(d_in = batch.shape[-1], d_out = 10)
out = sa(batch)
assert out.shape == (5, 6, 10)
torch.Size([5, 6, 3])

There are a couple fairly simple extensions to simple self-attention to make things work with language models.

Extension: causal self attention

Since language models generate outputs one token at a time, it doesn’t make sense for tokens to consider future tokens in their attention mechanism.

To account for that, we can apply a mask so that when we do the weighted sum of all the value vectors to get the output for token ii, the weights applied to any value jj with j>ij > i are 0. In practice that means applying a mask by constructing a lower triangular matrix of all 1s, and multiplying it by the matrix which holds the dot products of the queries and keys.

Extension: multiheaded attention

We can think of the queries, keys, and values as containing not all contexttual information about a sequence, but just some aspect of that context. Kind of like how different layers in an image CNN can capture different visual features of images (see Visualizing and Understanding Convolutional Networks, 2014), we can use multiple sets of (queries, keys, values) to capture different aspects of the context.

As an illustrative example, the smallest version of GPT2 used 12 attention heads in its self-attention block.


class CausalSelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, attn_bias=True):
        super().__init__()
        self.w_q = torch.nn.Linear(d_in, d_out, bias=attn_bias)
        self.w_k = torch.nn.Linear(d_in, d_out, bias=attn_bias)
        self.w_v = torch.nn.Linear(d_in, d_out, bias=attn_bias)

        self.d_in = d_in
        self.d_out = d_out

        self.register_buffer("causal_mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        B, T, E = x.shape
        assert E == self.d_in
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)
        assert q.shape == k.shape == v.shape == (B, T, self.d_out)
        attn = q @ k.transpose(-2, -1)
        assert attn.shape == (B, T, T)

        # no guarantee that our input x takes up the full context_length, so trim down causal_mask
        mask = self.causal_mask[:T, :T]
        attn.masked_fill_(mask.bool(), -torch.inf)
        attn = torch.softmax(attn * (self.d_out ** -0.5), dim=-1)
        assert torch.all(torch.isclose(attn.sum(-1), torch.ones(T)))
        out = attn @ v
        assert out.shape == (B, T, self.d_out)
        return out
        



class MultiheadAttention(torch.nn.Module):
    def __init__(self, d_in: int, d_out: int, num_heads: int, context_length: int, qkv_bias=True):
        super().__init__()
        if d_out % num_heads != 0:
            raise ValueError("d_out must be divisible by num_heads for multiheaded attention")

        self.d_out = d_out
        self.num_heads = num_heads
        self.d_head = d_out // num_heads
        self.w_q = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_k = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_v = torch.nn.Linear(d_in, d_out, bias=qkv_bias)

        self.dropout = torch.nn.Dropout(0.1)

        self.out_proj = torch.nn.Linear(d_out, d_out)

        self.register_buffer("causal_mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        B, T, E = x.shape
        queries = self.w_q(x)
        keys = self.w_k(x)
        values = self.w_v(x)
        assert queries.shape == keys.shape == values.shape
        assert queries.shape == (B, T, self.d_out)
        """
        [
          [[0, 1, 0, 1],
           [2, 3, 2, 3]]
        ],
        [
          [[4, 5, 4, 5],
           [6, 7, 6, 7]]
        ] shape (2, 2, 4) --> shape (2, 2, 2, 2)
        Reshaped for multiple heads, but since the matrix we want is actually [[0, 1], [2, 3]] we need to do a transpose.
        [
          [[[0, 1],
            [0, 1]],
           [[2, 3],
            [2, 3]]
          ]
        ],
        [
          [
            [[4, 5],
             [4, 5]],
            [[6, 7],
             [6, 7]]
          ],
          [
          ]
        ]
        """
        queries = queries.reshape(B, T, self.num_heads, -1).transpose(1, 2)
        keys = keys.reshape(B, T, self.num_heads, -1).transpose (1, 2)
        values = values.reshape(B, T, self.num_heads, -1).transpose(1, 2)
        assert values.shape == queries.shape == keys.shape == (B, self.num_heads, T, self.d_head)
        attn = queries @ keys.transpose(2, 3)
        assert attn.shape == (B, self.num_heads, T, T)
        mask = self.causal_mask[:T, :T]
        attn.masked_fill_(mask.bool(), -torch.inf)
        attn = torch.softmax(attn * (self.d_head ** -0.5), dim=-1)
        attn = self.dropout(attn)
        output = attn @ values
        assert output.shape == (B, self.num_heads, T, self.d_head)
        output = output.transpose(1, 2).flatten(-2)
        assert output.shape == (B, T, self.d_out)
        output = self.out_proj(output)
        return output

Tricky things to watch out for when implementing:

  • Make sure the shapes are correct when doing QKTQK^T in multiheaded attention. It’s easy to mess this up, I found writing it out on paper with little matrices with a clear view on which part of the full queries tensor belongs to each head.

  • In multi-headed attention make sure you scale the attention scores that are input by softmax by 1dhead\frac{1}{\sqrt{d_{head}}}

  • Don’t forget the final projection on the values.

  • Don’t forget Dropout!

  • Don’t mess up the causal masking. It’s easy to do if you’re not careful. You could mask with 0s after applying softmax, but masking with 0s before applying softmax is not valid because exp(0)=1\exp(0) = 1, not 0, not to mention that attention scores (the qkq \cdot k values) could be negative. So if you mask before applying softmax, you need to mask with -\infty, since exp()=0\exp(-\infty) = 0. Also, the torch APIs are not super intuitive, at least for me. torch.masked_fill_ will apply an in-place mask where the mask is True. torch.tril(..., diagonal=1) has 0s across the main diagonal.

csa = CausalSelfAttention(d_in=batch.shape[-1], d_out=10, context_length=batch.shape[1])
out = csa(batch)
print(out.shape)
out[0]
torch.Size([5, 6, 10])





tensor([[ 0.2054,  0.5257, -0.5775, -0.2834, -0.0609, -0.4216,  0.1994,  0.1902,
          0.6867,  0.1970],
        [ 0.3684,  0.4443, -0.5592, -0.3495, -0.2503, -0.5115,  0.2472,  0.0793,
          0.9065,  0.2939],
        [ 0.4259,  0.4201, -0.5537, -0.3667, -0.3142, -0.5477,  0.2652,  0.0475,
          0.9835,  0.3291],
        [ 0.4567,  0.3811, -0.5431, -0.3463, -0.3845, -0.5307,  0.2477,  0.0206,
          0.9687,  0.2794],
        [ 0.4708,  0.4200, -0.5508, -0.2757, -0.3772, -0.5811,  0.2584,  0.0929,
          0.9694,  0.2843],
        [ 0.4792,  0.3727, -0.5398, -0.3112, -0.4214, -0.5391,  0.2420,  0.0302,
          0.9618,  0.2543]], grad_fn=<SelectBackward0>)
mha = MultiheadAttention(d_in=batch.shape[-1], d_out=10, num_heads=2, context_length=batch.shape[1])
out = mha(batch)
print(out.shape)
out[0]
torch.Size([5, 6, 10])





tensor([[-0.8100, -0.1408, -1.3714, -0.5464, -0.2073, -1.0994,  0.4570, -0.4845,
         -0.4971, -0.5315],
        [-0.6426,  0.0212, -0.6967,  0.0186, -0.0093, -0.4048,  0.1778, -0.3883,
         -0.2246,  0.2713],
        [-0.8835,  0.1633, -1.4187, -0.4690, -0.2264, -0.9220,  0.3543, -0.2235,
         -0.3431, -0.3388],
        [-0.3317, -0.0097, -0.2685, -0.0109, -0.1006, -0.3995,  0.3175, -0.2946,
         -0.2642,  0.1643],
        [-0.3648,  0.0114, -0.2059,  0.0931,  0.0510, -0.2593,  0.1591, -0.2490,
         -0.1970,  0.3774],
        [-0.2307,  0.0296, -0.0651,  0.0476, -0.1115, -0.3113,  0.2992, -0.2186,
         -0.2413,  0.2497]], grad_fn=<SelectBackward0>)