08. PyTorch Paper Replicating¶
Welcome to Milestone Project 2: PyTorch Paper Replicating!
In this project, we're going to be replicating a machine learning research paper and creating a Vision Transformer (ViT) from scratch using PyTorch.
We'll then see how ViT, a state-of-the-art computer vision architecture, performs on our FoodVision Mini problem.

For Milestone Project 2 we're going to focus on recreating the Vision Transformer (ViT) computer vision architecture and applying it to our FoodVision Mini problem to classify different images of pizza, steak and sushi.
What is paper replicating?¶
It's no secret machine learning is advancing fast.
Many of these advances get published in machine learning research papers.
And the goal of paper replicating is to replicate these advances with code so you can use the techniques for your own problem.
For example, let's say a new model architecture gets released that performs better than any other architecture before on various benchmarks, wouldn't it be nice to try that architecture on your own problems?

Machine learning paper replicating involves turning a machine learning paper comprised of images/diagrams, math and text into usable code and in our case, usable PyTorch code. Diagram, math equations and text from the ViT paper.
What is a machine learning research paper?¶
A machine learning research paper is a scientific paper that details findings of a research group on a specific area.
The contents of a machine learning research paper can vary from paper to paper but they generally follow the structure:
Section | Contents |
---|---|
Abstract | An overview/summary of the paper's main findings/contributions. |
Introduction | What's the paper's main problem and details of previous methods used to try and solve it. |
Method | How did the researchers go about conducting their research? For example, what model(s), data sources, training setups were used? |
Results | What are the outcomes of the paper? If a new type of model or training setup was used, how did the results of findings compare to previous works? (this is where experiment tracking comes in handy) |
Conclusion | What are the limitations of the suggested methods? What are some next steps for the research community? |
References | What resources/other papers did the researchers look at to build their own body of work? |
Appendix | Are there any extra resources/findings to look at that weren't included in any of the above sections? |
Why replicate a machine learning research paper?¶
A machine learning research paper is often a presentation of months of work and experiments done by some of the best machine learning teams in the world condensed into a few pages of text.
And if these experiments lead to better results in an area related to the problem you're working on, it'd be nice to check them out.
Also, replicating the work of others is a fantastic way to practice your skills.
George Hotz is founder of comma.ai, a self-driving car company and livestreams machine learning coding on Twitch and those videos get posted in full to YouTube. I pulled this quote from one of his livestreams. The "Ł" is to note that machine learning engineering often involves the extra step(s) of preprocessing data and making your models available for others to use (deployment).
When you first start trying to replicate research papers, you'll likely be overwhelmed.
That's normal.
Research teams spend weeks, months and sometimes years creating these works so it makes sense if it takes you sometime to even read let alone reproduce the works.
Replicating research is such a tough problem, phenomenal machine learning libraries and tools such as, HuggingFace, PyTorch Image Models (timm
library) and fast.ai have been born out of making machine learning research more accessible.
Where can you find code examples for machine learning research papers?¶
One of the first things you'll notice when it comes to machine learning research is: there's a lot of it.
So beware, trying to stay on top of it is like trying to outrun a hamster wheel.
Follow your interest, pick a few things that stand out to you.
In saying this, there are several places to find and read machine learning research papers (and code):
Resource | What is it? |
---|---|
arXiv | Pronounced "archive", arXiv is a free and open resource for reading technical articles on everything from physics to computer science (inlcuding machine learning). |
AK Twitter | The AK Twitter account publishes machine learning research highlights, often with live demos almost every day. I don't understand 9/10 posts but I find it fun to explore every so often. |
Papers with Code | A curated collection of trending, active and greatest machine learning papers, many of which include code resources attached. Also includes a collection of common machine learning datasets, benchmarks and current state-of-the-art models. |
lucidrains' vit-pytorch GitHub repository |
Less of a place to find research papers and more of an example of what paper replicating with code on a larger-scale and with a specific focus looks like. The vit-pytorch repository is a collection of Vision Transformer model architectures from various research papers replicated with PyTorch code (much of the inspiration for this notebook was gathered from this repository). |
Note: This list is far from exhaustive. I only list a few places, the ones I use most frequently personally. So beware the bias. However, I've noticed that even this short list often sully satisfies my needs for knowing what's going on in the field. Any more and I might go crazy.
What we're going to cover¶
Rather than talk about replicating a paper, we're going to get hands-on and actually replicate a paper.
The process for replicating all papers will be slightly different but by seeing what it's like to do one, we'll get the momentum to do more.
More specifically, we're going to be replicating the machine learning research paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (ViT paper) with PyTorch.
The Transformer neural network architecture was originally introduced in the machine learning research paper Attention is all you need.
And the original Transformer architecture was designed to work on one-dimensional (1D) sequences of text.
A Transformer architecture is generally considered to be any neural network that uses the attention mechanism) as its primary learning layer. Similar to a how a convolutional neural network (CNN) uses convolutions as its primary learning layer.
Like the name suggests, the Vision Transformer (ViT) architecture was designed to adapt the original Transformer architecture to vision problem(s) (classification being the first and since many others have followed).
The original Vision Transformer has been through several iterations over the past couple of years, however, we're going to focus on replicating the original, otherwise known as the "vanilla Vision Transformer". Because if you can recreate the original, you can adapt to the others.
We're going to be focusing on building the ViT architecture as per the original ViT paper and applying it to FoodVision Mini.
Topic | Contents |
---|---|
0. Getting setup | We've written a fair bit of useful code over the past few sections, let's download it and make sure we can use it again. |
1. Get data | Let's get the pizza, steak and sushi image classification dataset we've been using and build a Vision Transformer to try and improve FoodVision Mini model's results. |
2. Create Datasets and DataLoaders | We'll use the data_setup.py script we wrote in chapter 05. PyTorch Going Modular to setup our DataLoaders. |
3. Replicating the ViT paper: an overview | Replicating a machine learning research paper can be bit a fair challenge, so before we jump in, let's break the ViT paper down into smaller chunks, so we can replicate the paper chunk by chunk. |
4. Equation 1: The Patch Embedding | The ViT architecture is comprised of four main equations, the first being the patch and position embedding. Or turning an image into a sequence of learnable patches. |
5. Equation 2: Multi-Head Attention (MSA) | The self-attention/multi-head self-attention (MSA) mechanism is at the heart of every Transformer architecture, including the ViT architecture, let's create an MSA block using PyTorch's in-built layers. |
6. Equation 3: Multilayer Perceptron (MLP) | The ViT architecture uses a multilayer perceptron as part of its Transformer Encoder and for its output layer. Let's start by creating an MLP for the Transformer Encoder. |
7. Creating the Transformer Encoder | A Transformer Encoder is typically comprised of alternating layers of MSA (equation 2) and MLP (equation 3) joined together via residual connections. Let's create one by stacking the layers we created in sections 5 & 6 on top of each other. |
8. Putting it all together to create ViT | We've got all the pieces of the puzzle to create the ViT architecture, let's put them all together into a single class we can call as our model. |
9. Setting up training code for our ViT model | Training our custom ViT implementation is similar to all of the other model's we've trained previously. And thanks to our train() function in engine.py we can start training with a few lines of code. |
10. Using a pretrained ViT from torchvision.models |
Training a large model like ViT usually takes a fair amount of data. Since we're only working with a small amount of pizza, steak and sushi images, let's see if we can leverage the power of transfer learning to improve our performance. |
11. Make predictions on a custom image | The magic of machine learning is seeing it work on your own data, so let's take our best performing model and put FoodVision Mini to the test on the infamous pizza-dad image (a photo of my dad eating pizza). |
Note: Despite the fact we're going to be focused on replicating the ViT paper, avoid getting too bogged down on a particular paper as newer better methods will often come along, quickly, so the skill should be to remain curious whilst building the fundamental skills of turning math and words on a page into working code.
Terminology¶
There are going to be a fair few acronyms throughout this notebook.
In light of this, here are some definitions:
- ViT - Stands for Vision Transformer (the main neural network architecture we're going to be focused on replicating).
- ViT paper - Short hand for the original machine learning research paper that introduced the ViT architecture, An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, anytime ViT paper is mentioned, you can be assured it is referencing this paper.
Where can you get help?¶
All of the materials for this course are available on GitHub.
If you run into trouble, you can ask a question on the course GitHub Discussions page.
And of course, there's the PyTorch documentation and PyTorch developer forums, a very helpful place for all things PyTorch.
0. Getting setup¶
As we've done previously, let's make sure we've got all of the modules we'll need for this section.
We'll import the Python scripts (such as data_setup.py
and engine.py
) we created in 05. PyTorch Going Modular.
To do so, we'll download going_modular
directory from the pytorch-deep-learning
repository (if we don't already have it).
We'll also get the torchinfo
package if it's not available.
torchinfo
will help later on to give us a visual representation of our model.
And since later on we'll be using torchvision
v0.13 package (available as of July 2022), we'll make sure we've got the latest versions.
# For this notebook to run with updated APIs, we need torch 1.12+ and torchvision 0.13+
try:
import torch
import torchvision
assert int(torch.__version__.split(".")[1]) >= 12, "torch version should be 1.12+"
assert int(torchvision.__version__.split(".")[1]) >= 13, "torchvision version should be 0.13+"
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
except:
print(f"[INFO] torch/torchvision versions not as required, installing nightly versions.")
!pip3 install -U torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
torch version: 1.12.0+cu102 torchvision version: 0.13.0+cu102
Note: If you're using Google Colab and the cell above starts to install various software packages, you may have to restart your runtime after running the above cell. After restarting, you can run the cell again and verify you've got the right versions of
torch
andtorchvision
.
Now we'll continue with the regular imports, setting up device agnostic code and this time we'll also get the helper_functions.py
script from GitHub.
The helper_functions.py
script contains several functions we created in previous sections:
set_seeds()
to set the random seeds (created in 07. PyTorch Experiment Tracking section 0).download_data()
to download a data source given a link (created in 07. PyTorch Experiment Tracking section 1).plot_loss_curves()
to inspect our model's training results (created in 04. PyTorch Custom Datasets section 7.8)
Note: It may be a better idea for many of the functions in the
helper_functions.py
script to be merged intogoing_modular/going_modular/utils.py
, perhaps that's an extension you'd like to try.
# Continue with regular imports
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torchvision import transforms
# Try to get torchinfo, install it if it doesn't work
try:
from torchinfo import summary
except:
print("[INFO] Couldn't find torchinfo... installing it.")
!pip install -q torchinfo
from torchinfo import summary
# Try to import the going_modular directory, download it from GitHub if it doesn't work
try:
from going_modular.going_modular import data_setup, engine
from helper_functions import download_data, set_seeds, plot_loss_curves
except:
# Get the going_modular scripts
print("[INFO] Couldn't find going_modular or helper_functions scripts... downloading them from GitHub.")
!git clone https://github.com/mrdbourke/pytorch-deep-learning
!mv pytorch-deep-learning/going_modular .
!mv pytorch-deep-learning/helper_functions.py . # get the helper_functions.py script
!rm -rf pytorch-deep-learning
from going_modular.going_modular import data_setup, engine
from helper_functions import download_data, set_seeds, plot_loss_curves
Note: If you're using Google Colab, and you don't have a GPU turned on yet, it's now time to turn one on via
Runtime -> Change runtime type -> Hardware accelerator -> GPU
.
device = "cuda" if torch.cuda.is_available() else "cpu"
device
'cuda'
1. Get Data¶
Since we're continuing on with FoodVision Mini, let's download the pizza, steak and sushi image dataset we've been using.
To do so we can use the download_data()
function from helper_functions.py
that we created in 07. PyTorch Experiment Tracking section 1.
We'll source
to the raw GitHub link of the pizza_steak_sushi.zip
data and the destination
to pizza_steak_sushi
.
# Download pizza, steak, sushi images from GitHub
image_path = download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
destination="pizza_steak_sushi")
image_path
[INFO] data/pizza_steak_sushi directory exists, skipping download.
PosixPath('data/pizza_steak_sushi')
Beautiful! Data downloaded, let's setup the training and test directories.
# Setup directory paths to train and test images
train_dir = image_path / "train"
test_dir = image_path / "test"
2. Create Datasets and DataLoaders¶
Now we've got some data, let's now turn it into DataLoader
's.
To do so we can use the create_dataloaders()
function in data_setup.py
.
First, we'll create a transform to prepare our images.
This where one of the first references to the ViT paper will come in.
In Table 3, the training resolution is mentioned as being 224 (height=224, width=224).
You can often find various hyperparameter settings listed in a table. In this case we're still preparing our data, so we're mainly concerned with things like image size and batch size. Source: Table 3 in ViT paper.
So we'll make sure our transform resizes our images appropriately.
And since we'll be training our model from scratch (no transfer learning to begin with), we won't provide a normalize
transform like we did in 06. PyTorch Transfer Learning section 2.1.
2.1 Prepare transforms for images¶
# Create image size (from Table 3 in the ViT paper)
IMG_SIZE = 224
# Create transform pipeline manually
manual_transforms = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
print(f"Manually created transforms: {manual_transforms}")
Manually created transforms: Compose( Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None) ToTensor() )
2.2 Turn images into DataLoader
's¶
Transforms created!
Let's now create our DataLoader
's.
The ViT paper states the use of a batch size of 4096 which is 128x the size of the batch size we've been using (32).
However, we're going to stick with a batch size of 32.
Why?
Because some hardware (including the free tier of Google Colab) may not be able to handle a batch size of 4096.
Having a batch size of 4096 means that 4096 images need to fit into the GPU memory at a time.
This works when you've got the hardware to handle it like a research team from Google often does but when you're running on a single GPU (such as using Google Colab), making sure things work with smaller batch size first is a good idea.
An extension of this project could be to try a higher batch size value and see what happens.
Note: We're using the
pin_memory=True
parameter in thecreate_dataloaders()
function to speed up computation.pin_memory=True
avoids unnecessary copying of memory between the CPU and GPU memory by "pinning" examples that have been seen before. Though the benefits of this will likely be seen with larger dataset sizes (our FoodVision Mini dataset is quite small). However, settingpin_memory=True
doesn't always improve performance (this is another one of those we're scenarios in machine learning where some things work sometimes and don't other times), so best to experiment, experiment, experiment. See the PyTorchtorch.utils.data.DataLoader
documentation or Making Deep Learning Go Brrrr from First Principles by Horace He for more.
# Set the batch size
BATCH_SIZE = 32 # this is lower than the ViT paper but it's because we're starting small
# Create data loaders
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
train_dir=train_dir,
test_dir=test_dir,
transform=manual_transforms, # use manually created transforms
batch_size=BATCH_SIZE
)
train_dataloader, test_dataloader, class_names
(<torch.utils.data.dataloader.DataLoader at 0x7f18845ff0d0>, <torch.utils.data.dataloader.DataLoader at 0x7f17f3f5f520>, ['pizza', 'steak', 'sushi'])
2.3 Visualize a single image¶
Now we've loaded our data, let's visualize, visualize, visualize!
An important step in the ViT paper is preparing the images into patches.
We'll get to what this means in section 4 but for now, let's view a single image and its label.
To do so, let's get a single image and label from a batch of data and inspect their shapes.
# Get a batch of images
image_batch, label_batch = next(iter(train_dataloader))
# Get a single image from the batch
image, label = image_batch[0], label_batch[0]
# View the batch shapes
image.shape, label
(torch.Size([3, 224, 224]), tensor(2))
Wonderful!
Now let's plot the image and its label with matplotlib
.
# Plot image with matplotlib
plt.imshow(image.permute(1, 2, 0)) # rearrange image dimensions to suit matplotlib [color_channels, height, width] -> [height, width, color_channels]
plt.title(class_names[label])
plt.axis(False);
Nice!
Looks like our images are importing correctly, let's continue with the paper replication.
3. Replicating the ViT paper: an overview¶
Before we write any more code, let's discuss what we're doing.
We'd like to replicate the ViT paper for our own problem, FoodVision Mini.
So our model inputs are: images of pizza, steak and sushi.
And our ideal model outputs are: predicted labels of pizza, steak or sushi.
No different to what we've been doing throughout the previous sections.
The question is: how do we go from our inputs to the desired outputs?
3.1 Inputs and outputs, layers and blocks¶
ViT is a deep learning neural network architecture.
And any neural network architecture is generally comprised of layers.
And a collection of layers is often referred to as a block.
And stacking many blocks together is what gives us the whole architecture.
A layer takes an input (say an image tensor), performs some kind of function on it (for example what's in the layer's forward()
method) and then returns an output.
So if a single layer takes an input and gives an output, then a collection of layers or a block also takes an input and gives an output.
Let's make this concrete:
- Layer - takes an input, performs a function on it, returns an output.
- Block - a collection of layers, takes an input, performs a series of functions on it, returns an output.
- Architecture (or model) - a collection of blocks, takes an input, performs a series of functions on it, returns an output.
This ideology is what we're going to be using to replicate the ViT paper.
We're going to take it layer by layer, block by block, function by function putting the pieces of the puzzle together like Lego to get our desired overall architecture.
The reason we do this is because looking at a whole research paper can be intimidating.
So for a better understanding, we'll break it down, starting with the inputs and outputs of single layer and working up to the inputs and outputs of the whole model.

A modern deep learning architecture is usually collection of layers and blocks. Where layers take an input (data as a numerical representation) and manipulate it using some kind of function (for example, the self-attention formula pictured above, however, this function could be almost anything) and then output it. Blocks are generally stacks of layers on top of each other doing a similar thing to a single layer but multiple times.
3.2 Getting specific: What's ViT made of?¶
There are many little details about the ViT model sprinkled throughout the paper.
Finding them all is like one big treasure hunt!
Remember, a research paper is often months of work compressed into a few pages so it's understandable for it to take of practice to replicate.
However, the main three resources we'll be looking at for the architecture design are:
- Figure 1 - This gives an overview of the model in a graphical sense, you could almost recreate the architecture with this figure alone.
- Four equations in section 3.1 - These equations give a little bit more of a mathematical grounding to the coloured blocks in Figure 1.
- Table 1 - This table shows the various hyperparameter settings (such as number of layers and number of hidden units) for different ViT model variants. We'll be focused on the smallest version, ViT-Base.
3.2.1 Exploring Figure 1¶
Let's start by going through Figure 1 of the ViT Paper.
The main things we'll be paying attention to are:
- Layers - takes an input, performs an operation or function on the input, produces an output.
- Blocks - a collection of layers, which in turn also takes an input and produces an output.
Figure 1 from the ViT Paper showcasing the different inputs, outputs, layers and blocks that create the architecture. Our goal will be to replicate each of these using PyTorch code.
The ViT architecture is comprised of several stages:
- Patch + Position Embedding (inputs) - Turns the input image into a sequence of image patches and adds a position number to specify in what order the patch comes in.
- Linear projection of flattened patches (Embedded Patches) - The image patches get turned into an embedding, the benefit of using an embedding rather than just the image values is that an embedding is a learnable representation (typically in the form of a vector) of the image that can improve with training.
- Norm - This is short for "Layer Normalization" or "LayerNorm", a technique for regularizing (reducing overfitting) a neural network, you can use LayerNorm via the PyTorch layer
torch.nn.LayerNorm()
. - Multi-Head Attention - This is a Multi-Headed Self-Attention layer or "MSA" for short. You can create an MSA layer via the PyTorch layer
torch.nn.MultiheadAttention()
. - MLP (or Multilayer perceptron) - A MLP can often refer to any collection of feedforward layers (or in PyTorch's case, a collection of layers with a
forward()
method). In the ViT Paper, the authors refer to the MLP as "MLP block" and it contains twotorch.nn.Linear()
layers with atorch.nn.GELU()
non-linearity activation in between them (section 3.1) and atorch.nn.Dropout()
layer after each (Appendex B.1). - Transformer Encoder - The Transformer Encoder, is a collection of the layers listed above. There are two skip connections inside the Transformer encoder (the "+" symbols) meaning the layer's inputs are fed directly to immediate layers as well as subsequent layers. The overall ViT architecture is comprised of a number of Transformer encoders stacked on top of eachother.
- MLP Head - This is the output layer of the architecture, it converts the learned features of an input to a class output. Since we're working on image classification, you could also call this the "classifier head". The structure of the MLP Head is similar to the MLP block.
You might notice that many of the pieces of the ViT architecture can be created with existing PyTorch layers.
This is because of how PyTorch is designed, it's one of the main purposes of PyTorch to create reusable neural network layers for both researchers and machine learning practitioners.
Question: Why not code everything from scratch?
You could definitely do that by reproducing all of the math equations from the paper with custom PyTorch layers and that would certainly be an educative exercise, however, using pre-existing PyTorch layers is usually favoured as pre-existing layers have often been extensively tested and performance checked to make sure they run correctly and fast.
Note: We're going to be focused on writing PyTorch code to create these layers. For the background on what each of these layers does, I'd suggest reading the ViT Paper in full or reading the linked resources for each layer.
Let's take Figure 1 and adapt it to our FoodVision Mini problem of classifying images of food into pizza, steak or sushi.
Figure 1 from the ViT Paper adapted for use with FoodVision Mini. An image of food goes in (pizza), the image gets turned into patches and then projected to an embedding. The embedding then travels through the various layers and blocks and (hopefully) the class "pizza" is returned.
3.2.2 Exploring the Four Equations¶
The next main part(s) of the ViT paper we're going to look at are the four equations in section 3.1.
These four equations represent the math behind the four major parts of the ViT architecture.
Section 3.1 describes each of these (some of the text has been omitted for brevity, bolded text is mine):
Equation number | Description from ViT paper section 3.1 |
---|---|
1 | ...The Transformer uses constant latent vector size $D$ through all of its layers, so we flatten the patches and map to $D$ dimensions with a trainable linear projection (Eq. 1). We refer to the output of this projection as the patch embeddings... Position embeddings are added to the patch embeddings to retain positional information. We use standard learnable 1D position embeddings... |
2 | The Transformer encoder (Vaswani et al., 2017) consists of alternating layers of multiheaded selfattention (MSA, see Appendix A) and MLP blocks (Eq. 2, 3). Layernorm (LN) is applied before every block, and residual connections after every block (Wang et al., 2019; Baevski & Auli, 2019). |
3 | Same as equation 2. |
4 | Similar to BERT's [ class ] token, we prepend a learnable embedding to the sequence of embedded patches $\left(\mathbf{z}_{0}^{0}=\mathbf{x}_{\text {class }}\right)$, whose state at the output of the Transformer encoder $\left(\mathbf{z}_{L}^{0}\right)$ serves as the image representation $\mathbf{y}$ (Eq. 4)... |
Let's map these descriptions to the ViT architecture in Figure 1.
Connecting Figure 1 from the ViT paper to the four equations from section 3.1 describing the math behind each of the layers/blocks.
There's a lot happening in the image above but following the coloured lines and arrows reveals the main concepts of the ViT architecture.
How about we break down each equation further (it will be our goal to recreate these with code)?
In all equations (except equation 4), "$\mathbf{z}$" is the raw output of a particular layer:
- $\mathbf{z}_{0}$ is "z zero" (this is the output of the initial patch embedding layer).
- $\mathbf{z}_{\ell}^{\prime}$ is "z of a particular layer prime" (or an intermediary value of z).
- $\mathbf{z}_{\ell}$ is "z of a particular layer".
And $\mathbf{y}$ is the overall output of the architecture.
3.2.3 Equation 1 overview¶
$$ \begin{aligned} \mathbf{z}_{0} &=\left[\mathbf{x}_{\text {class }} ; \mathbf{x}_{p}^{1} \mathbf{E} ; \mathbf{x}_{p}^{2} \mathbf{E} ; \cdots ; \mathbf{x}_{p}^{N} \mathbf{E}\right]+\mathbf{E}_{\text {pos }}, & & \mathbf{E} \in \mathbb{R}^{\left(P^{2} \cdot C\right) \times D}, \mathbf{E}_{\text {pos }} \in \mathbb{R}^{(N+1) \times D} \end{aligned} $$
This equation deals with the class token, patch embedding and position embedding ($\mathbf{E}$ is for embedding) of the input image.
In vector form, the embedding might look something like:
x_input = [class_token, image_patch_1, image_patch_2, image_patch_3...] + [class_token_position, image_patch_1_position, image_patch_2_position, image_patch_3_position...]
Where each of the elements in the vector is learnable (their requires_grad=True
).
3.2.4 Equation 2 overview¶
$$ \begin{aligned} \mathbf{z}_{\ell}^{\prime} &=\operatorname{MSA}\left(\operatorname{LN}\left(\mathbf{z}_{\ell-1}\right)\right)+\mathbf{z}_{\ell-1}, & & \ell=1 \ldots L \end{aligned} $$
This says that for every layer from $1$ through to $L$ (the total number of layers), there's a Multi-Head Attention layer (MSA) wrapping a LayerNorm layer (LN).
The addition on the end is the equivalent of adding the input to the output and forming a skip/residual connection.
We'll call this layer the "MSA block".
In pseudocode, this might look like:
x_output_MSA_block = MSA_layer(LN_layer(x_input)) + x_input
Notice the skip connection on the end (adding the input of the layers to the output of the layers).
3.2.5 Equation 3 overview¶
$$ \begin{aligned} \mathbf{z}_{\ell} &=\operatorname{MLP}\left(\operatorname{LN}\left(\mathbf{z}_{\ell}^{\prime}\right)\right)+\mathbf{z}_{\ell}^{\prime}, & & \ell=1 \ldots L \\ \end{aligned} $$
This says that for every layer from $1$ through to $L$ (the total number of layers), there's also a Multilayer Perceptron layer (MLP) wrapping a LayerNorm layer (LN).
The addition on the end is showing the presence of a skip/residual connection.
We'll call this layer the "MLP block".
In pseudocode, this might look like:
x_output_MLP_block = MLP_layer(LN_layer(x_output_MSA_block)) + x_output_MSA_block
Notice the skip connection on the end (adding the input of the layers to the output of the layers).
3.2.6 Equation 4 overview¶
$$ \begin{aligned} \mathbf{y} &=\operatorname{LN}\left(\mathbf{z}_{L}^{0}\right) & & \end{aligned} $$
This says for the last layer $L$, the output $y$ is the 0 index token of $z$ wrapped in a LayerNorm layer (LN).
Or in our case, the 0 index of x_output_MLP_block
:
y = Linear_layer(LN_layer(x_output_MLP_block[0]))
Of course there are some simplifications above but we'll take care of those when we start to write PyTorch code for each section.
Note: The above section covers alot of information. But don't forget if something doesn't make sense, you can always research it further. By asking questions like "what is a residual connection?".
3.2.7 Exploring Table 1¶
The final piece of the ViT architecture puzzle we'll focus on (for now) is Table 1.
Model | Layers | Hidden size $D$ | MLP size | Heads | Params |
---|---|---|---|---|---|
ViT-Base | 12 | 768 | 3072 | 12 | $86M$ |
ViT-Large | 24 | 1024 | 4096 | 16 | $307M$ |
ViT-Huge | 32 | 1280 | 5120 | 16 | $632M$ |
This table showcasing the various hyperparameters of each of the ViT architectures.
You can see the numbers gradually increase from ViT-Base to ViT-Huge.
We're going to focus on replicating ViT-Base (start small and scale up when necessary) but we'll be writing code that could easily scale up to the larger variants.
Breaking the hyperparameters down:
- Layers - How many Transformer Encoder blocks are there? (each of these will contain a MSA block and MLP block)
- Hidden size $D$ - This is the embedding dimension throughout the architecture, this will be the size of the vector that our image gets turned into when it gets patched and embedded. Generally, the larger the embedding dimension, the more information can be captured, the better results. However, a larger embedding comes at the cost of more compute.
- MLP size - What are the number of hidden units in the MLP layers?
- Heads - How many heads are there in the Multi-Head Attention layers?
- Params - What are the total number of parameters of the model? Generally, more parameters leads to better performance but at the cost of more compute. You'll notice even ViT-Base has far more parameters than any other model we've used so far.
We'll use these values as the hyperparameter settings for our ViT architecture.
3.3 My workflow for replicating papers¶
When I start working on replicating a paper, I go through the following steps:
- Read the whole paper end-to-end once (to get an idea of the main concepts).
- Go back through each section and see how they line up with each other and start thinking about how they might be turned into code (just like above).
- Repeat step 2 until I've got a fairly good outline.
- Use mathpix.com (a very handy tool) to turn any sections of the paper into markdown/LaTeX to put into notebooks.
- Replicate the simplest version of the model possible.
- If I get stuck, look up other examples.
Turning the four equations from the ViT paper into editable LaTeX/markdown using mathpix.com.
We've already gone through the first few steps above (and if you haven't read the full paper yet, I'd encourage you to give it a go) but what we'll be focusing on next is step 5: replicating the simplest version of the model possible.
This is why we're starting with ViT-Base.
Replicating the smallest version of the architecture possible, get it working and then we can scale up if we wanted to.
Note: If you've never read a research paper before, many of the above steps can be intimidating. But don't worry, like anything, your skills at reading and replicating papers will improve with practice. Don't forget, a research paper is often months of work by many people compressed into a few pages. So trying to replicate it on your own is no small feat.
4. Equation 1: Split data into patches and creating the class, position and patch embedding¶
I remember one of my machine learning engineer friends used to say "it's all about the embedding."
As in, if you can represent your data in a good, learnable way (as embeddings are learnable representations), chances are, a learning algorithm will be able to perform well on them.
With that being said, let's start by creating the class, position and patch embeddings for the ViT architecture.
We'll start with the patch embedding.
This means we'll be turning our input images in a sequence of patches and then embedding those patches.
Recall that an embedding is a learnable representation of some form and is often a vector.
The term learnable is important because this means the numerical representation of an input image (that the model sees) can be improved over time.
We'll begin by following the opening paragraph of section 3.1 of the ViT paper (bold mine):
The standard Transformer receives as input a 1D sequence of token embeddings. To handle 2D images, we reshape the image $\mathbf{x} \in \mathbb{R}^{H \times W \times C}$ into a sequence of flattened 2D patches $\mathbf{x}_{p} \in \mathbb{R}^{N \times\left(P^{2} \cdot C\right)}$, where $(H, W)$ is the resolution of the original image, $C$ is the number of channels, $(P, P)$ is the resolution of each image patch, and $N=H W / P^{2}$ is the resulting number of patches, which also serves as the effective input sequence length for the Transformer. The Transformer uses constant latent vector size $D$ through all of its layers, so we flatten the patches and map to $D$ dimensions with a trainable linear projection (Eq. 1). We refer to the output of this projection as the patch embeddings.
And size we're dealing with image shapes, let's keep in mind the line from Table 3 of the ViT paper:
Training resolution is 224.
Let's break down the text above.
- $D$ is the size of the patch embeddings, different values for $D$ for various sized ViT models can be found in Table 1.
- The image starts as 2D with size ${H \times W \times C}$.
- $(H, W)$ is the resolution of the original image (height, width).
- $C$ is the number of channels.
- The image gets converted to a sequence of flattened 2D patches with size ${N \times\left(P^{2} \cdot C\right)}$.
- $(P, P)$ is the resolution of each image patch (patch size).
- $N=H W / P^{2}$ is the resulting number of patches, which also serves as the input sequence length for the Transformer.
Mapping the patch and position embedding portion of the ViT architecture from Figure 1 to Equation 1. The opening paragraph of section 3.1 describes the different input and output shapes of the patch embedding layer.
4.1 Calculating patch embedding input and output shapes by hand¶
How about we start by calculating these input and output shape values by hand?
To do so, let's create some variables to mimic each of the terms (such as $H$, $W$ etc) above.
We'll use a patch size ($P$) of 16 since it's the best performing version of ViT-Base uses (see column "ViT-B/16" of Table 5 in the ViT paper for more).
# Create example values
height = 224 # H ("The training resolution is 224.")
width = 224 # W
color_channels = 3 # C
patch_size = 16 # P
# Calculate N (number of patches)
number_of_patches = int((height * width) / patch_size**2)
print(f"Number of patches (N) with image height (H={height}), width (W={width}) and patch size (P={patch_size}): {number_of_patches}")
Number of patches (N) with image height (H=224), width (W=224) and patch size (P=16): 196
We've got the number of patches, how about we create the image output size as well?
Better yet, let's replicate the input and output shapes of the patch embedding layer.
Recall:
- Input: The image starts as 2D with size ${H \times W \times C}$.
- Output: The image gets converted to a sequence of flattened 2D patches with size ${N \times\left(P^{2} \cdot C\right)}$.
# Input shape (this is the size of a single image)
embedding_layer_input_shape = (height, width, color_channels)
# Output shape
embedding_layer_output_shape = (number_of_patches, patch_size**2 * color_channels)
print(f"Input shape (single 2D image): {embedding_layer_input_shape}")
print(f"Output shape (single 2D image flattened into patches): {embedding_layer_output_shape}")
Input shape (single 2D image): (224, 224, 3) Output shape (single 2D image flattened into patches): (196, 768)
Input and output shapes acquired!
4.2 Turning a single image into patches¶
Now we know the ideal input and output shapes for our patch embedding layer, let's move towards making it.
What we're doing is breaking down the overall architecture into smaller pieces, focusing on the inputs and outputs of individual layers.
So how do we create the patch embedding layer?
We'll get to that shortly, first, let's visualize, visualize, visualize! what it looks like to turn an image into patches.
Let's start with our single image.
# View single image
plt.imshow(image.permute(1, 2, 0)) # adjust for matplotlib
plt.title(class_names[label])
plt.axis(False);
We want to turn this image into patches of itself inline with Figure 1 of the ViT paper.
How about we start by just visualizing the top row of patched pixels?
We can do this by indexing on the different image dimensions.
# Change image shape to be compatible with matplotlib (color_channels, height, width) -> (height, width, color_channels)
image_permuted = image.permute(1, 2, 0)
# Index to plot the top row of patched pixels
patch_size = 16
plt.figure(figsize=(patch_size, patch_size))
plt.imshow(image_permuted[:patch_size, :, :]);
Now we've got the top row, let's turn it into patches.
We can do this by iterating through the number of patches there'd be in the top row.
# Setup hyperparameters and make sure img_size and patch_size are compatible
img_size = 224
patch_size = 16
num_patches = img_size/patch_size
assert img_size % patch_size == 0, "Image size must be divisible by patch size"
print(f"Number of patches per row: {num_patches}\nPatch size: {patch_size} pixels x {patch_size} pixels")
# Create a series of subplots
fig, axs = plt.subplots(nrows=1,
ncols=img_size // patch_size, # one column for each patch
figsize=(num_patches, num_patches),
sharex=True,
sharey=True)
# Iterate through number of patches in the top row
for i, patch in enumerate(range(0, img_size, patch_size)):
axs[i].imshow(image_permuted[:patch_size, patch:patch+patch_size, :]); # keep height index constant, alter the width index
axs[i].set_xlabel(i+1) # set the label
axs[i].set_xticks([])
axs[i].set_yticks([])
Number of patches per row: 14.0 Patch size: 16 pixels x 16 pixels
Those are some nice looking patches!
How about we do it for the whole image?
This time we'll iterate through the indexs for height and width and plot each patch as it's own subplot.
# Setup hyperparameters and make sure img_size and patch_size are compatible
img_size = 224
patch_size = 16
num_patches = img_size/patch_size
assert img_size % patch_size == 0, "Image size must be divisible by patch size"
print(f"Number of patches per row: {num_patches}\
\nNumber of patches per column: {num_patches}\
\nTotal patches: {num_patches*num_patches}\
\nPatch size: {patch_size} pixels x {patch_size} pixels")
# Create a series of subplots
fig, axs = plt.subplots(nrows=img_size // patch_size, # need int not float
ncols=img_size // patch_size,
figsize=(num_patches, num_patches),
sharex=True,
sharey=True)
# Loop through height and width of image
for i, patch_height in enumerate(range(0, img_size, patch_size)): # iterate through height
for j, patch_width in enumerate(range(0, img_size, patch_size)): # iterate through width
# Plot the permuted image patch (image_permuted -> (Height, Width, Color Channels))
axs[i, j].imshow(image_permuted[patch_height:patch_height+patch_size, # iterate through height
patch_width:patch_width+patch_size, # iterate through width
:]) # get all color channels
# Set up label information, remove the ticks for clarity and set labels to outside
axs[i, j].set_ylabel(i+1,
rotation="horizontal",
horizontalalignment="right",
verticalalignment="center")
axs[i, j].set_xlabel(j+1)
axs[i, j].set_xticks([])
axs[i, j].set_yticks([])
axs[i, j].label_outer()
# Set a super title
fig.suptitle(f"{class_names[label]} -> Patchified", fontsize=16)
plt.show()
Number of patches per row: 14.0 Number of patches per column: 14.0 Total patches: 196.0 Patch size: 16 pixels x 16 pixels