DeiT: Making Vision Transformers More Data-Efficient with Knowledge Distillation
Overview of DeiT (Data-efficient Image Transformer) The Vision Transformer (ViT) introduced in a 2020 paper by Dosovitskiy et al. demonstrated remarkable success in image recognition tasks but required an enormous amount of labeled training data, specifically 300 million images from the JFT-300M dataset. This data-intensive requirement posed significant challenges for practical applications. To address this issue, Hugo Touvron and colleagues published "Training Data-Efficient Image Transformers & Distillation Through Attention" in December 2020, introducing Data-efficient Image Transformer (DeiT). Key Innovations in DeiT Knowledge Distillation The primary innovation in DeiT is the use of knowledge distillation, a technique where a smaller model (the student) learns from a larger, pre-trained model (the teacher). In this context, the teacher model is RegNet, a Convolutional Neural Network (CNN) known for its efficiency. By leveraging the teacher's knowledge, DeiT significantly reduces the amount of training data required. For instance, DeiT can achieve comparable accuracy to ViT using just 1 million images from the ImageNet-1K dataset, a 300x reduction compared to ViT's data needs. Modified Architecture To better facilitate knowledge distillation, Touvron et al. modified the ViT architecture. They introduced a new element called the distillation token alongside the existing class token. The class token is responsible for the final classification decision, while the distillation token helps the model learn from the teacher during training. These tokens are appended to the sequence of image patches and processed together through the transformer layers. DeiT Variants Three variants of DeiT were proposed: DeiT-Ti (Tiny), DeiT-S (Small), and DeiT-B (Base). The Base variant, DeiT-B, is equivalent in size to the smallest ViT variant, ViT-B, which highlights DeiT's focus on computational efficiency. DeiT-B contains 86 million trainable parameters and is configured to map each image patch to a 768-dimensional tensor, which is then processed through 12 attention heads and 12 transformer encoder layers. Implementation Details Treating Images as Sequences of Patches In DeiT, images are divided into non-overlapping patches using a convolutional layer. Each patch is then flattened and converted into a sequence of patch tokens. For a 384×384 image with a patch size of 16×16, this results in 576 patches, each represented as a 768-dimensional vector. Transformer Encoder Layer The transformer encoder layer in DeiT is similar to that in ViT, comprising a multihead attention layer, a feed-forward network (FFN), and normalization layers. However, the inclusion of the distillation token requires adjustments in the forward pass and loss function. The distillation token and class token are concatenated with the patch tokens and then passed through the positional embedding layer, ensuring that all tokens receive context-specific information. Forward Pass Patching: Convert the input image into a sequence of patches. Token Concatenation: Append the class and distillation tokens to the patch token sequence. Positional Embedding: Add positional embeddings to the concatenated token sequence. Encoder Stack: Pass the sequence through a stack of 12 transformer encoder layers. Normalization: Normalize the output of the last encoder layer. Head Extraction: Extract the class and distillation tokens. Output Processing: Process the extracted tokens through linear layers to produce the final classification and distillation outputs. Experimental Results Experiments conducted with various models, including EfficientNet, ViT, and DeiT, on the ImageNet-1K dataset showed that DeiT, even without additional data, outperformed both ViT and EfficientNet. When the novel distillation mechanism was applied and the model was fine-tuned using upscaled images (384×384), the performance improved further, confirming DeiT's superiority in data-limited scenarios. Implementation Example Here is a simplified implementation of the DeiT-B architecture: ```python import torch import torch.nn as nn from timm.models.layers import trunc_normal_ from torchinfo import summary Configuration parameters BATCH_SIZE = 1 IMAGE_SIZE = 384 IN_CHANNELS = 3 PATCH_SIZE = 16 EMBED_DIM = 768 NUM_HEADS = 12 NUM_LAYERS = 12 FFN_SIZE = EMBED_DIM * 4 NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2 NUM_CLASSES = 1000 Patcher class class Patcher(nn.Module): def init(self): super().init() self.conv = nn.Conv2d( in_channels=IN_CHANNELS, out_channels=EMBED_DIM, kernel_size=PATCH_SIZE, stride=PATCH_SIZE ) self.flatten = nn.Flatten(start_dim=2) def forward(self, x): x = self.conv(x) x = self.flatten(x) x = x.permute(0, 2, 1) return x Encoder class class Encoder(nn.Module): def init(self): super().init() self.norm_0 = nn.LayerNorm(EMBED_DIM) self.multihead_attention = nn.MultiheadAttention(EMBED_DIM, num_heads=NUM_HEADS, batch_first=True) self.norm_1 = nn.LayerNorm(EMBED_DIM) self.ffn = nn.Sequential( nn.Linear(in_features=EMBED_DIM, out_features=FFN_SIZE), nn.GELU(), nn.Linear(in_features=FFN_SIZE, out_features=EMBED_DIM) ) def forward(self, x): residual = x x = self.norm_0(x) x = self.multihead_attention(x, x, x)[0] x = x + residual residual = x x = self.norm_1(x) x = self.ffn(x) x = x + residual return x DeiT class class DeiT(nn.Module): def init(self): super().init() self.patcher = Patcher() self.class_token = nn.Parameter(torch.zeros(BATCH_SIZE, 1, EMBED_DIM)) self.dist_token = nn.Parameter(torch.zeros(BATCH_SIZE, 1, EMBED_DIM)) trunc_normal_(self.class_token, std=.02) trunc_normal_(self.dist_token, std=.02) self.pos_embedding = nn.Parameter(torch.zeros(BATCH_SIZE, NUM_PATCHES + 2, EMBED_DIM)) trunc_normal_(self.pos_embedding, std=.02) self.encoders = nn.ModuleList([Encoder() for _ in range(NUM_LAYERS)]) self.norm_out = nn.LayerNorm(EMBED_DIM) self.class_head = nn.Linear(in_features=EMBED_DIM, out_features=NUM_CLASSES) self.dist_head = nn.Linear(in_features=EMBED_DIM, out_features=NUM_CLASSES) def forward(self, x): x = self.patcher(x) x = torch.cat([self.class_token, self.dist_token, x], dim=1) x = x + self.pos_embedding for encoder in self.encoders: x = encoder(x) x = self.norm_out(x) class_out = x[:, 0] dist_out = x[:, 1] class_out = self.class_head(class_out) dist_out = self.dist_head(dist_out) return class_out, dist_out Testing the DeiT model deit = DeiT() x = torch.randn(BATCH_SIZE, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE) class_out, dist_out = deit(x) print("Classification Output:", class_out.size()) print("Distillation Output:", dist_out.size()) ``` Industry Insights and Evaluation Industry experts have praised DeiT for its ability to make transformer-based models more accessible and practical for real-world applications. By reducing the data requirements and improving efficiency, DeiT opens up possibilities for deploying advanced vision models in settings with limited resources. Facebook Research, the team behind DeiT, emphasizes the model's significance in advancing data efficiency in deep learning. Conclusion DeiT represents a significant step forward in making vision transformers practical for a broader range of applications. Its innovative use of knowledge distillation and modified architecture not only reduces the data requirements but also maintains high accuracy, making it a valuable tool in the field of computer vision. For readers interested in delving deeper into the specifics, the DeiT paper and the official GitHub repository offer comprehensive resources. DeiT's introduction has been a pivotal development in the field of deep learning, particularly for image recognition tasks. By combining the strengths of transformers and knowledge distillation, it addresses the data inefficiencies of traditional ViT, thereby broadening the scope of its application in resource-constrained environments. Facebook Research, known for its contributions to AI, has provided extensive documentation and support for DeiT, reinforcing its credibility and utility.