Logo
热心市民王先生

架构层解决方案 - Sparse Attention与Ring Attention

Sparse Attention Ring Attention 分布式训练 模型架构

系统分析降低Transformer注意力复杂度的架构方案,包括Sparse Attention变体、Ring Attention分布式计算、Sliding Window优化及MQA/GQA技术。

Sparse Attention:从二次到线性复杂度

稀疏注意力核心思想

Sparse Attention通过限制每个query token只关注部分key token,将注意力复杂度从O(n2)O(n^2)降低到O(nk)O(n \cdot k)O(nlogn)O(n \cdot \log n)。关键在于设计合理的稀疏模式,在降低计算量的同时保持模型能力。

主流稀疏模式对比

flowchart LR
    subgraph 全注意力["Full Attention O(n²)"]
        F1["●●●●"]
        F2["●●●●"]
        F3["●●●●"]
        F4["●●●●"]
    end
    
    subgraph 局部滑动["Sliding Window O(n×w)"]
        S1["●●○○"]
        S2["●●●○"]
        S3["○●●●"]
        S4["○○●●"]
    end
    
    subgraph 膨胀注意力["Dilated Attention O(n×w/d)"]
        D1["●○●○"]
        D2["○●○●"]
        D3["●○●○"]
        D4["○●○●"]
    end
    
    subgraph 全局局部["Global+Local O(n×(g+w))"]
        G1["●●●●"]
        G2["●●○○"]
        G3["●○●○"]
        G4["●○○●"]
    end

Longformer:局部窗口+全局注意力

Longformer采用滑动窗口局部注意力预定义全局注意力的组合:

技术细节

  • 每个token关注左右各ww个邻居(通常w=512w=512
  • 预设部分token为”全局”token(如[CLS]、段落首句),关注所有位置
  • 复杂度:O(n(2w+g))O(n \cdot (2w + g)),其中gg为全局token数量

性能表现(在4K长度文本分类任务上):

配置准确率计算量(相对)
Full Attention89.2%100%
Longformer (w=256)88.7%12%
Longformer (w=512)89.1%25%
Longformer (w=1024)89.2%50%

优势:实现简单,可在现有Transformer上快速部署 局限:全局token的选择依赖任务先验,通用性受限

BigBird:随机+窗口+全局

BigBird通过三种注意力模式的组合实现高效长序列建模:

  1. 随机注意力:每个token随机关注rr个位置(提供全局连接)
  2. 滑动窗口:局部连续关注(捕获局部依赖)
  3. 全局token:部分token关注所有位置(作为信息枢纽)

理论保证: BigBird论文证明了这种稀疏模式是图灵完备的——可以表达任何可计算函数,这是早期稀疏注意力方案(如Longformer)不具备的理论特性。

实际性能

  • 在64K长度文档级任务上,BigBird达到Full Attention 96%的性能
  • 训练速度提升8-10倍,显存占用降低60%

Sparse Transformer:跨步注意力

Sparse Transformer采用跨步模式(Strided Pattern)

  • 每层交替使用行注意力(row-wise)和列注意力(column-wise)
  • 行注意力关注同一行内的token
  • 列注意力关注间隔为dd的token(如每隔64个位置)
graph TD
    subgraph "第1层: 行注意力"
        A1[1] --- B1[2] --- C1[3] --- D1[4]
        E1[5] --- F1[6] --- G1[7] --- H1[8]
    end
    
    subgraph "第2层: 列注意力"
        A2[1] -.-> A3[5]
        B2[2] -.-> B3[6]
        C2[3] -.-> C3[7]
        D2[4] -.-> D3[8]
    end

Ring Attention:突破单卡显存限制

核心创新

Ring Attention由UC Berkeley于2023年提出,通过**块级计算(Blockwise Computation)环形通信(Ring Communication)**实现超长序列的分布式训练与推理。

关键洞察: 标准注意力的softmax可以分解为块级计算:

softmax(QKT)V=iexp(QKiTm)Viiexp(QKiTm)\text{softmax}(QK^T)V = \frac{\sum_i \exp(QK_i^T - m)V_i}{\sum_i \exp(QK_i^T - m)}

其中mm为全局最大值,可以在线计算并传播。

工作原理

sequenceDiagram
    participant C1 as 计算节点1<br/>Blocks 1-4
    participant C2 as 计算节点2<br/>Blocks 5-8
    participant C3 as 计算节点3<br/>Blocks 9-12
    participant C4 as 计算节点4<br/>Blocks 13-16
    
    Note over C1,C4: 初始化: 每个节点持有1/4序列
    
    loop 环形迭代
        C1->>C2: 发送KV Block
        C2->>C3: 发送KV Block
        C3->>C4: 发送KV Block
        C4->>C1: 发送KV Block
        
        Note over C1: 累积局部Attention结果
        Note over C2: 累积局部Attention结果
        Note over C3: 累积局部Attention结果
        Note over C4: 累积局部Attention结果
    end
    
    Note over C1,C4: All-Reduce: 汇总全局Softmax归一化因子

计算-通信重叠

Ring Attention的核心优势在于计算与通信的重叠

  • 每个节点在计算当前块的attention时,同时接收下一个块的KV
  • 通信延迟被隐藏在计算过程中
  • 理想情况下,扩展效率接近线性

性能数据(在256个A100上训练LLaMA-2-7B):

序列长度训练时间/步显存/节点扩展效率
4K2.1s12 GB100%
32K3.8s14 GB95%
128K6.2s18 GB92%
1M18.5s32 GB88%

与其他方案的对比

xychart-beta
    title "不同方案的最大支持长度与复杂度"
    x-axis "最大序列长度 (log scale)" 4096 --> 1048576
    y-axis "计算复杂度 (相对值)" 1 --> 1000
    
    line [1, 4, 16, 64, 256, 1024]
    line [1, 2, 4, 8, 16, 32]
    line [1, 1.2, 1.5, 2, 2.5, 3]
    
    annotation "Full Attention" at 1000000, 256
    annotation "Sparse Attention" at 1000000, 32
    annotation "Ring Attention" at 1000000, 2.5

关键结论

  • Ring Attention在1M长度上的复杂度仅为Full Attention的0.3%
  • 相比Sparse Attention,Ring Attention保持精确的Full Attention语义,无信息损失

Sliding Window Attention:平衡效率与性能

实现机制

Sliding Window Attention(SWA)是Mistral-7B等模型采用的高效注意力方案:

  • 每个token只关注最近的ww个token(如w=4096w=4096
  • 在Transformer层间交替滑动窗口方向(左→右→左)
  • 部分层保留全局注意力(如每4层中的1层)

优势

  • 计算复杂度:O(nw)O(n \cdot w),与序列长度无关
  • 实现简单:仅需修改attention mask
  • 硬件友好:内存访问模式连续,利于缓存

局限

  • 最大依赖距离受限为ww
  • 需要足够的Transformer层数传递远距离信息

Mistral-7B的实际表现

任务类型窗口大小准确率延迟(相对)
短文本分类4K91.2%1.0x
文档问答4K78.5%1.2x
代码生成4K85.3%1.1x
长文档摘要32K82.1%2.8x

观察:在32K任务上使用4K滑动窗口+4层全局注意力的配置,达到82.1%的准确率,而计算成本仅为Full Attention的35%。

MQA与GQA:降低KV Cache开销

Multi-Query Attention (MQA)

MQA由Shazeer于2019年提出,核心思想是所有query共享同一组key和value

MQA(Qh,K,V)=softmax(QhKTdk)V\text{MQA}(Q_h, K, V) = \text{softmax}\left(\frac{Q_h K^T}{\sqrt{d_k}}\right)V

其中QhQ_h为第hh个head的query,KKVV为共享的key/value。

显存节省

  • 标准MHA:2nheadsdheadnseq2 \cdot n_{\text{heads}} \cdot d_{\text{head}} \cdot n_{\text{seq}}
  • MQA:21dmodelnseq2 \cdot 1 \cdot d_{\text{model}} \cdot n_{\text{seq}}
  • 对于32头模型,MQA减少KV Cache 96.9%

性能影响: 在PaLM-540B上的实验显示,MQA相比MHA仅造成0.5%的困惑度上升,但推理速度提升2.3倍

Grouped-Query Attention (GQA)

GQA是MHA和MQA的折中方案:

  • nheadsn_{\text{heads}}个query head分为gg组(如g=8g=8
  • 每组共享一组key/value
  • g=nheadsg=n_{\text{heads}}时退化为MHA,g=1g=1时为MQA
flowchart LR
    subgraph MHA["Multi-Head Attention (32 heads)"]
        Q1[Q1] --> K1[K1]
        Q1 --> V1[V1]
        Q2[Q2] --> K2[K2]
        Q2 --> V2[V2]
        Q3[Q3] --> K3[K3]
        Q3 --> V3[V3]
        Qn[...] --> Kn[...]
        Qn --> Vn[...]
    end
    
    subgraph GQA["Grouped-Query Attention (32Q/8KV)"]
        GQ1[Q1-4] --> GK1[K1]
        GQ1 --> GV1[V1]
        GQ2[Q5-8] --> GK2[K2]
        GQ2 --> GV2[V2]
        GQ3[Q9-12] --> GK3[K3]
        GQ3 --> GV3[V3]
        GQn[...] --> GKn[...]
        GQn --> GVn[...]
    end

LLaMA-2的GQA配置

  • 7B模型:MHA(性能优先)
  • 13B模型:MHA
  • 70B模型:GQA-8(平衡性能与效率)

在长上下文场景下,GQA-8相比MHA减少KV Cache 87.5%,而困惑度仅增加0.3%

方案对比与选型建议

综合对比表

方案计算复杂度显存需求实现难度适用场景代表模型
Full AttentionO(n2)O(n^2)极高简单短文本(<4K)GPT-4
LongformerO(nw)O(n \cdot w)中等文档理解Longformer
BigBirdO(n(w+r+g))O(n \cdot (w+r+g))复杂超长文档BigBird
Ring AttentionO(n2/p)O(n^2/p)复杂训练>128K-
Sliding WindowO(nw)O(n \cdot w)简单平衡方案Mistral
MQAO(n2)O(n^2)极低简单极致压缩PaLM
GQAO(n2)O(n^2)很低简单推荐方案LLaMA-2-70B

部署建议

  1. 训练阶段(需支持>128K):

    • 首选:Ring Attention + GQA
    • 备选:Sparse Attention(BigBird)+ GQA
  2. 推理阶段(8K-128K):

    • 首选:Sliding Window + GQA
    • 备选:YaRN外推 + GQA
  3. 资源受限场景(<24GB显存):

    • 必须:MQA或GQA + KV Cache量化
    • 推荐:4-bit KV Cache压缩

参考资料

  1. Longformer: The Long-Document Transformer (Beltagy et al., 2020)

    • 滑动窗口+全局注意力的开创性工作
  2. Big Bird: Transformers for Longer Sequences (Zaheer et al., 2020)

    • 随机+窗口+全局的稀疏模式,理论完备性证明
  3. Ring Attention with Blockwise Transformers (Liu et al., 2023)

    • 突破性的分布式超长序列训练方案
  4. Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019)

    • MQA原始论文
  5. GQA: Training Generalized Multi-Query Transformer Models (Ainslie et al., 2023)

    • GQA设计与LLaMA-2应用
  6. Mistral 7B (Jiang et al., 2023)

    • Sliding Window Attention的工程实践