DeiT: Making Vision Transformers More Data-Efficient and Cost-Effective
Vision Transformer on a Budget The vanilla Vision Transformer (ViT) is highly effective but requires an enormous amount of labeled training data. Specifically, the original ViT, introduced in October 2020, needs about 300 million images from the JFT-300M dataset to achieve optimal performance. This data requirement poses a significant computational and resource challenge, hence the need for a more efficient alternative. Touvron et al. addressed this issue in their December 2020 paper titled "Training Data-Efficient Image Transformers & Distillation Through Attention." They proposed the Data-efficient image Transformer (DeiT), a model that uses knowledge distillation to significantly reduce the data requirements while maintaining high accuracy. Knowledge distillation involves transferring the learned knowledge from a pre-trained teacher model to a student model during training. In this case, the teacher model is RegNet, a Convolutional Neural Network (CNN)-based model, while the student model is DeiT. The Concept of DeiT DeiT improves upon the original ViT by incorporating a distillation token alongside the class token. While the class token captures the overall class information, the distillation token receives guidance from the teacher model, effectively learning from both the dataset and the teacher's output. This mechanism makes DeiT more efficient in terms of data usage. For instance, while ViT requires 300 million images, DeiT can achieve comparable results with just 1 million images from the ImageNet-1K dataset—a 300-fold reduction in data requirements. DeiT Variants The paper introduces three variants of DeiT: DeiT-Ti (Tiny), DeiT-S (Small), and DeiT-B (Base). Each variant is designed to balance efficiency with performance. The Base variant (DeiT-B) is the largest and most powerful among the three, and it matches the model size of the smallest ViT variant (ViT-B). This indicates that DeiT is designed to challenge ViT by focusing on data efficiency. Experimental Results The DeiT paper presents several experiments demonstrating the model's effectiveness, particularly on the ImageNet-1K dataset. When trained without additional data, DeiT outperforms ViT and other models like EfficientNet in terms of accuracy and processing speed. Notably, the DeiT-B variant trained with the novel distillation mechanism and fine-tuned on upscaled 384×384 images (DeiT-BAlembic↑384) showed even more significant improvements. The results highlight DeiT's ability to leverage limited data more effectively than its predecessors. DeiT Architecture and Implementation To implement the DeiT architecture, let’s break it down into key components: Patching Mechanism: DeiT treats images as sequences of patches. A Patcher class is defined to extract non-overlapping patches from an image. This is achieved using a 2D convolution layer with a kernel size equal to the patch size, followed by flattening and permuting operations to form a sequence of patch tensors. Transformer Encoder: The core of DeiT is a series of transformer encoder layers, each consisting of a multi-head attention layer, a feed-forward network (FFN), and normalization layers. The encoder layer processes the patch tokens, class token, and distillation token, ensuring that the class and distillation tokens capture meaningful information. Class and Distillation Tokens: Instead of just a class token, DeiT adds a distillation token. Both tokens are initialized as trainable parameters and added to the patch token sequence. The position embedding tensor is also added to the sequence to capture spatial relationships. Multiple Encoder Layers: The transformed patch sequence passes through a stack of 12 encoder layers. Each layer updates the token sequence, progressively enriching the class and distillation tokens with contextual information. Output Heads: After passing through the encoders, the class and distillation tokens are passed through separate linear layers (output heads) to produce the final logits. During inference, the logits from both heads are averaged to make the final prediction. Here is a simplified version of the implementation: ```python import torch import torch.nn as nn from timm.models.layers import trunc_normal_ Configuration BATCH_SIZE = 1 IMAGE_SIZE = 384 IN_CHANNELS = 3 PATCH_SIZE = 16 EMBED_DIM = 768 NUM_HEADS = 12 NUM_LAYERS = 12 FFN_SIZE = EMBED_DIM * 4 NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2 NUM_CLASSES = 1000 Patcher class Patcher(nn.Module): def init(self): super().init() self.conv = nn.Conv2d(in_channels=IN_CHANNELS, out_channels=EMBED_DIM, kernel_size=PATCH_SIZE, stride=PATCH_SIZE) self.flatten = nn.Flatten(start_dim=2) def forward(self, x): x = self.conv(x) x = self.flatten(x) x = x.permute(0, 2, 1) return x Transformer Encoder class Encoder(nn.Module): def init(self): super().init() self.norm_0 = nn.LayerNorm(EMBED_DIM) self.multihead_attention = nn.MultiheadAttention(EMBED_DIM, NUM_HEADS, batch_first=True) self.norm_1 = nn.LayerNorm(EMBED_DIM) self.ffn = nn.Sequential( nn.Linear(EMBED_DIM, FFN_SIZE), nn.GELU(), nn.Linear(FFN_SIZE, EMBED_DIM), ) def forward(self, x): residual = x x = self.norm_0(x) x = self.multihead_attention(x, x, x)[0] x = x + residual residual = x x = self.norm_1(x) x = self.ffn(x) x = x + residual return x DeiT Model class DeiT(nn.Module): def init(self): super().init() self.patcher = Patcher() self.class_token = nn.Parameter(torch.zeros(BATCH_SIZE, 1, EMBED_DIM)) self.dist_token = nn.Parameter(torch.zeros(BATCH_SIZE, 1, EMBED_DIM)) trunc_normal_(self.class_token, std=.02) trunc_normal_(self.dist_token, std=.02) self.pos_embedding = nn.Parameter(torch.zeros(BATCH_SIZE, NUM_PATCHES + 2, EMBED_DIM)) trunc_normal_(self.pos_embedding, std=.02) self.encoders = nn.ModuleList([Encoder() for _ in range(NUM_LAYERS)]) self.norm_out = nn.LayerNorm(EMBED_DIM) self.class_head = nn.Linear(EMBED_DIM, NUM_CLASSES) self.dist_head = nn.Linear(EMBED_DIM, NUM_CLASSES) def forward(self, x): x = self.patcher(x) x = torch.cat([self.class_token, self.dist_token, x], dim=1) x = x + self.pos_embedding for encoder in self.encoders: x = encoder(x) x = self.norm_out(x) class_out = x[:, 0] dist_out = x[:, 1] class_out = self.class_head(class_out) dist_out = self.dist_head(dist_out) return class_out, dist_out Test the DeiT model deit = DeiT() x = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE) class_out, dist_out = deit(x) ``` Evaluation by Industry Insiders Industry experts praise DeiT for its innovative approach to reducing the data requirements for training transformer-based models, making them more accessible and practical for real-world applications. The introduction of the distillation token and the novel knowledge distillation mechanism are seen as significant advancements in the field of computer vision. DeiT's architecture, while derived from ViT, includes crucial modifications that enhance its efficiency and performance. Facebook Research, the organization behind the development of DeiT, continues to support and improve the model, making it a valuable tool for researchers and practitioners alike. The authors' work on DeiT has opened new avenues for exploring data-efficient models, particularly in scenarios where large labeled datasets are not readily available. The model's simplicity and effectiveness in leveraging limited data make it a compelling choice for various image recognition tasks, pushing the boundaries of what transformers can achieve in the domain of computer vision.