Attention
Let’s implement some attention modules.
The 3Blue1Brown video on attention is really good for intuition. Sebastian Raschka’s “Build a Large Language Model” is really nice too for more in-the-weeds implementation details without any tricky math.
Here’s the intuition:
- we’ve got some sequence; say it’s the sentence “the quick brown fox jumps over the lazy dog”. For the sake of simplicity let’s pretend that each word is a single token.
- each token in the sequence wants to know which other tokens in the sequence to pay attention to; for example, “fox” should probably pay attention to “quick” and “brown” in order to understand that we are dealing not just with any fox, but a quick and brown one.
- each token has a “query”, which is kind of like that token’s question that it poses to all the other tokens: “how much should I pay attention to you?”
- each token has a “key”, which is like a token answering, “here’s how much you should pay attention to me!” We take the dot product of the query and key, and then normalize the resulting dot products (with softmax) so that we’ve got a probability distribution of how much a token should pay attention to the other tokens.
- and what, specifically, should we pay attention to in a given token? The “value” is what we pay attention to - it can be thought of the actual information associated with each token. For a given token , we do a weighted sum of all the token values based on the attention weights for token .
This gives us simple self-attention.
import torch
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# self attention, non-causal, single head
class SelfAttention(torch.nn.Module):
def __init__(self, d_in, d_out):
super().__init__()
self.attention_dim = d_out
self.w_q = torch.nn.Linear(d_in, d_out)
self.w_k = torch.nn.Linear(d_in, d_out)
self.w_v = torch.nn.Linear(d_in, d_out)
def forward(self, x):
# B: batch size
# T: context length
# E: embedding dimension
B, T, E = x.shape
queries = self.w_q(x)
keys = self.w_k(x)
values = self.w_k(x)
assert queries.shape == keys.shape == values.shape == (B, T, self.attention_dim)
attention = torch.matmul(queries, keys.transpose(1, 2))
assert attention.shape == (B, T, T)
attention = torch.softmax(attention * (self.attention_dim ** -0.5), dim=-1)
assert torch.all(torch.isclose(attention.sum(dim=-1), torch.ones(T)))
output = torch.matmul(attention, values)
assert output.shape == (B, T, self.attention_dim)
return output
inputs = torch.tensor(
[[0.43, .15, .89], # your
[.55, .87, .66], # journey
[.57, .85, .64], # starts
[.22, .58, .33], # with
[.77, .25, .10], # one
[.05, .80, .55] # step
]
)
inputs, inputs.shape
batch = torch.stack([inputs] * 5, dim=0)
print(batch.shape)
sa = SelfAttention(d_in = batch.shape[-1], d_out = 10)
out = sa(batch)
assert out.shape == (5, 6, 10)
torch.Size([5, 6, 3])
There are a couple fairly simple extensions to simple self-attention to make things work with language models.
Extension: causal self attention
Since language models generate outputs one token at a time, it doesn’t make sense for tokens to consider future tokens in their attention mechanism.
To account for that, we can apply a mask so that when we do the weighted sum of all the value vectors to get the output for token , the weights applied to any value with are 0. In practice that means applying a mask by constructing a lower triangular matrix of all 1s, and multiplying it by the matrix which holds the dot products of the queries and keys.
Extension: multiheaded attention
We can think of the queries, keys, and values as containing not all contexttual information about a sequence, but just some aspect of that context. Kind of like how different layers in an image CNN can capture different visual features of images (see Visualizing and Understanding Convolutional Networks, 2014), we can use multiple sets of (queries, keys, values) to capture different aspects of the context.
As an illustrative example, the smallest version of GPT2 used 12 attention heads in its self-attention block.
class CausalSelfAttention(torch.nn.Module):
def __init__(self, d_in, d_out, context_length, attn_bias=True):
super().__init__()
self.w_q = torch.nn.Linear(d_in, d_out, bias=attn_bias)
self.w_k = torch.nn.Linear(d_in, d_out, bias=attn_bias)
self.w_v = torch.nn.Linear(d_in, d_out, bias=attn_bias)
self.d_in = d_in
self.d_out = d_out
self.register_buffer("causal_mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
B, T, E = x.shape
assert E == self.d_in
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
assert q.shape == k.shape == v.shape == (B, T, self.d_out)
attn = q @ k.transpose(-2, -1)
assert attn.shape == (B, T, T)
# no guarantee that our input x takes up the full context_length, so trim down causal_mask
mask = self.causal_mask[:T, :T]
attn.masked_fill_(mask.bool(), -torch.inf)
attn = torch.softmax(attn * (self.d_out ** -0.5), dim=-1)
assert torch.all(torch.isclose(attn.sum(-1), torch.ones(T)))
out = attn @ v
assert out.shape == (B, T, self.d_out)
return out
class MultiheadAttention(torch.nn.Module):
def __init__(self, d_in: int, d_out: int, num_heads: int, context_length: int, qkv_bias=True):
super().__init__()
if d_out % num_heads != 0:
raise ValueError("d_out must be divisible by num_heads for multiheaded attention")
self.d_out = d_out
self.num_heads = num_heads
self.d_head = d_out // num_heads
self.w_q = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.w_k = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.w_v = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = torch.nn.Dropout(0.1)
self.out_proj = torch.nn.Linear(d_out, d_out)
self.register_buffer("causal_mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
B, T, E = x.shape
queries = self.w_q(x)
keys = self.w_k(x)
values = self.w_v(x)
assert queries.shape == keys.shape == values.shape
assert queries.shape == (B, T, self.d_out)
"""
[
[[0, 1, 0, 1],
[2, 3, 2, 3]]
],
[
[[4, 5, 4, 5],
[6, 7, 6, 7]]
] shape (2, 2, 4) --> shape (2, 2, 2, 2)
Reshaped for multiple heads, but since the matrix we want is actually [[0, 1], [2, 3]] we need to do a transpose.
[
[[[0, 1],
[0, 1]],
[[2, 3],
[2, 3]]
]
],
[
[
[[4, 5],
[4, 5]],
[[6, 7],
[6, 7]]
],
[
]
]
"""
queries = queries.reshape(B, T, self.num_heads, -1).transpose(1, 2)
keys = keys.reshape(B, T, self.num_heads, -1).transpose (1, 2)
values = values.reshape(B, T, self.num_heads, -1).transpose(1, 2)
assert values.shape == queries.shape == keys.shape == (B, self.num_heads, T, self.d_head)
attn = queries @ keys.transpose(2, 3)
assert attn.shape == (B, self.num_heads, T, T)
mask = self.causal_mask[:T, :T]
attn.masked_fill_(mask.bool(), -torch.inf)
attn = torch.softmax(attn * (self.d_head ** -0.5), dim=-1)
attn = self.dropout(attn)
output = attn @ values
assert output.shape == (B, self.num_heads, T, self.d_head)
output = output.transpose(1, 2).flatten(-2)
assert output.shape == (B, T, self.d_out)
output = self.out_proj(output)
return output
Tricky things to watch out for when implementing:
Make sure the shapes are correct when doing in multiheaded attention. It’s easy to mess this up, I found writing it out on paper with little matrices with a clear view on which part of the full
queriestensor belongs to each head.In multi-headed attention make sure you scale the attention scores that are input by softmax by
Don’t forget the final projection on the values.
Don’t forget Dropout!
Don’t mess up the causal masking. It’s easy to do if you’re not careful. You could mask with 0s after applying softmax, but masking with 0s before applying softmax is not valid because , not 0, not to mention that attention scores (the values) could be negative. So if you mask before applying softmax, you need to mask with , since . Also, the torch APIs are not super intuitive, at least for me.
torch.masked_fill_will apply an in-place mask where the mask isTrue.torch.tril(..., diagonal=1)has 0s across the main diagonal.
csa = CausalSelfAttention(d_in=batch.shape[-1], d_out=10, context_length=batch.shape[1])
out = csa(batch)
print(out.shape)
out[0]
torch.Size([5, 6, 10])
tensor([[ 0.2054, 0.5257, -0.5775, -0.2834, -0.0609, -0.4216, 0.1994, 0.1902,
0.6867, 0.1970],
[ 0.3684, 0.4443, -0.5592, -0.3495, -0.2503, -0.5115, 0.2472, 0.0793,
0.9065, 0.2939],
[ 0.4259, 0.4201, -0.5537, -0.3667, -0.3142, -0.5477, 0.2652, 0.0475,
0.9835, 0.3291],
[ 0.4567, 0.3811, -0.5431, -0.3463, -0.3845, -0.5307, 0.2477, 0.0206,
0.9687, 0.2794],
[ 0.4708, 0.4200, -0.5508, -0.2757, -0.3772, -0.5811, 0.2584, 0.0929,
0.9694, 0.2843],
[ 0.4792, 0.3727, -0.5398, -0.3112, -0.4214, -0.5391, 0.2420, 0.0302,
0.9618, 0.2543]], grad_fn=<SelectBackward0>)
mha = MultiheadAttention(d_in=batch.shape[-1], d_out=10, num_heads=2, context_length=batch.shape[1])
out = mha(batch)
print(out.shape)
out[0]
torch.Size([5, 6, 10])
tensor([[-0.8100, -0.1408, -1.3714, -0.5464, -0.2073, -1.0994, 0.4570, -0.4845,
-0.4971, -0.5315],
[-0.6426, 0.0212, -0.6967, 0.0186, -0.0093, -0.4048, 0.1778, -0.3883,
-0.2246, 0.2713],
[-0.8835, 0.1633, -1.4187, -0.4690, -0.2264, -0.9220, 0.3543, -0.2235,
-0.3431, -0.3388],
[-0.3317, -0.0097, -0.2685, -0.0109, -0.1006, -0.3995, 0.3175, -0.2946,
-0.2642, 0.1643],
[-0.3648, 0.0114, -0.2059, 0.0931, 0.0510, -0.2593, 0.1591, -0.2490,
-0.1970, 0.3774],
[-0.2307, 0.0296, -0.0651, 0.0476, -0.1115, -0.3113, 0.2992, -0.2186,
-0.2413, 0.2497]], grad_fn=<SelectBackward0>)