Welcome to the 4th episode of our Deep Learning Gymnastics series.
Today, we’ll use all the skills learned in our previous lessons: tensor broadcasting, indexing and reshaping, to revisit one of the most famous and important loss functions of supervised machine learning (and deep learning): cross entropy.
LLMs? Yes, they are also based on it. We’ll actually get inspired (again) by Andrej Karpathy’s videos around building an LLM from scratch to illustrate how to manipulate the cross entropy function.
A short refresher on Cross Entropy
Entropy in general and Cross-entropy in particular are fascinating concepts that lie at the foundation of information theory. If you want to dive a bit into it and understand the links between the logistic regression cost function, Log Loss, Cross Entropy and Negative Log Likelihood and are not afraid of some maths formulas, you can read one of my old posts here.
But for today we’ll focus on the essence. Cross-entropy in ML is most often used as a cost function that measures the difference between a probability vector (one probability per predicted class) and a one-hot encoded label. Typically:
Here, O is the raw output of the neural network, often called logits. Then, before we apply the cross entropy formula, we typically pass those logits through the softmax function so it becomes a probability vector P, where each probability is the prediction of each of your multiple classes. And L is the one hot encoded vector representing the label.
So in our example, we can see that the cross-entropy is simply – log(0.6) i.e ~0.22 . As you note, the higher the probability for the correct class, the closer to 0 it will be (when probability is 1 for the correct class, then the cost will be -log(1) , which is 0). The lower the probability for the correct class, the bigger the cost (tending to infinity when the probability is 0). Note the figure above is inspired from this short great video.
Cross Entropy in LLMs
Large Langage Models (LLMs) core capability is to try predicting the next word (or more generally token) given a list of previous words/tokens. In a future blog post, we’ll describe precisely how the training set is built, but for the sake of this post, let’s illustrate a batch of the training set of an LLM on a picture and explain it:
In the episode #2 of our series, we explained what a batch is, and that those numbers represents the index of a token in the vocabulary. Assume our LLM is predicting the next token (out of 27 possible) given a context of max 3 tokens, this is how to read the figure above:
- The batch on the left represents 8 lines of three tokens each.
- Each token of the batch points to a tensor of size (27,1) representing the prediction of what the next token should be (one logit for each of the 27 possible tokens). So the batch tensor shape is (8,3,27).
- For instance, the (27,1) tensor in the figure represents the prediction for each of the 27 tokens, given the sequence of the three tokens 7,16,18.
- In that example, what is e.g. the logit prediction for the next token to be token 1? just look at index 1 of that vector. Here you go: ~0.55 (which seems rather high compared to others)
- The tensor on the right are the labels (the actual next token from the training set). It thus has the same shape as the batch, except that it does not contains prediction logits tensors, so just (8,3)
How to calculate the Cross Entropy on that single prediction logits (in the figure) against the actual label?
Simple, we just follow the diagram we gave above: we pass that vector through the softmax function, which will give us the (27,1) tensor P representing probabilities. Then we have L = (0,1,0,0,0,0,…,0) , and we just apply the cross entropy formula.
The Gymnastic Exercise
In the previous section, we explained how to compute the Cross Entropy for one single entry of the (8,3) batch of our example. But how to compute it for the whole batch? To do so, we need to calculate the exact same thing, but for the 8*3 = 24 possible cases.
Did you recognize the vector we had in the previous section’s figure? Yes, that’s the 7th one from the bottom.
So the gymnastic exercise is to take the initial batch with prediction tensor of shape (8,3,27) , stretch it out to the 8*3 = 24 prediction logits (which is a (24,27) tensor as in pic above), do the same for the label tensor, and from there, compute in parallel the cross entropy of the 24 couples of logits/label, and returns the mean of them as the result.
Solving it in PyTorch
First we need to generate all the input tensors:
- X, the batch with prediction, which is a (8,3,27) tensor
- Y, the labels, which is a (8,3) tensor.
The code below will produce the same numbers as the one exposed in the second figure of this post.
import torch
torch.manual_seed(18)
# creates the batch
random_tensor = torch.randint(low=0, high=26, size=(8,3))
# create random logits for each index in the vocabulary
L = torch.randn((27, 27))
#creating the labels
Y = torch.randint(low=0, high=26, size=(8,3))
# creating our batch (8,3,27). C.f https://www.philippeadjiman.com/blog/2023/12/23/deep-learning-gymnastics-tensor-indexing/
X = L[random_tensor]
To fully understand this code, please refer to the post #2 of this series about tensor indexing.
Note that in that other post, we created embeddings of size 4 as an illustration, while here, we’re having already the final logits (of size 27, which is the vocabulary size). In a fully implemented LLM, those logits will only come up after many steps (stay tuned for a future blog post about it).
Now, we’d like to use the PyTorch’s cross_entropy function. Reading the doc, we see it expects as input the actual logits to be in the second dimension, which corresponds exactly to what we described in the figure above: stretching out the input batch. And same for the labels. We actually learned how to do that with views in the post #3 of this series around tensor reshaping. So here you go:
#Reshaping before using cross_entropy. C.f https://www.philippeadjiman.com/blog/2024/02/03/deep-learning-gymnastics-tensor-reshaping/
B,T,C = X.shape
logits = X.view(B*T,C)
labels = Y.view(B*T)
With that, we’ll exactly obtain what we illustrated in our previous figure.
Now that we got our inputs in the proper shape, we can compute our cross entropy with the function:
import torch.nn.functional as F
F.cross_entropy(logits , labels)
Which gives 3.7759 . Yay! we computed the cross entropy of our LLM batch 💪
Calculating Cross Entropy “manually”
Turns out that once we have the logits and labels in the proper shape like we just did with views, then calculating cross entropy without using the PyTorch’s function is actually quiet simple, and is useful to understand what happens behind the scenes.
Here is an compact and elegant way to do it (credit again to the code from Karpathy’s videos ):
counts = logits.exp()
prob = counts / counts.sum(1,keepdims=True)
- prob[torch.arange(24),target].log().mean()
Surely enough, it returns the exact same result (3.7759) as when using the PyTorch function 🤩 .
So what’s going on in that code?
The first two lines are to transform the logits into probabilities using the softmax function, by simply first applying the exponential function and then dividing all logits by the sum of exponentials. Wonder what that keepdims=True means? Please read the post #1 of this series around tensor broadcasting
Now the last line is interesting.
Remember our initial figure. Let’s look again how cross entropy is calculated:
Given L is a one hot encoded vector, there will be only one 1, and thus the cross entropy is just about plucking out the right index in P and -log it. In the figure, the 1 is at the second place, so in terms of index it is 1 (as index starts at 0), and thus cross entropy is simply -log(P[1]).
Because in our code, the labels are already a number between 0 and 26 (the size of the vocabulary), we can use it as an index, extract the right number in each of the 24 vectors of prob, log them all, and the mean is simply the cross entropy of the whole batch.
So, simply:
- prob[torch.arange(24),target].log().mean()
Magical, no?
If you’re wondering why it is still worth to use the built-in cross entropy function, watch this great explanation by Andrej Karpathy.
What about TensorFlow?
As traditionally done in the posts of that series, let’s also look at the equivalent code in TensorFlow.
As for PyTorch, for all the gymnastic preparation (broadcasting, indexing and reshaping), please refer to the post #1 , #2 and #3 of our Deep Learning Gymnastic series .
Regarding the cross entropy function in TensorFlow, we can use e.g. sparse_softmax_cross_entropy_with_logits . Note how explicit is the name: it tells that you need to pass logits, and then it will apply softmax and cross entropy.
If you’re using Keras, you can also use the SparseCategoricalCrossentropy . Note that to do so, you first need to instantiate the function , explicitly saying we’re using logits, and then apply it to the reshaped logits and labels.
Find the full code below, illustrating both entropy functions.
import tensorflow as tf
tf.random.set_seed(18)
# Create a random batch of shape (8,3) with indexes between 0 and 26
random_tensor = tf.random.uniform(shape=(8,3), minval=0, maxval=26, dtype=tf.int32)
# create random logits for each index in the vocabulary
L = tf.random.uniform((27,27), dtype=tf.float32)
#creating the labels
Y = tf.random.uniform(shape=(8,3), minval=0, maxval=26, dtype=tf.int32)
# creating our batch (8,3,27). C.f https://www.philippeadjiman.com/blog/2023/12/23/deep-learning-gymnastics-tensor-indexing/
X = tf.gather(L,random_tensor)
#Reshaping before using cross_entropy. C.f https://www.philippeadjiman.com/blog/2024/02/03/deep-learning-gymnastics-tensor-reshaping/
B,T,C = X.shape
logits = tf.reshape( X , [B*T,C])
labels = tf.reshape( Y , [B*T,1]) # 24 numbers (each one between 0 and 26)
#Calling cross entropy using sparse_softmax_cross_entropy_with_logits
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels[:, 0],logits=logits)
print(tf.reduce_mean(loss))
#Calling cross entropy using Keras' SparseCategoricalCrossentropy
ce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
print(ce(labels,logits))
That’s it for today.
Hope you’re feeling in better shape with your tensors 🤸. Until our next episode.
Like those posts? Feel free to subscribe here to not miss future ones: