架构层解决方案 - Sparse Attention与Ring Attention
系统分析降低Transformer注意力复杂度的架构方案,包括Sparse Attention变体、Ring Attention分布式计算、Sliding Window优化及MQA/GQA技术。
Sparse Attention:从二次到线性复杂度
稀疏注意力核心思想
Sparse Attention通过限制每个query token只关注部分key token,将注意力复杂度从降低到或。关键在于设计合理的稀疏模式,在降低计算量的同时保持模型能力。
主流稀疏模式对比
flowchart LR
subgraph 全注意力["Full Attention O(n²)"]
F1["●●●●"]
F2["●●●●"]
F3["●●●●"]
F4["●●●●"]
end
subgraph 局部滑动["Sliding Window O(n×w)"]
S1["●●○○"]
S2["●●●○"]
S3["○●●●"]
S4["○○●●"]
end
subgraph 膨胀注意力["Dilated Attention O(n×w/d)"]
D1["●○●○"]
D2["○●○●"]
D3["●○●○"]
D4["○●○●"]
end
subgraph 全局局部["Global+Local O(n×(g+w))"]
G1["●●●●"]
G2["●●○○"]
G3["●○●○"]
G4["●○○●"]
end
Longformer:局部窗口+全局注意力
Longformer采用滑动窗口局部注意力与预定义全局注意力的组合:
技术细节:
- 每个token关注左右各个邻居(通常)
- 预设部分token为”全局”token(如[CLS]、段落首句),关注所有位置
- 复杂度:,其中为全局token数量
性能表现(在4K长度文本分类任务上):
| 配置 | 准确率 | 计算量(相对) |
|---|---|---|
| Full Attention | 89.2% | 100% |
| Longformer (w=256) | 88.7% | 12% |
| Longformer (w=512) | 89.1% | 25% |
| Longformer (w=1024) | 89.2% | 50% |
优势:实现简单,可在现有Transformer上快速部署 局限:全局token的选择依赖任务先验,通用性受限
BigBird:随机+窗口+全局
BigBird通过三种注意力模式的组合实现高效长序列建模:
- 随机注意力:每个token随机关注个位置(提供全局连接)
- 滑动窗口:局部连续关注(捕获局部依赖)
- 全局token:部分token关注所有位置(作为信息枢纽)
理论保证: BigBird论文证明了这种稀疏模式是图灵完备的——可以表达任何可计算函数,这是早期稀疏注意力方案(如Longformer)不具备的理论特性。
实际性能:
- 在64K长度文档级任务上,BigBird达到Full Attention 96%的性能
- 训练速度提升8-10倍,显存占用降低60%
Sparse Transformer:跨步注意力
Sparse Transformer采用跨步模式(Strided Pattern):
- 每层交替使用行注意力(row-wise)和列注意力(column-wise)
- 行注意力关注同一行内的token
- 列注意力关注间隔为的token(如每隔64个位置)
graph TD
subgraph "第1层: 行注意力"
A1[1] --- B1[2] --- C1[3] --- D1[4]
E1[5] --- F1[6] --- G1[7] --- H1[8]
end
subgraph "第2层: 列注意力"
A2[1] -.-> A3[5]
B2[2] -.-> B3[6]
C2[3] -.-> C3[7]
D2[4] -.-> D3[8]
end
Ring Attention:突破单卡显存限制
核心创新
Ring Attention由UC Berkeley于2023年提出,通过**块级计算(Blockwise Computation)和环形通信(Ring Communication)**实现超长序列的分布式训练与推理。
关键洞察: 标准注意力的softmax可以分解为块级计算:
其中为全局最大值,可以在线计算并传播。
工作原理
sequenceDiagram
participant C1 as 计算节点1<br/>Blocks 1-4
participant C2 as 计算节点2<br/>Blocks 5-8
participant C3 as 计算节点3<br/>Blocks 9-12
participant C4 as 计算节点4<br/>Blocks 13-16
Note over C1,C4: 初始化: 每个节点持有1/4序列
loop 环形迭代
C1->>C2: 发送KV Block
C2->>C3: 发送KV Block
C3->>C4: 发送KV Block
C4->>C1: 发送KV Block
Note over C1: 累积局部Attention结果
Note over C2: 累积局部Attention结果
Note over C3: 累积局部Attention结果
Note over C4: 累积局部Attention结果
end
Note over C1,C4: All-Reduce: 汇总全局Softmax归一化因子
计算-通信重叠
Ring Attention的核心优势在于计算与通信的重叠:
- 每个节点在计算当前块的attention时,同时接收下一个块的KV
- 通信延迟被隐藏在计算过程中
- 理想情况下,扩展效率接近线性
性能数据(在256个A100上训练LLaMA-2-7B):
| 序列长度 | 训练时间/步 | 显存/节点 | 扩展效率 |
|---|---|---|---|
| 4K | 2.1s | 12 GB | 100% |
| 32K | 3.8s | 14 GB | 95% |
| 128K | 6.2s | 18 GB | 92% |
| 1M | 18.5s | 32 GB | 88% |
与其他方案的对比
xychart-beta
title "不同方案的最大支持长度与复杂度"
x-axis "最大序列长度 (log scale)" 4096 --> 1048576
y-axis "计算复杂度 (相对值)" 1 --> 1000
line [1, 4, 16, 64, 256, 1024]
line [1, 2, 4, 8, 16, 32]
line [1, 1.2, 1.5, 2, 2.5, 3]
annotation "Full Attention" at 1000000, 256
annotation "Sparse Attention" at 1000000, 32
annotation "Ring Attention" at 1000000, 2.5
关键结论:
- Ring Attention在1M长度上的复杂度仅为Full Attention的0.3%
- 相比Sparse Attention,Ring Attention保持精确的Full Attention语义,无信息损失
Sliding Window Attention:平衡效率与性能
实现机制
Sliding Window Attention(SWA)是Mistral-7B等模型采用的高效注意力方案:
- 每个token只关注最近的个token(如)
- 在Transformer层间交替滑动窗口方向(左→右→左)
- 部分层保留全局注意力(如每4层中的1层)
优势:
- 计算复杂度:,与序列长度无关
- 实现简单:仅需修改attention mask
- 硬件友好:内存访问模式连续,利于缓存
局限:
- 最大依赖距离受限为
- 需要足够的Transformer层数传递远距离信息
Mistral-7B的实际表现
| 任务类型 | 窗口大小 | 准确率 | 延迟(相对) |
|---|---|---|---|
| 短文本分类 | 4K | 91.2% | 1.0x |
| 文档问答 | 4K | 78.5% | 1.2x |
| 代码生成 | 4K | 85.3% | 1.1x |
| 长文档摘要 | 32K | 82.1% | 2.8x |
观察:在32K任务上使用4K滑动窗口+4层全局注意力的配置,达到82.1%的准确率,而计算成本仅为Full Attention的35%。
MQA与GQA:降低KV Cache开销
Multi-Query Attention (MQA)
MQA由Shazeer于2019年提出,核心思想是所有query共享同一组key和value:
其中为第个head的query,和为共享的key/value。
显存节省:
- 标准MHA:
- MQA:
- 对于32头模型,MQA减少KV Cache 96.9%
性能影响: 在PaLM-540B上的实验显示,MQA相比MHA仅造成0.5%的困惑度上升,但推理速度提升2.3倍。
Grouped-Query Attention (GQA)
GQA是MHA和MQA的折中方案:
- 将个query head分为组(如)
- 每组共享一组key/value
- 当时退化为MHA,时为MQA
flowchart LR
subgraph MHA["Multi-Head Attention (32 heads)"]
Q1[Q1] --> K1[K1]
Q1 --> V1[V1]
Q2[Q2] --> K2[K2]
Q2 --> V2[V2]
Q3[Q3] --> K3[K3]
Q3 --> V3[V3]
Qn[...] --> Kn[...]
Qn --> Vn[...]
end
subgraph GQA["Grouped-Query Attention (32Q/8KV)"]
GQ1[Q1-4] --> GK1[K1]
GQ1 --> GV1[V1]
GQ2[Q5-8] --> GK2[K2]
GQ2 --> GV2[V2]
GQ3[Q9-12] --> GK3[K3]
GQ3 --> GV3[V3]
GQn[...] --> GKn[...]
GQn --> GVn[...]
end
LLaMA-2的GQA配置:
- 7B模型:MHA(性能优先)
- 13B模型:MHA
- 70B模型:GQA-8(平衡性能与效率)
在长上下文场景下,GQA-8相比MHA减少KV Cache 87.5%,而困惑度仅增加0.3%。
方案对比与选型建议
综合对比表
| 方案 | 计算复杂度 | 显存需求 | 实现难度 | 适用场景 | 代表模型 |
|---|---|---|---|---|---|
| Full Attention | 极高 | 简单 | 短文本(<4K) | GPT-4 | |
| Longformer | 低 | 中等 | 文档理解 | Longformer | |
| BigBird | 低 | 复杂 | 超长文档 | BigBird | |
| Ring Attention | 中 | 复杂 | 训练>128K | - | |
| Sliding Window | 低 | 简单 | 平衡方案 | Mistral | |
| MQA | 极低 | 简单 | 极致压缩 | PaLM | |
| GQA | 很低 | 简单 | 推荐方案 | LLaMA-2-70B |
部署建议
-
训练阶段(需支持>128K):
- 首选:Ring Attention + GQA
- 备选:Sparse Attention(BigBird)+ GQA
-
推理阶段(8K-128K):
- 首选:Sliding Window + GQA
- 备选:YaRN外推 + GQA
-
资源受限场景(<24GB显存):
- 必须:MQA或GQA + KV Cache量化
- 推荐:4-bit KV Cache压缩
参考资料
-
Longformer: The Long-Document Transformer (Beltagy et al., 2020)
- 滑动窗口+全局注意力的开创性工作
-
Big Bird: Transformers for Longer Sequences (Zaheer et al., 2020)
- 随机+窗口+全局的稀疏模式,理论完备性证明
-
Ring Attention with Blockwise Transformers (Liu et al., 2023)
- 突破性的分布式超长序列训练方案
-
Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019)
- MQA原始论文
-
GQA: Training Generalized Multi-Query Transformer Models (Ainslie et al., 2023)
- GQA设计与LLaMA-2应用
-
Mistral 7B (Jiang et al., 2023)
- Sliding Window Attention的工程实践