GPT From Scratch #4: The Mathematical Trick Behind Self Attention

Welcome to Part 4 of our GPT From Scratch series, inspired by Karpathy’s  Let’s build GPT: from scratch, in code, spelled out.

Links to previous and upcoming posts of the series:

In Part 2 we explained how to create a training set from Shakespeare’s works. Part 3 then introduced a basic bigram model, predicting the next character based solely on its predecessor. 

However, this approach is fundamentally limited. To achieve the capabilities of models like GPT, we need to go beyond just one character (or word, or token) back, and consider the broader context of the preceding sequence. This vital interaction is enabled by self-attention, a mechanism underpinned by a very elegant mathematical trick for efficient context awareness. 

Let’s dive in.

The simplest kind of communication in our batch

Back to our good old batch of size BxTxC (as explained in Part 3, section “the logits”).

As a reminder, each line there are T consecutive characters from the training set (a.k.a an example), and each such character is associated with its embeddings (an array of numbers of fixed size C), and you have B such examples in the batch.

Our goal in order to illustrate communication between characters is the following:  for each character at index i (in an example of the batch) do an average on the embeddings of all the previous characters.

One would wonder why doing just an average is interesting, but we’ll see later that it will be the basis for building the powerful self attention mechanism.

So, let’s illustrate with an example what we mean by doing an average of the embeddings. 

We first generate a random batch of size 4,8,2 (i.e. 4 examples, each with 8 characters, and each with an embedding of size 2).

Let’s look at the first line example:

Those numbers represent the 8 characters (one character per line) of the first example, and for each, their embeddings (2 numbers).

So our goal is to produce a tensor such that for each line we get the average of all the previous numbers in the respective column.

In our example, we’re looking to get the tensor below. Look for example at the second line, the first number there is 0.3507, which is the average of the first number of the two first lines in the original tensor above (0.0783 and 0.6231). Same for all the other numbers in the resulting tensor, they are the average of all the previous ones.

Now the challenge is: how to produce that in a very efficient way so it can scale.

The brute force way

It is always useful in every problem to start with the brute force solution as a baseline.

Here is Karpathy’s code for the brute force way of solving this:

A few notes on that code:

  • xprev is referring to the bth example of the batch, and to all the characters from 0 to t . And for each of those, you have the embedding of size C, and thus xprev is of dimension (t,C)
  • Then, when you do the torch.mean(xprev,0), it is actually doing the average on each channel of the embedding (in our case, there are 2).
  • Bow means bag of words (a common term when just averaging stuff out)

And sure enough, it works and produces the right result. The problem is that it is highly inefficient and won’t scale both at training and at inference when talking about huge models like GPTs.

The trick: a (very) cleaver matrix multiplication

Now let’s describe the trick that enabled scaling self attention, and that arguably is at the core of the generative AI revolution.

First, a small reminder on how matrix multiplication works.

Each element in c, is obtained by summing the dot product of the corresponding row and column in a and b

E.g., to obtain in c the result of the 2nd row,  and 1st column, you just do the dot product of the 2nd row in a (which is [4,6,5]) and the 1st column in b (which is [3,4,4]) . And thus [4,6,5] . [3,4,4] = 3*4+4*6+4*5 = 56, which is indeed what we see in c at 2nd row and 1st column. 

Now, in the example above, if instead of multiplying b by a random matrix, we multiply it by a triangular matrix, something magic happens:

Can you see what happened? It turns out that now each element in c, is the sum of all the previous elements from b! 

For instance, the 7 in c , which is 1st column, 2nd row, corresponds to the sum of all the elements of the 1st column in b, and up to the second row: 3+4 = 7. This works for every element in c. 

Why is it magical and why is it helping in what we try to achieve?

Because if you add a last ingredient to the magic, and instead of multiplying by the triangular matrix, you multiply by a matrix that is the triangular matrix, but with the number divided by the sum of the row, you get the exact average result we wanted! See it below with the associated code

If you want to understand what the line

is doing, you can read my blog post on tensor broadcasting.

Those are two dimensional matrices, let’s get back to our initial problem which is a tensor of dimension BxTxC. Let’s see how it works well for those dimension:

  • We’re multiplying a TxT matrix with a tensor BxTxC .
  • Pytorch will create a B dimension to the TxT matrix, which will yield a multiplication between BxTxT and BxTxC, which for each batch element , will do a TxT times TxC multiplication (in parallel) , which will yield a BxTxC result. In code, it gives:

Does it really work? 

To check it, we can compare the result of xbow2 with xbow that we obtained in the brute force way section:

Sure enough, it magically works and we indeed obtain the exact same output in both methods 🤯 .

The difference? With the matrix multiplication, it is incomparably more efficient and is thus a game changer given the scale of what it takes to build GPT.

A softmax version

We saw that the key part of the trick is to produce this normalized triangular tensor (the wei in the code above ).

Turns out there is an equivalent way to produce it using the softmax function!

This works because softmax is actually a normalization layer where you exponent all elements and then divide them by the sum.

Why use softwax instead of what we did in the previous example? To be honest, i’m not sure, but i assume it is a matter of elegance and interpretation, because, as Karpathy explains, the triangular matrix before applying softmax looks like this:

And if you interpret this matrix as the “communication allowance” for each element in the batch (because this is what it will end up being through the matrix multiplication => each line of the batch contains 8 training examples ) then it says that it is only allowed to communicate with past elements, and for future element, the communication is forbidden (because when we’ll generate characters, we’ll have access only to the previous ones, and not the upcoming ones).

And, magically enough, it just works, and produces the same result as in the brute force way.

What was achieved and what’s next?

We started with our standard batch of examples of dimension (B,T,C), and our goal was to produce another batch (same dimension), but such that for each character at index i (in an example of the batch) do an average on the embeddings of all the previous characters.

We discovered it could be done in a crazy effective way by simply multiplying the tensor of the batch by a triangular normalized matrix. 

This is how it looks like in one example of our batch:

And this works exactly the same if we apply it to the whole batch of examples: you start with the tensor (B,T,C) and you end up with a tensor of the same dimension, but this time with all the averaged examples.

Why does it matter? 

Because:

  1. It is an extremely efficient operation and that’s what will allow scaling GPT to huge amount of data
  2. We’ll replace the simple averaging by a very smart aggregation of all the previous characters (the context)
  3. The operation will be exactly the same: a simple matrix multiplication, we’ll just replace the triangular matrix by the smart aggregation.

We thus now illustrated the foundation of what lies behind GPT.

From there, we’ll just introduce an additional fundamental concept, called “positional encoding” and then we’ll implement the famous self-attention mechanism which is the backbone of GPT. 

So let’s do it and dive in Part 5: Positional Encodings.


Posted

in

, ,

by