HyperAI
Back to Headlines

ConvNeXt: How Meta's Tuned CNN Outperforms Vision Transformers on ImageNet

10 hours ago

Researchers from Meta have challenged the notion that traditional Convolutional Neural Networks (CNNs) are obsolete due to the rise of Vision Transformers (ViTs). Their paper, "A ConvNet for the 2020s," introduces a new CNN architecture called ConvNeXt, which achieves performance comparable to or even better than state-of-the-art transformers while maintaining the simplicity and efficiency of CNNs. The Core Thesis and Experiments The primary argument presented by the researchers is that the superior performance of ViTs is not solely due to the transformer architecture itself but also because of advanced hyperparameter tuning and configuration. To validate this, they applied the hyperparameters and configurations used in ViTs to the ResNet architecture, a staple in deep learning since 2015. Initial Hyperparameter Tuning Starting with ResNet-50, a well-known CNN, the team experimented with five key areas: Macro Design: They adjusted the stage ratios in the ResNet architecture to match those of the Swin-T transformer, shifting from 3:4:6:3 to 1:1:3:1. This change improved the model's accuracy from 78.8% to 79.4% on the ImageNet dataset. Input Patching: Inspired by ViTs, they modified the first convolution layer to treat input images as non-overlapping patches using a 4×4 kernel with stride 4, further improving accuracy to 79.5%. ResNeXt-ification: They adopted the ResNeXt architecture's group convolution technique, but this initially led to a slight drop in accuracy to 78.3%. By increasing the network width to match the channel counts used in Swin-T, they managed to boost the accuracy to 80.5%. Inverted Bottleneck: Instead of the traditional wide → narrow → wide bottleneck structure, ConvNeXt uses a narrow → wide → narrow structure, akin to the feed-forward layer in transformers. This resulted in a minor improvement to 80.6%. Kernel Size Exploration: They swapped the order of the first and second layers in the inverted bottleneck block and experimented with different kernel sizes. Using a 7×7 kernel in the depthwise convolution layer restored the accuracy to 80.6% with lower computational complexity. Micro Design Adjustments Further improvements were made at a more granular level: Activation Function: They replaced ReLU with GELU, which kept the accuracy steady at 80.6% but was retained for subsequent experiments. Batch Normalization: Reducing the number of activation functions to only between the two pointwise convolutions increased accuracy to 81.3%, aligning with Swin-T while using fewer GFLOPS. Layer Normalization: Implementing a single batch normalization layer before the first pointwise convolution improved accuracy to 81.4%, surpassing Swin-T. Downsampling: Instead of using strides to downsample, they added a separate convolution layer. Initially, this caused a degradation in accuracy, but adding layer normalization layers resolved the issue, pushing the accuracy to 82.0%. The ConvNeXt Architecture ConvNeXt combines the strengths of traditional CNNs and modern transformers by carefully integrating transformer-specific configurations into the ResNet framework. Here’s a breakdown of the key components: Stem Layer: The first convolution layer treats input images as non-overlapping patches using a 4×4 kernel and stride 4. ConvNeXt Block: Each block follows the inverted bottleneck structure, with a 7×7 depthwise convolution followed by two pointwise convolutions and a GELU activation function. ConvNeXt Block Transition: Transitions between stages use a projection layer to adjust the number of channels and a separate downsampling layer with layer normalization. Implementation Details Stem Layer ```python import torch import torch.nn as nn class ConvNeXtStem(nn.Module): def init(self): super().init() self.stem = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=4, stride=4) self.normstem = nn.LayerNorm(normalized_shape=[96, 56, 56]) def forward(self, x): x = self.stem(x) x = x.permute(0, 2, 3, 1) x = self.normstem(x) x = x.permute(0, 3, 1, 2) return x ``` ConvNeXt Block ```python class ConvNeXtBlock(nn.Module): def init(self, num_channels): super().init() hidden_channels = num_channels * 4 self.conv0 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=7, stride=1, padding=3, groups=num_channels) self.norm = nn.LayerNorm(normalized_shape=[num_channels, 56, 56]) self.conv1 = nn.Conv2d(in_channels=num_channels, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0) self.gelu = nn.GELU() self.conv2 = nn.Conv2d(in_channels=hidden_channels, out_channels=num_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): residual = x x = self.conv0(x) x = x.permute(0, 2, 3, 1) x = self.norm(x) x = x.permute(0, 3, 1, 2) x = self.conv1(x) x = self.gelu(x) x = self.conv2(x) x = x + residual return x ``` ConvNeXt Block Transition ```python class ConvNeXtBlockTransition(nn.Module): def init(self, in_channels, out_channels): super().init() hidden_channels = out_channels * 4 self.projection = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) self.conv0 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=1, padding=3, groups=in_channels) self.norm0 = nn.LayerNorm(normalized_shape=[out_channels, 56, 56]) self.conv1 = nn.Conv2d(in_channels=out_channels, out_channels=hidden_channels, kernel_size=1, stride=1, padding=0) self.gelu = nn.GELU() self.conv2 = nn.Conv2d(in_channels=hidden_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0) self.norm1 = nn.LayerNorm(normalized_shape=[out_channels, 28, 28]) self.downsample = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=2, stride=2) def forward(self, x): residual = self.projection(x) x = self.conv0(x) x = x.permute(0, 2, 3, 1) x = self.norm0(x) x = x.permute(0, 3, 1, 2) x = self.conv1(x) x = self.gelu(x) x = self.conv2(x) x = x.permute(0, 2, 3, 1) x = self.norm1(x) x = x.permute(0, 3, 1, 2) x = self.downsample(x) x = x + residual return x ``` Full Architecture ```python class ConvNeXt(nn.Module): def init(self): super().init() self.stem = ConvNeXtStem() self.res2 = nn.ModuleList([ConvNeXtBlock(96) for _ in range(3)]) self.res3 = nn.ModuleList([ConvNeXtBlockTransition(96, 192)] + [ConvNeXtBlock(192) for _ in range(2)]) self.res4 = nn.ModuleList([ConvNeXtBlockTransition(192, 384)] + [ConvNeXtBlock(384) for _ in range(8)]) self.res5 = nn.ModuleList([ConvNeXtBlockTransition(384, 768)] + [ConvNeXtBlock(768) for _ in range(2)]) self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) self.normpool = nn.LayerNorm(normalized_shape=768) self.fc = nn.Linear(in_features=768, out_features=1000) def forward(self, x): x = self.stem(x) for block in self.res2: x = block(x) for block in self.res3: x = block(x) for block in self.res4: x = block(x) for block in self.res5: x = block(x) x = self.avgpool(x) x = x.permute(0, 2, 3, 1) x = self.normpool(x) x = x.permute(0, 3, 1, 2) x = x.reshape(x.shape[0], -1) x = self.fc(x) return x ``` Evaluation and Industry Insights Industry experts have praised ConvNeXt for its ability to leverage transformer-like configurations while retaining the efficiency and simplicity of CNNs. This hybrid approach demonstrates that CNNs, when optimized, can still compete with the latest transformer models. The success of ConvNeXt underscores the importance of architectural flexibility and the potential for further advancements in CNN performance through innovative hyperparameter tuning and design adjustments. Company Profile Meta, formerly known as Facebook, is a leading technology company known for its contributions to AI, particularly in computer vision and natural language processing. The development of ConvNeXt reflects Meta's commitment to advancing the state of the art in deep learning and exploring new avenues for improving model performance. This article provides a beginner-friendly introduction to the ConvNeXt architecture and a step-by-step guide to implementing it in PyTorch. For a detailed and official implementation, refer to the GitHub repository maintained by Meta.

Related Links

Towards Data Science