多头自注意力机制在自然语言处理和视觉任务中取得了最先进的性能。然而,它们对序列长度的二次依赖性使得推理速度受到瓶颈的限制。为了避免这种瓶颈,研究人员提出了各种稀疏多头自注意力模型,其中计算了部分全局注意力。尽管如此,当前的稀疏库和编译器并不支持多样化的稀疏多头自注意力模式的高性能实现,因为它们所操作的稀疏格式通常是为高性能和科学计算应用而设计的,要么是针对极端数量级的随机稀疏性(<1%的非零值),要么是针对特定的稀疏模式。然而,稀疏多头自注意力中的稀疏模式是中等稀疏的(10-50%的非零值)且多样化的,因此现有的稀疏格式在性能和通用性之间进行权衡。
我们通过提出一种新的稀疏格式——仿射压缩稀疏行(ACSR)和支持代码生成方案SPLAT来弥合这一差距,既实现了通用性,又实现了性能。我们提议的格式和代码生成算法的核心观察是,常见的稀疏多头自注意力模式具有独特的规则几何特性。这些特性可以在运行时分析,暴露出新颖的优化和分块策略,SPLAT利用这些策略为多样化的模式生成高性能的实现。为了证明SPLAT的有效性,我们使用它为各种稀疏多头自注意力模型生成代码,在A100 GPU上相对于手写的triton和TVM内核,实现了2.05倍和4.05倍的几何平均加速。此外,SPLAT的接口直观易用,与JAX中现有的MHSA实现配合使用。
提问交流