In the heart of the implementation of modern deep learning models (yes, including LLMs) always lies some subtle and critical techniques and/or tricks that are important to know and master. Tensor Broadcasting is one of them.
Official doc exists (for e.g. pytorch or tensorflow) but in this post, we’ll try to introduce the topic in a simple and intuitive way, using a motivating example inspired from the amazing series of videos from Andrej Karpathy on language modeling.
Example of broadcasting in action
Suppose you have a tensor of size 3 x 4 (tensor having 2 dimensions can also be just called a matrix) , and each row represents a set of counts over 4 options you try to choose from (the higher, the more likely it is the right option), and your goal is to efficiently transform those counts into probability densities. On a concrete example, you want to go from left to right here:
The matrix on the left is our raw counts, and the one on the right is what we’d like to get. So we’d like to find an efficient (vectorized) way to sum up all the rows separately, and divide each count by the sum of its row. So we first need to create a matrix of shape 1×3 which contains the sum of each row, typically :
\(\) \begin{bmatrix} 150 \\ 50 \\ 100 \end{bmatrix} \(\)
The question then is whether the following operation is allowed:
(for the sake of the explanation, we’re assuming that none of the rows’ sum is equal to 0)
This is where broadcasting comes into play. When presented such an operation, broadcasting will find a way to adapt the second matrix to be of the same dimension as the first one, by duplicating its columns, and then perform an efficient element wise division. As follows:
Are your tensors broadcastable?
Whether your doing broadcasting using numpy, pytorch or tensorflow , in order to know if two tensors are “broadcastable”, you just need to align the shapes (or dimensions) of your two tensors from right to left, and for each dimension, check if they are either equal, or one of them is 1, or one of them does not exist. If it is the case for all dimensions, then the two tensors are broadcastable. What is the shape of the resulting tensor? just take the max dimension along each dimension.
Let’s try it on our example. The shape of the first tensor is [3,4] and the second one (before broadcasting) is [3,1] . So let’s align the shapes and go from right to left and compare each dimension:
This method works also for tensors of any shapes. Let’s check a couple of other examples:
Example 1: Two tensors with shapes A.shape = [4,3,2]
and B.shape = [3,1]
Example 2: Two tensors with shapes A.shape = [4,3,2]
and B.shape = [3,1,2]
Which of the two examples are brodcastable tensors and which are not? Let’s start by Example 1:
All good, you can broadcast those two tensors. Note that for the case of the most left dimension, since it was not existing for the second tensor, it just acts as if it was a 1.
What about Example 2?
Because the most left dimension of those two tensors both exists but are not equal, and none of them is 1, then it breaks the conditions for them to be broadcastable.
Tensor brodcasting in Pytorch and Tensorflow
Let’s see broadcasting in action with PyTorch on a example of a tensor of shape 3×3 of counts, that we want to normalize in the same way as our previous example:
import torch N = torch.tensor([[10, 20, 10], [20, 5 , 25], [10, 60, 30]], dtype=torch.int32) # calculate sum along rows row_sums = N.sum(dim=1, keepdim=True) # normalize each row N_normalized = N / row_sums
The parameter dim=1
is here to say that we want to sum over rows, and for the keepdim
parameter, wait for next section to see why we used it and why it is critical.
Let’s now print N
, row_sums
and N_normalized
respectively:
As we can see, the broadcast operation worked as expected as the sum on each row of the results is indeed equal to 1.
Let’s see how the code looks like in tensorflow:
import tensorflow as tf N = tf.constant([ [10, 20, 10], [20, 5, 25], [10, 60, 30] ], dtype=tf.int32) # calculate sum along rows row_sums = tf.reduce_sum(N, axis=1, keepdims=True) # normalize each row N_normalized = N / row_sums
As you can see, the code is rather similar, up to some differences like the need to use the tf.reduce_sum
function rather than doing the sum directly on the tensor, and also, the keepdim
parameter is now in plural (keepdims
)😅 . But printing N_normalized
returns the same result as with the pytorch code.
When things go wrong
So, what was this keepdim=True
(or keepdims=True
in tensorflow) all about?
If you run e.g. the exact same pytorch code as above but without keepdim=True
, this is what you’ll get when printing N
, row_sums
and N_normalized
.
As you can see, N_normalized
is completely messed up and the rows don’t sum to 1 anymore 🤦
But how that happened? What did broadcasting do at all?
First, was the operation broadcastable? well, now you know how to check it from previous section. N
is of shape [3,3]
and the trick is that now row_sums
is of shape [3]
, because pytorch squeezed the dimension and created a line vector. Using the method explained before, you can see that the tensors are broadcastable.
And practically, what happens now is that row_sums
gets duplicated horizontally instead of being duplicated vertically! In other words, during the operation N / row_sums
, this is what happened to row_sums
in the process:
So as you can see, in that case, the keepdim
parameter was critical to keep row_sums
with the same number of dimensions than the initial tensor and thus have the right shape for a proper broadcasting.
ChatBots can help, but only when you know what you’re doing
This statement holds for any code related generation coming from chat bots like Bard or ChatGPT.
Specifically on that one, depending on the version of the chatBot you’re using and how you ask your prompt, sometimes you’ll get the right code (using keepdims=True
) and sometimes not. But now, for any broadcasting related question, you won’t be able to get fooled anymore 🤩.
Conclusion
Broadcasting is a critical technique that every deep learning developer needs to master in order to efficiently and properly implement state of the art models in an efficient way. And you better understand the nuances and subtleties we discussed (like e.g. the keepdims
param), otherwise you might silently introduce bugs that will render your whole model useless.