Striped Attention: Faster Ring Attention for Causal Transformers

2023年11月15日
  • 简介
    为了应对转换器模型中越来越长的序列长度的不断增长的需求,Liu等人最近提出了环形注意力(Ring Attention),这是一种精确的注意力算法,能够通过在多个设备之间分配自注意力来克服每个设备内存瓶颈的问题。在本文中,我们研究了环形注意力在因果变换器模型的重要特殊情况下的性能特征,并确定了由因果关注计算的三角形结构引起的关键工作负载不平衡。我们提出了一个简单的环形注意力扩展,称为条纹注意力(Striped Attention),以解决这种不平衡。每个设备不再拥有连续的子序列,而是拥有均匀分布在整个序列中的令牌子集,我们证明这样可以导致更均匀的工作负载。在A100 GPU和TPUv4上运行条纹注意力的实验中,我们能够在序列长度为256k时实现高达1.45倍的端到端吞吐量改进,比原始的环形注意力算法更快。此外,在16个TPUv4芯片上,我们在序列长度为786k时能够实现1.65倍的加速。我们将我们的实验代码作为开源发布。
  • 作者讲解
  • 图表
  • 解决问题
    本文试图解决在transformer模型中,由于序列长度不断增加导致的内存瓶颈问题,提出了Ring Attention算法。本文研究Ring Attention在因果transformer模型中的性能特征,并发现了由于因果注意力计算的三角形结构而导致的工作负载不平衡问题。
  • 关键思路
    本文提出了一种名为Striped Attention的简单扩展,可以解决因果transformer模型中的工作负载不平衡问题,通过将每个设备的令牌子集均匀分布在整个序列中来实现。
  • 其它亮点
    本文在A100 GPU和TPUv4上运行Striped Attention算法,与原始的Ring Attention算法相比,能够在序列长度为256k时实现高达1.45倍的端到端吞吐量提升,在16个TPUv4芯片上,能够在序列长度为786k时实现1.65倍的加速。本文的代码已经开源。
  • 相关研究
    最近在这个领域中,还有一些相关研究。例如,Liu等人提出了Ring Attention算法,用于解决transformer模型中的内存瓶颈问题。还有一些研究关注于改进transformer模型的性能,例如,Vaswani等人提出了self-attention机制,用于计算输入序列中所有位置之间的相关性。
许愿开讲
PDF
原文
点赞 收藏
向作者提问
NEW
分享到Link

提问交流

提交问题,平台邀请作者,轻松获得权威解答~

向作者提问