训练显存计算与优化
训练大模型时,显存是最关键的约束之一。本文拆解训练显存的四大组成部分,并介绍常见的优化手段。
Q:训练显存计算
来源:AI Infra / 抖音搜推架构一面(牛客网)
普通回答:训练显存包括模型参数、梯度和优化器状态。
更好的回答:
训练显存由四部分组成,以 Adam + FP16 混合精度训练为例:
1. 模型参数(Model Parameters)
- FP16 参数:2 bytes × P(P = 参数量)
- FP32 master copy(混合精度需要):4 bytes × P
- 合计:6P bytes
2. 梯度(Gradients)
- FP16 梯度:2 bytes × P
- 合计:2P bytes
3. 优化器状态(Optimizer States)
- Adam 需要一阶矩(momentum)和二阶矩(variance),各 FP32
- 4 bytes × P × 2 = 8P bytes
- SGD 只需 momentum:4P bytes
4. 激活值(Activations)— 见下题详解
总结(不含激活值):
| 组件 | FP32 训练 | 混合精度(Adam) |
|---|---|---|
| 参数 | 4P | 2P + 4P = 6P |
| 梯度 | 4P | 2P |
| 优化器 | 8P | 8P |
| 合计 | 16P | 16P |
举例:7B 模型(P = 7×10⁹)
- 不含激活值:16 × 7B = 112 GB
- 这就是为什么 7B 模型训练至少需要 2 张 A100 80GB
优化手段:
- ZeRO(DeepSpeed):将参数/梯度/优化器状态切分到多卡
- ZeRO-1:切优化器状态 → 省 ~4×
- ZeRO-2:切优化器 + 梯度
- ZeRO-3:全切(参数也分片)
- Offload:将优化器状态 offload 到 CPU 内存
- 梯度累积:减小单步 batch → 减少激活值显存
考察点:能否精确拆解每一项的来源和大小,以及对 ZeRO 等优化策略的理解。
Q:激活值会占多少显存
来源:AI Infra / 抖音搜推架构一面(牛客网)
普通回答:激活值显存和 batch size、序列长度有关,通常比参数大。
更好的回答:
激活值是前向传播中每层的中间输出,需要保留到反向传播时计算梯度。
Transformer 单层激活值显存(FP16,不含 attention score):
每层激活 ≈ seq_len × batch_size × hidden_dim × 约 10~14 × 2 bytes
其中 “10~14” 来自:
- LayerNorm 输入输出(2 份)
- QKV 投影输出(3 × hidden_dim)
- Attention score matrix(seq_len × seq_len × num_heads)← 这一项是平方级
- Attention output
- FFN 中间层(通常 4 × hidden_dim)
- FFN 输出
总激活值显存 = 单层 × num_layers
举例:LLaMA-7B(32 层,hidden=4096,seq_len=2048,batch=1,FP16)
- Attention score:2048 × 2048 × 32 heads × 2 bytes × 32 layers ≈ 8 GB
- 其余激活:约 5-10 GB
- batch=8 → 激活值可超 100 GB
关键特征:
- 激活值随 batch_size 和 seq_len 线性甚至平方增长
- 参数/梯度/优化器与 batch_size 无关,但激活值与 batch_size 成正比
优化手段:
- Activation Checkpointing(梯度重计算):只保存部分层的激活,反向时重新计算中间层。以约 33% 额外计算换取大幅显存节省
- FlashAttention:不显式存储 seq_len × seq_len 的 attention matrix,O(n) 显存
- Sequence Parallelism:将序列维度切分到多卡
考察点:激活值是训练显存中最容易被忽视但占比最大的部分,面试官想看你是否理解这个”隐形大户”。