GPT From Scratch #3: The Bigram Model

Welcome to Part 3 of our GPT From Scratch series, inspired by Karpathy’s  Let’s build GPT: from scratch, in code, spelled out.

Links to previous and upcoming posts of the series:

Now that we’ve created a training set, we can now start to train a first model. 

To illustrate the overall principles of how we’ll build a GPT, Karpathy starts with a simple Bigram model.

It is mind blowing how the structure and principles we’ll use to build such a simplistic model are exactly the same as what will take us up to a full GPT. 

Read on.

Bigram ?

First, let’s describe very briefly what a bigram model is. As our goal is to predict the next character, one of the simplest and most naive ways to predict it would be to basically check how often 2 characters occur together in the data (in our case, in Shakespeare literature). For instance, if you have the letter ‘t’ , in english the most likely next letter would be ‘h’ based on the frequency of occurrences of two letters together, as illustrated in that old yet great post by Peter Norving exposing a bi-gram frequency table for English.

Peter Norvig’s English Bi-Gram table

With such a table, one naive yet working way of generating the next character would be based on the current character and drawing randomly the next one according to that table distribution. 

Bigrams in a neural net?

In another amazing separate video, Karpathy illustrates how to actually do that not only with a direct approach of just counting occurrences in a bigram table, but how to learn it directly using a neural net. Of course, it doesn’t give better results (because eventually, you only use one character as the context to guess the next one), but learning it from a neural net is without any comparison more flexible and powerful as it sets the ground for extending it up to a full GPT model 🤯, as it will be illustrated along the posts.

From here, we assume the reader has some minimal understanding of PyTorch, and how neural nets are working, what is a forward pass, backpropagation and a loss function. But if not, you can watch this video or just read on and hopefully you’ll get the idea on the way.

So let’s jump into Karpathy’s implementation of the bigram model as a neural net in pytorch.

Let’s dive into each important line/component of that code. Each section below highlights the line(s) of code of interest and explains it.

The token embedding table

Usually, embeddings refers to a latent compacted representation of a concept (word, image, others). But here, it can actually be interpreted simply as the NxN bigram matrix.  N (the vocab_size) is the number of possible characters, so e.g. 26 for English characters, but will be 65 in our case as we consider some punctuations and lower/upper case as different characters (see our previous post on the training set).

The logits: an important BxTxC tensor

In this section, we’ll dive into this line:

Let’s explain.

In machine learning, and in classification tasks in particular, logits can be interpreted as the raw score a model can give to each of the possible classes. In our case, given we try to predict the next character, the possible classes are each of the possible characters of the vocabulary, so 65 different options. 

So the logits are the raw score the model gives to each of those 65 options, before normalizing them into probabilities. In the case of bigrams, one could expect this raw score to be the count of occurrences of the two characters (the current one + the one we consider to predict next).

So, why does token_embedding_table(idx) give us the logits?  

To understand it, you can read my blog post about tensor indexing , but the basic idea (illustrated in the picture below) is that idx represents the batch of examples (as we described it in the training set post, as seen on the left of the picture below), and token_embedding_table represents the embedding matrix (each row is an embedding representation of the corresponding character in the vocabulary, in the middle of the picture below), and doing token_embedding_table(idx) simply returns the same initial batch idx but that time augmented with the embedding of each character (and thus it is a cube, on the right in the picture below)

Augmenting the batch with embeddings using tensor indexing

This BxTxC tensor will be our starting point in everything we’ll do to build a GPT. So what you need to remember is that each line there are T consecutive characters from the training set (a.k.a an example), and each such character is associated with its embeddings (an array of numbers of fixed size C), and you have B such examples in the batch. 

A note on the terminology of the dimensions:

  • B refers to Batch (number of lines/examples in the batch)
  • T to Time. It comes from the number of timesteps in the series of characters (as those concepts were initially introduced for time series in the context or Recurrent Neural Network,  and thus the terminology sticked also in the context of characters )
  • C refers to channels, or the size of the embedding. Channels were often used for the three RGB channels in pixel of a picture in image classification, but in general it just represents the number of latent features of each timestep. 

Cross-entropy loss

As Jeremy Howard likes to say, in order to train a neural net, you just need 3 things: an input (the batch in our case), a label (the next character in our case) and a loss function. Give those 3 things to a neural net, and it will learn something. 

So, what is the loss function in our case? We’ll use cross-entropy as it is often used in classifiers.

Some long time ago, I wrote a blog post about the theoretical aspects of entropy , but more recently, I wrote about explaining how it is done in LLM, as Karpathy demonstrates it in his video and code. You can find it here: Master Your (LLM) Cross Entropy .

In a nutshell, it explains how the three first lines of code with B,T,C above are just using tensor reshaping to unfold both the embeddings obtained in the previous section (the logits) and the labels (the next character we want to predict) so it can be easily passed to the cross_entropy function in PyTorch.

Basically, it looks like that before we pass it to the cross entropy function:

The training loop

So what we defined above is basically the model. It defines the logits, and the loss (by calculating the cross entropy between the logits and the label). 

Now we need to train it. To learn what? Basically to learn the parameters of the model, which in our case is the embedding table (the NxN token bigram matrix we discussed above in the section “The token embedding table”).

So here is a basic yet functioning  training loop:

Rather simple: we fetch a batch (using the getBatch function we described in the first post), we compute the logits on it, we do the backward pass magic (someday we’ll write another post on how it works). And that’s it. Rinse and repeat and observe the loss dropping. 

Generating Text From the model

Once the model is trained, this is how text is generated from the model:

A couple of notes on why this function is a bit overkilled for now, but still will set us up for future steps:

  • We don’t need batches (B) as we’ll start only from one example, and the idea is simply to start with some empty context, generate each time one character based on the logits converted into probabilities (using softmax) and sampling from it, and then concatenating them to the current context and redoing the same.
  • For the bigram model, this function is even more overkilled given that for generating the next character, the bigram model only looks at the last character, and here we’re passing it the full context from the beginning: it is just because this generic (for now overkilled) way of writing it will be reusable as-is later on when we’ll start using the full context of all the previous characters 

Then to generate text, it just looks like this:

Obviously, a Bigram model won’t be able to generate any Shakespeare-like content, but just for fun, let’s use the function above to generate some text:

As expected, it is completely Gibberish, and Shakespeare can still rest in peace for now, but still, it is interesting to observe that the model did capture some structure of how the text from Shakespeare is formed (someone speaking, then comma, then new line etc…). As a reminder, this is how the Shakespeare original text look like in the training set:  

Why this framework sets us up for GPT

So what is really amazing, is that all the building blocks described above for the bigram model, will stay exactly the same for a full GPT model: the training loop, the loss computation from the logits, the text generation.

So what will change? Well, basically the way the logits will be built. Instead of getting them using only the previous character, they will be built using the full context of the previous text (the “prompt”) and will go through the beautiful transformer architecture and its self attention component, that we already explained at length from the intuition perspective (in that post) and that we’ll build from scratch in the coming post of that series.

What’s next? We’ll now dive into the Mathematical Trick behind Self Attention.


Posted

in

, ,

by