FlashAttention 是一种高效、内存友好型的注意力算法,由斯坦福大学联合纽约州立大学在 2022 年提出,旨在解决传统 Transformer 模型中自注意力 (Self-Attention) 层的高计算复杂度和显存占用问题。相关论文成果为「FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness」。该算法已被集成到 PyTorch 2.0 中,并被多个开源框架如 triton 、 xformer 等整合实现。它通过重新排序注意力计算,利用 tiling(平铺)和重计算技术显著加快了计算速度,并将内存使用量从序列长度的二次方降低到线性关系。
FlashAttention 的提出,使得像 Meta 的 LLaMA 和阿联酋推出的 Falcon 等开源大模型能够加速计算和节省显存。此外,FlashAttention 的后续版本 FlashAttention-2 在原有基础上进行了改进,提供了更好的并行性和工作分区,由 Tri Dao 在 2023 年 7 月通过论文「FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning」提出。
FlashAttention-3 由 Colfax Research 、 Meta 、 NVIDIA 、 Georgia Tech 、 Princeton University 和 Together AI 的研究团队于 2024 年 7 月联合提出,相关论文为「FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision」。作为系列的最新版本,它在 H100 GPU 上实现了显著的性能提升,速度是 FlashAttention-2 的 1.5-2.0 倍,最高可达 740 TFLOPS,即 H100 理论最大 FLOPS 利用率为 75%,并且在使用 FP8 时接近 1.2 PFLOPS 75 。这些改进使得 LLM 的训练和运行速度大幅提升,同时能够在保持精度的同时使用较低精度的数字 (FP8) ,从而可能降低内存使用量并节省成本。