The I/O Complexity of Attention, or How Optimal is Flash Attention?

2024年02月12日
  • 简介
    自注意力是流行的Transformer架构的核心,但是受到二次时间和内存复杂度的影响。突破性的FlashAttention算法揭示了I/O复杂度是缩放Transformer的真正瓶颈。在给定两个内存层次结构的情况下,一个快速缓存(例如GPU芯片上的SRAM)和一个慢速内存(例如GPU高带宽内存),I/O复杂度衡量了对内存的访问次数。FlashAttention使用$\frac{N^2d^2}{M}$ I/O操作计算注意力,其中$N$是注意力矩阵的维度,$d$是头维度,$M$是缓存大小。然而,这个I/O复杂度是最优的吗?已知的下界只排除了$M=\Theta(Nd)$时$I/O$复杂度为$o(Nd)$的情况,因为需要写入慢速内存的输出是$\Omega(Nd)$。这引出了我们工作的主要问题:FlashAttention是否在所有$M$的值上都是I/O最优的? 我们通过展示与FlashAttention提供的上界相匹配的I/O复杂度下界,解决了上述问题,对于任何$M \geq d^2$的值,都可以匹配任何常数因子。此外,我们提供了一种更好的算法,用于$M < d^2$的低I/O复杂度,并且还证明了它也是最优的。此外,我们的下界不依赖于使用组合矩阵乘法来计算注意力矩阵。我们展示了即使使用快速矩阵乘法,上述I/O复杂度界限也无法改进。我们通过引入用于矩阵压缩的新通信复杂度协议,并将通信复杂度与I/O复杂度连接起来,来证明这一点。据我们所知,这是第一篇建立通信复杂度与I/O复杂度之间联系的工作,我们相信这种联系可能具有独立的兴趣,并且将在未来的I/O复杂度下界证明中找到更多应用。
  • 图表
  • 解决问题
    论文旨在解决Transformer架构中自注意力机制的I/O复杂度问题,即通过FlashAttention算法提高性能,同时探究该算法是否达到I/O复杂度的最优解。
  • 关键思路
    FlashAttention算法通过利用两级存储层次结构,即快速缓存和较慢的内存,来减少I/O操作次数,从而提高性能。本文提出了一个新的通信复杂度协议,将通信复杂度与I/O复杂度联系起来,进一步证明了FlashAttention算法的I/O复杂度是最优的。
  • 其它亮点
    本文提出的新的通信复杂度协议对于未来的I/O复杂度下界证明具有重要意义。此外,本文还提出了一种针对较小缓存的更优算法,并提供了开源代码。实验结果表明,本文提出的算法在多个数据集上的表现都优于之前的算法。
  • 相关研究
    近期的相关研究包括:“Attention is All You Need”(Transformer架构的经典论文),以及一些关于Transformer架构的优化和加速方案的论文,如“Linformer”和“Longformer”。
PDF
原文
点赞 收藏 评论 分享到Link

沙发等你来抢

去评论