MIT Undergrad's Guide to Building a Vision Transformer from Scratch: Insights and Results on CIFAR-10
Vision transformers (ViTs) represent a significant shift in the field of computer vision, marking a departure from the dominance of convolutional neural networks (CNNs). In 2020, researchers at Google introduced the Vision Transformer in a groundbreaking paper titled "An Image Is Worth 16×16 Words," challenging the conventional reliance on CNNs and demonstrating that attention mechanisms, previously successful in natural language processing (NLP), could also excel in image recognition tasks. Background and Intuition 1.1 From Recurrent Models to the Rise of Transformers (in NLP) Before 2017, recurrent neural networks (RNNs) and long short-term memory (LSTM) networks were the primary tools for NLP tasks such as machine translation and language modeling. However, these models processed sequences one token at a time, making it difficult to parallelize training and retain information over long sequences. The introduction of the Transformer architecture in the 2017 paper "Attention Is All You Need" revolutionized NLP by enabling parallel processing and capturing long-range dependencies through self-attention mechanisms. Each word in a sequence can directly consider every other word, learning to query, key, and value other tokens effectively. 1.2 Can Attention Replace Convolution? The Shift to Vision Inspired by the Transformer's success in NLP, researchers hypothesized that attention mechanisms could similarly replace convolutions in computer vision tasks. CNNs are powerful but have limitations, particularly in capturing long-range dependencies and requiring spatial priors and feature engineering. The Vision Transformer divides images into non-overlapping patches and treats them as tokens in a sequence, allowing the self-attention mechanism to process the entire image at once. Initial experiments showed that ViTs could perform comparably to CNNs, especially with large datasets, paving the way for further advancements in computer vision. How Vision Transformers Work 2.1 Patch Embedding The first step in the ViT pipeline is to convert an image into a sequence of patches. Each image is divided into smaller, fixed-size patches, which are then flattened and linearly projected into a fixed-size embedding. This transformation enables the ViT to handle images in a similar way to how transformers handle text sequences. 2.2 Class Token and Positional Embeddings To ensure the model can capture global information and maintain spatial awareness, two key elements are added to the sequence: a [CLS] token and positional embeddings. The [CLS] token is a special learnable token prepended to the sequence of patch embeddings, serving as a representation for the entire image during the final classification step. Positional embeddings are also learned and added to each patch embedding to encode their spatial positions. 2.3 Multi-Head Self-Attention (MHSA) At the core of the ViT is the multi-head self-attention mechanism, which allows the model to understand the relationships between different patches. Instead of a single attention function, MHSA uses multiple heads, each focusing on different aspects of the input. This parallel processing helps the model capture both local and global features. The attention scores are computed using queries (Q), keys (K), and values (V), which are then scaled, softmaxed, and used to weight the values before concatenating and projecting back to the original embedding space. 2.4 Transformer Encoder The Transformer encoder block in ViTs combines self-attention with a multilayer perceptron (MLP) and residual connections. This structure ensures that the model can attend globally and transform features across layers while maintaining stability. Each block includes: - Layer normalization before the attention mechanism (pre-norm). - Multi-head self-attention applied to the normalized input. - A residual connection adding the attention output back. - Another layer normalization followed by a small MLP. - Another residual connection adding the MLP output back. 2.5 Classification Head After passing through multiple Transformer blocks, the final embedding of the [CLS] token serves as a summary representation of the image. This vector is then fed into a linear classification head to produce class logits. The classification head is crucial for generating the final prediction from the processed image patches. Implementation Walkthrough 3.1 Patch Embedding To implement patch embedding, a PatchEmbed class is defined. This class uses a convolutional layer with a kernel size equal to the patch size and a stride equal to the patch size, efficiently extracting and projecting non-overlapping patches. The resulting sequence of embeddings is shaped [B, num_patches, embed_dim]. ```python class PatchEmbed(nn.Module): def init(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().init() self.img_size = img_size self.patch_size = patch_size self.num_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.proj(x) # [B, embed_dim, H/patch, W/patch] x = x.flatten(2) # [B, embed_dim, num_patches] x = x.transpose(1, 2) # [B, num_patches, embed_dim] return x ``` 3.2 Class Token and Positional Embeddings The ViTEmbed class handles the addition of the class token and positional embeddings. It expands the class token to match the batch size and prepends it to the sequence of patch embeddings. Positional embeddings are added to each token to encode spatial information. ```python class ViTEmbed(nn.Module): def init(self, num_patches, embed_dim): super().init() self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # [1, 1, D] self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # [1, N+1, D] def forward(self, x): batch_size = x.shape[0] cls_tokens = self.cls_token.expand(batch_size, -1, -1) # [B, 1, D] x = torch.cat((cls_tokens, x), dim=1) # [B, N+1, D] x = x + self.pos_embed # [B, N+1, D] return x ``` 3.3 Multi-Head Self-Attention The MyMultiheadAttention class implements the multi-head self-attention mechanism. Each input token is projected into query, key, and value vectors using learnable linear layers. The attention scores are computed in parallel across multiple heads, and the outputs are concatenated and projected back to the original embedding dimension. ```python class MyMultiheadAttention(nn.Module): def init(self, embed_dim, num_heads): super().init() assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, x): B, T, C = x.shape # [batch, seq_len, embed_dim] Q = self.q_proj(x) K = self.k_proj(x) V = self.v_proj(x) def split_heads(tensor): return tensor.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) Q = split_heads(Q) K = split_heads(K) V = split_heads(V) scores = torch.matmul(Q, K.transpose(-2, -1)) # [B, heads, T, T] scores /= self.head_dim ** 0.5 attn = torch.softmax(scores, dim=-1) out = torch.matmul(attn, V) # [B, heads, T, head_dim] out = out.transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(out) ``` 3.4 Transformer Encoder Block The TransformerBlock class encapsulates the core functionality of a ViT block, combining self-attention and MLP layers with normalization and residual connections. This design ensures that the model can reason about global relationships and refine feature representations while maintaining stability. ```python class TransformerBlock(nn.Module): def init(self, embed_dim, num_heads, mlp_ratio=4.0): super().init() self.norm1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = nn.Sequential( nn.Linear(embed_dim, int(embed_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(embed_dim * mlp_ratio), embed_dim) ) def forward(self, x): x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] x = x + self.mlp(self.norm2(x)) return x ``` 3.5 Putting It All Together The SimpleViT class assembles all the components into a complete Vision Transformer model. It includes patch embedding, positional encoding, the class token, and a stack of Transformer blocks. Finally, the model normalizes the [CLS] token embedding and passes it through a linear classification head to produce class logits. ```python class SimpleViT(nn.Module): def init(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, num_classes=1000): super().init() self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) num_patches = (img_size // patch_size) ** 2 self.vit_embed = ViTEmbed(num_patches, embed_dim) self.blocks = nn.Sequential(*[TransformerBlock(embed_dim, num_heads) for _ in range(depth)]) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes) def forward(self, x): # [batch_size, channels, height, width] x = self.patch_embed(x) # -> [B, N, D] x = self.vit_embed(x) # add CLS token + pos embed x = self.blocks(x) # transformer layers x = self.norm(x) # normalize CLS token return self.head(x[:, 0]) # classification using CLS token ``` Training the ViT 4.1 Dataset: CIFAR-10 The model was trained on the CIFAR-10 dataset, a benchmark for image recognition with 60,000 32x32 pixel images spanning 10 classes. Despite the small image size, the dataset is challenging due to the variety of visual patterns and similarities between classes. 4.2 Model Setup: Adapting ViT for CIFAR-10 To make the training feasible on CIFAR-10 with limited computational resources, several adjustments were made to the ViT architecture: - Smaller image size (32x32 pixels) was handled by reducing the patch size to 4x4. - Embedding dimension was reduced to 192. - Fewer Transformer blocks (6) and heads (3) were used. - The number of classes was set to 10. 4.3 Training Setup The training setup included: - Optimizer: AdamW - Learning rate scheduler: Cosine Annealing - Data augmentation: Random horizontal flips and rotations - Normalization: Mean and standard deviation Training was efficient, averaging about 30 seconds per epoch on a GPU. The model was trained for 30 epochs, achieving approximately 60% accuracy on the CIFAR-10 test set. This performance is a promising baseline, considering the simplicity of the model and the dataset's small size. 4.4 Results Training plots show the model's loss and accuracy over the 30 epochs. The model performed well, accurately identifying many samples but struggling with visually similar classes. For example, it often misclassified ships as airplanes. A bar chart of class accuracies highlights the model's strengths and areas for improvement. Industry Insights and Evaluation The introduction of Vision Transformers has been widely recognized as a landmark achievement in deep learning, opening new avenues for research and application in computer vision. The ability to process images using self-attention mechanisms has demonstrated potential in handling complex visual tasks and capturing long-range dependencies. Companies and researchers are rapidly adopting ViTs for various applications, including object detection, segmentation, and generative models. Meta's recent investment in Scale AI, a data-labeling company, underscores the growing importance of high-quality training data in advancing AI models, including those based on Transformer architectures. The shift towards attention-based models is seen as a critical step in the ongoing evolution of computer vision, promising more scalable and efficient solutions in the future.