FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
解决问题:该论文旨在解决在扩展Transformer模型以处理更长序列时,由于注意力层的运行时间和内存消耗呈二次增长,导致性能下降的问题。
关键思路:FlashAttention通过利用GPU内存层次结构的不对称性,实现了线性内存节省和2-4倍的运行时加速,同时没有使用近似方法。然而,FlashAttention仍然比优化后的矩阵乘法操作慢得多,只能达到理论最大FLOPs/s的25-40%。FlashAttention-2通过优化工作分区,解决了FlashAttention的低效率问题。具体来说,FlashAttention-2 (1) 调整算法以减少非矩阵乘法FLOPs数量 (2) 并行计算注意力,即使是单个头部,也跨不同的线程块以增加占用率,并且 (3) 在每个线程块内,将工作分配给线程束以减少通过共享内存的通信。这些改进使得FlashAttention-2相对于FlashAttention的速度提高了约2倍,在A100上达到理论最大FLOPs/s的50-73%,接近于GEMM操作的效率。作者通过实验证明,当用于端到端训练GPT-style模型时,FlashAttention-2的训练速度可达每个A100 GPU的225 TFLOPs/s(72%模型FLOPs利用率)。
其他亮点:该论文的实验结果表明,FlashAttention-2在处理长序列时具有更好的性能和效率,有望在语言建模和高分辨率图像理解等领域得到广泛应用。此外,该论文还开源了代码,可供研究者使用和参考。
关于作者:Tri Dao是Facebook AI Research的研究员,他的研究方向主要是深度学习和自然语言处理。他之前的代表作包括“Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer”等。
相关研究:其他相关的研究包括“Longformer: The Long-Document Transformer”(Iz Beltagy等,Allen Institute for AI),“Reformer: The Efficient Transformer”(Nikita Kitaev等,Google Research)等。这些研究都探索了如何扩展Transformer模型以处理更长的序列,并提出了不同的解决方案。
论文摘要:本文讨论了在过去几年中,将Transformer扩展到更长的序列长度是一个主要问题,扩展后可以提高语言建模和高分辨率图像理解的性能,同时也可以解锁代码、音频和视频生成的新应用。注意力层是扩展到更长序列的主要瓶颈,因为它的运行时间和内存随着序列长度呈二次增加。FlashAttention利用不对称的GPU内存层次结构带来了显著的内存节省(线性而不是二次)和运行时加速(与优化基线相比提高2-4倍),而不需要近似。然而,FlashAttention仍然远远不如优化的矩阵乘法(GEMM)操作快,只达到了理论最大FLOPs/s的25-40%。我们观察到,低效率是由于GPU上不同线程块和warp之间的工作分配不够优化,导致低占用率或不必要的共享内存读/写。我们提出了FlashAttention-2,通过更好的工作分配来解决这些问题。具体来说,我们(1)调整算法以减少非矩阵乘法FLOPs的数量(2)并行化注意力计算,即使是单个头,也可以跨不同的线程块增加占用率,以及(3)在每个线程块内,将工作分配给warp以减少通过共享内存的通信。这些方法相比FlashAttention可以实现大约2倍的加速,A100上可以达到理论最大FLOPs/s的50-73%,接近GEMM操作的效率。我们通过实验证明,当用于端到端训练GPT-style模型时,FlashAttention-2可以达到每个A100 GPU的225 TFLOPs/s的训练速度(72%的模型FLOPs利用率)。
内容中包含的图片若涉及版权问题,请及时与我们联系删除
评论
沙发等你来抢