GPT From Scratch #6: Coding Self Attention

Welcome to Part 6 of our GPT From Scratch series, inspired by Karpathy’s Let’s build GPT: from scratch, in code, spelled out.

Links to previous and upcoming posts of the series:

In Part 2 we explained how to create a training set from Shakespeare’s works. Part 3 introduced a basic bigram model, predicting the next character based solely on its predecessor. Part 4 explained how a very clever matrix multiplication enables doing some operations (like average) on the previous character, in a very efficient way and Part 5 was about the positional encoding trick.

Now we’ve reached a critical point of that series: the implementation of the self attention mechanism, which is at the very heart of the transformer architecture. The original paper by Google which introduced the Transformer architecture is called “Attention Is All You Need” for a reason.

What is self-attention 

In a separate post (not part of that series) called Decoding Transformers: The Neural Nets Behind LLMs and More,  I’m describing the history of transformers, why they were a game changer, and a deep dive on their main component: the self attention mechanism (that we’ll implement now).

So, as a prerequisite, it is highly recommended to read it. But if you need to understand one thing about the intuition behind self attention, it is this:

The main purpose of the self attention mechanism is to adapt the vectors/embeddings of the words based on the context of the sentence/prompt.  

If this sentence doesn’t sink in very well for you yet, read on and hopefully it will be clearer (and if not, please read the separate post i mentioned above). 

How does it connect to the previous posts of that series?

As we explained in Part 2, section “Framing the Prediction Task”, our goal is to predict the next character, using the previous ones.

Regardless of how complex you do it, the input/output format is always the same: you start with a bunch of characters (the “context” or “prompt”), you do some calculation that outputs a vector called the logits , that is the size of your vocabulary (in our case 65, which is the number of distinct characters), that you ultimately transform (using the softmax function) into probabilities that allows you to pick the next character.  

In part 3, we introduced how to compute those logits using only the single previous character as the context (which is obviously very limited), using a bigram model.

Now we want to use all the previous characters (that we can call “the context”, or even, the “prompt”) in order to predict the next one. Self attention will provide us a way to do it in a very smart, massively parallel and effective way, by combining the embeddings of each previous character in a way that is putting more weight on the relevant characters.  

Let’s visualize at a very high level the steps showing how the next character is generated at prediction time (once we already learned the weights of the models):

  • It starts with the prompt (in the example below, three letters V, E, and R), with embeddings of size 4 each
  • Then self attention does its “contextualization magic” on the embeddings, producing “Contextualized Embeddings”
  • We then end up with an additional linear layer and softmax to give each character of our vocabulary a probability of being the next character (in our case, it is 65 characters, each one getting its probability). 
  • Then the predicted next character is sampled according to those probabilities, and in our example the character B is picked, which seems a logical result when the prompt was VER, thus forming the word VERB.

As we said, this is at prediction time. At training time, we’re doing it on a bunch of B examples in parallel, and we’re learning relevant weights. Let’s dive into it, deciphering self-attention code, line by line, and understand how it connects at training and prediction time. 

Deciphering self-attention code, line by line 

The main diagram illustrating what is happening in self-attention is the one below (originally from this great video), where the learned weights (during the training process) are the matrices Mk , Mq and Mq . Again, best is to read the detailed explanations in my post , but in a nutshell:

  • Taking our example above, we have 3 characters in the context
  • V2 represents the embeddings of the second character (in our example above it is ).
  • V2 ends up being y2 which is simply a weighted version of V2 using all the other characters of the context as normalized weights.
  • The Mk , Mq and Mv matrices represent Keys, Queries and Values, and will be the learned weights on the process.

So here is how Karpathy implements this self attention layer:

Wow. Barely 20 lines of code, but maybe the most important 20 lines of the AI revolution. 

A lot to unpack. Let’s dive in and explain it all, line by line.

  • First, the main object that we’ll manipulate in that code (the variable x in the code) is our good old batch of size BxTxC (as explained in Part 2, section “the logits”). 
  • As a reminder, each line there are T consecutive characters (a.k.a an example) from the training set, and each such character is associated with its embeddings (an array of numbers of fixed size C), and you have B such examples in the batch.
  • In the init, you have the initialization of the learnable weights of the transformer, namely the key, query and value. More details and intuitions on what they mean in my other post but in high level:
  • Each token (in our case a character) in the batch will emit three vectors, the key, the query and the value:
    • The query represents: “what am i looking for”
    • The key represents: “what do i contain”
    • The value is “what i will communicate” 
  • Note the dimension of those. E.g. key is a linear layer of dimension (n_embed,head_size) where n_embed is just C and head_size is a hyperparameter that can be set during hyperparameter tuning at training time.
  • Let’s now explain the forward function line by line. 
  • is simply about getting the important dimension of our batch x , that we explained just above
    • First, let’s understand the dimensions there. The dimensions of x are B,T,C (in the example above B=8, T=3 and C=4). Think of it as BxT vectors, each of dimension (1,C)  . 
    • What key(x) does, is a matrix multiplication between each of those BxT vectors and the same tensor: key (where key is one of the learned weights as explained above). key is dimension (C,head_size).
    • So, when you multiply BxT vectors of dimension (1,C) with a matrix ( key ) of dimension (C,head_size) , you end up with BxT vectors of dimension (1,head_size), and thus a tensor of dimension (B,T,head_size) .
    • Note that in his comment, Andrej Karpathy wrote (B,T,C) , but he did a small typo and meant (B,T,head_size)
    • The exact same thing happens with the query object . query(x) produces a tensor of dimension (B,T,head_size) .
  • So basically, what you have in hand now are two tensors , k and q, each of dimension (B,T,head_size)  . 
  • This line hides quite a lot of magic behind the scene. Let’s unpack it.
  • Each line in the batch of examples is like a sentence of T tokens, and for each such sentence, you’d like to create an affinity matrix between each token. This matrix is of size TxT . 
  • The way to create this affinity matrix is to do a dot product between the keys and queries of each of those tokens, that are represented by  k and q .
  • For doing a dot product between two tensors of dimension (B,T,head_size) , you have to switch the two last dimensions of the second tensor. This is what does. As a reminder, the symbol “@” in pytorch does tensor multiplication.
  • You end up with a tensor of dimension (B,T,T) as explained in the comment
  • As for the  part, this is just a scaling trick, and I explain the intuition behind it in my other blog post (look for “scaling embeddings”).
  • Btw, this line of code can be represented mathematically by this formula:
  • This part might be intimidating, but in fact, it is highly similar to what we introduced in part 4 of this series : The Mathematical Trick Behind Self Attention. 
  • The main difference is that now we’re not multiplying our batch with a simple triangular matrix (which would do only a simple average). 
  • We’re now multiplying by the weights that resulted in the dot product of all the tokens in the example! But otherwise the mathematical trick is exactly the same.
  • Let that sink in for a second. When we multiplied the queries and the keys, what we got is the affinity matrix between each token of the examples, giving one TxT matrix per example in the batch.
  • This TxT matrix represents the affinities between each token, and thus by applying the mathematical trick, we’re transforming the embedding of each token, into a weighted average, the weight being the affinity scores.
  • The only differences you might notice from the mathematical trick are:
    • The dropout, which is well known simple trick to prevent neural networks from overfitting
    • We don’t multiply the wei tensor directly with x, but with value(x) where value is also a learned linear layer (see the init function) is simply an additional layer on top of the key*query raw affinity score.
  • Bottom line:  Each token embedding, of each example of the batch, is now transformed into a weighted version, with the weights being the affinities between each tokens. Which is exactly what was illustrated in the diagram above
  • And that can be also captured in that one formula, which is at the center of the initial paper Attention is all you need that introduced it all:

Interpreting self-attention as a communication mechanism 

Karpathy made a very interesting note about how to think about self attention.

In some sense, each token (in our case, character) has some vector of information (the embeddings) and has to aggregate it via a weighted sum from all the other tokens that are connected to it. This weighted sum can be interpreted as a kind of communication between each token.

It can even be seen as a directed graph, where each token points to itself and all the previous tokens. For our 8 characters it can look like this (graph created with graphviz, with code created by Gemini).

Note that it is only because we’re an autoregression setting, where each character can only see the previous characters (the prompt) that a node cannot point to another one further in the sequence, but in other settings (where you’re e.g. allowed to see the whole sentence, like for analyzing a text for sentiment or anything else) then you could have a full connected graph, and every token can communicate with any other.

Connecting it all to our model

In part 3, we explained the bigram model. Now that we have the implementation of a self attention head, we can replace the basic embeddings table that we had in the bigram model by:

1. The positional encoding trick (explained in part 5)  and 

2. the attention head described above.

It gives this:

By now, this code should look much more self explained to you:

  • It is the exact same structure as the bigram model (see Part 3)
  • In the init we add the position embedding table (see Part 5)
  • We also add a layer normalization layer (we’ll explain that in our next post)
  • And a final linear layer of dimension (n_emb, vocab_size) so we end up with logits  
  • About the last few lines, the targets variable is None when the model is invoked for prediction/inference (and not for training), more details on that in the next section. 

Let’s draw the dimensions of the forward pass to see how things work out very nicely.

From there, what you do next really depends on whether you are in the training phase (i.e. tuning the actual parameters of the model) or in the inference phase (i.e. actually predicting the next character). Let’s detail that difference as it is important.

Training vs. Inference

Both in training and inference, the key element that you need are the logits, which are the raw predictions (one number per possible prediction, in our case, vocab_size numbers). Let’s see how those logits are used in Training vs. Inference phase. 

Training 

In that phase, your goal is to tune the parameters of the model. As explained above, the parameters are the Keys, Queries and Values. In deep learning, the way you tune those is via backpropagation. To do backpropagation, you need two things: the logits, the labels (or targets) and a loss function. This is all captured in those few lines from the previous snippet. 

The details of how those few lines beautifully work is detailed in one of my “Deep Learning Gymnastics” posts: Master Your (LLM) Cross Entropy . Below is an excerpt from that post. Read it fully for more details.

From there, the training loop to fine tune the weights from backpropagation is rather straightforward:

The get_batch function is what we explained in Part 2: The Training Set .

Inference

Once your model was trained and that your keys, queries, values and other parameters like embeddings have the proper weights learned from backpropagation, you can start actually doing inference, which in our case corresponds to generating Shakespeare text, one character at a time. The code to do it was presented in part 3, and we’ll put it again here for completeness:

Wow, we now deeply understand how one of the most important pieces of code of gen AI revolution is working!

But to achieve the actual end-goal (of this series) of building a GPT, there are a bunch of very important optimizations that will take the loss to new heights (well, maybe “canyon” is a better term, as the lower the loss, the better). 

Let’s now dive into the grand final part of this series: Building a GPT.


Posted

in

, ,

by