FlexAttention 是由 PyTorch 团队于 2024 年 7 月公开的一个新 API,它提供了一个灵活的接口,允许在几行典型的 PyTorch 代码中实现许多注意力变体,并通过 torch.compile
将其降低到一个融合的 FlashAttention 内核,从而在不牺牲性能的同时提供灵活性。相关论文成果为「FlexAttention for Efficient High-Resolution Vision-Language Models」,已被 ECCV 2024 接受。
FlexAttention 是一种灵活的注意力机制,旨在提高高分辨率视觉语言模型的效率。该机制通过编码高分辨率和低分辨率的图像标记,并仅使用低分辨率标记和少数选定的高分辨率标记来计算注意力图,从而显著降低了计算成本。高分辨率标记的选择是通过一个高分辨率选择模块进行的,该模块可以根据输入的注意力图检索相关区域的标记。然后,选定的高分辨率标记与低分辨率标记和文本标记一起输入到分层自注意力层,该层生成的注意力图将用于下一步的高分辨率标记选择。这个过程在每个注意力层迭代进行。实验表明,FlexAttention 在多模态基准测试中优于现有的高分辨率视觉语言模型,同时显著减少了近 40% 的计算成本。