MIT本科生从零构建视觉Transformer:详解图像预测的新时代模型
从图像到预测:构建图像变换器 MIT的一位大二学生详细讲述了如何从头开始实现图像变换器(Vision Transformer,简称ViT)的过程,探讨了其架构背后的原理。2020年,谷歌研究团队在论文《An Image Is Worth 16×16 Words》中首次提出了ViT,这是一个大胆的想法:将图像处理为句子,通过将图像分割成若干个16x16像素的小块,然后输入标准的Transformer模型中进行处理。这一创新开启了深度学习的新时代,尤其是在计算机视觉领域。 背景与灵感 1.1 从递归模型到Transformer崛起(NLP领域的变革) 2017年之前,自然语言处理(NLP)主要依赖于循环神经网络(RNN)和长短时记忆网络(LSTM),但由于这些模型只能逐词处理序列数据,难以并行化训练,且处理长序列时容易忘记前面的信息,限制了其规模化应用。2017年,Google的研究人员在《Attention Is All You Need》论文中提出了一种新架构——Transformer,通过自注意力机制解决了这些问题。从此,Transformer在NLP领域迅速崛起,成为了BERT、GPT和T5等重要模型的基础。 1.2 自注意力机制能否替代卷积? 尽管在NLP中取得了巨大成功,但那时计算机视觉领域仍由卷积神经网络(CNN)主导。CNN虽然强大,但在捕捉远距离特征方面存在不足,且需要大量的手工特征工程。因此,一个自然而然的问题出现了:如果自注意力机制能够在NLP中替代递归模型,那么它是否也能在计算机视觉领域替代卷积? 2020年,Dosovitskiy等人提出了ViT,将图像处理为序列,通过自注意力机制来理解图像中的不同区域之间的关系。早期的ViT在小数据集上的性能不佳,但这一方法证明了基于注意力的模型同样适用于图像处理。 ViT的工作原理 2.1 图像分块嵌入 Transformer最初设计用于处理1D序列,如句子。而图像则是2D像素网格。ViT通过将图像分成多个不重叠的小块(例如16x16像素),然后将每个小块展平为1D向量,并线性投影到固定大小的嵌入空间中。这样,一幅图像就变成了一个序列的嵌入表示。 2.2 分类标记和位置嵌入 为了使模型能够正确处理2D图像,还需引入两个额外的部分:一个特殊的[CLS]标记用于汇聚全局信息,以及位置嵌入用于编码图像的空间结构。[CLS]标记被附加到输入序列的开始,通过自注意力机制汇聚所有图像小块的信息。位置嵌入则为每一块嵌入添加一个唯一的空间位置信息,使其能够理解顺序。 2.3 多头自注意力机制 多头自注意力机制是ViT的核心。与单一的注意力函数不同,多头自注意力机制将输入拆分为多个“头”,每个头学习关注不同的输入特征,如边缘、纹理或空间布局。最后,这些“头”的输出被拼接并投影回原始嵌入空间。通过这种方式,模型可以并行地理解复杂的图像关系,不仅限于邻近区域,还包括语义结构。 2.4 Transformer编码器 每个Transformer编码器块结合了自注意力机制和前馈神经网络(MLP),并通过残差连接和归一化层保持模型的稳定性。这个块是ViT的基本构建单元,使得模型能够在全球范围内推理并转换特征。 2.5 分类头 在经过多个Transformer编码器块之后,模型需要一个分类头来进行最终预测。分类头主要利用[CLS]标记的最终嵌入向量作为整个图像的摘要表示,并通过一个简单的线性层输出类别概率。 实现过程 3.1 图像分块嵌入实现 通过一个巧妙的Conv2d层,可以一次性提取和投影非重叠的图像小块,得到形状为[B, num_patches, embed_dim]的嵌入序列,其中B为批量大小,num_patches为小块数量,embed_dim为嵌入维度。 3.2 分类标记和位置嵌入实现 定义了一个ViTEmbed模块,用于添加[CLS]标记和位置嵌入。最终得到的序列形状为[B, num_patches + 1, embed_dim],准备好进入Transformer编码器。 3.3 多头自注意力机制实现 通过将每个输入标记线性投影为查询(Q)、键(K)和值(V)向量,然后在多个“头”上并行计算注意力得分,并将输出拼接后投影回原始嵌入维度,实现了多头自注意力机制。 3.4 Transformer编码器实现 Transformer编码器块将自注意力机制和MLP层组合在一起,通过残差连接和层归一化来保持稳定。最终构建了一个完整的Transformer块。 3.5 组装完整模型 综合以上各个部分,实现了一个完整的SimpleViT模型,包括图像分块嵌入、位置嵌入、多头自注意力机制、Transformer编码器块以及分类头。 训练ViT 4.1 数据集:CIFAR-10 我们使用了众所周知的CIFAR-10基准数据集,该数据集包含60,000张32x32像素的图像,共10个类别(如飞机、猫、船等)。 4.2 模型配置:适应CIFAR-10 为了在CIFAR-10这样的小数据集和有限计算资源下进行有效训练,我们调整了SimpleViT的配置,包括使用4x4的小块(生成64个标记)、192维的嵌入空间、6个Transformer编码器块和3个注意力头。 4.3 训练设置 模型使用PyTorch进行训练,大约每30秒完成一个epoch,总训练时间为15分钟左右。训练结果显示,经过30个epoch后,模型在CIFAR-10测试集上达到了约60%的准确率,对于简单模型来说,这是一个不错的基线性能。测试结果显示,模型在区分视觉相似的类别上仍有困难,但对于大多数样本已经能够正确识别。 行业评价与公司背景 ViT的提出标志着计算机视觉领域的一个重大突破。许多业内人士认为,这是一种非常有前景的方法,尤其是对于处理大规模和复杂图像的任务。谷歌作为深度学习研究的先驱,一直致力于推动前沿技术的发展,ViT的推出再次展示了其在人工智能领域的领导地位。MIT的学生通过从头实现ViT,加深了对这一模型的理解,也为进一步研究奠定了基础。