低精度训练方法
新硬件的发布促使了新的训练范式的出现,这些范式更好地利用了这些硬件。目前,这些范式主要以 8 位精度训练的形式出现,使用如 TransformersEngine (TE) 或 MS-AMP 等包。
为了更好地理解今天讨论的主题,我们建议你先阅读 低精度使用指南,因为本文档将频繁引用该指南。
快速图表
以下是 MS-AMP 文档中的一张图表,显示了每种解决方案在训练过程中使用的不同位精度:
优化级别 | 计算(GEMM) | 通信 | 权重 | 主权重 | 权重梯度 | 优化器状态 |
---|---|---|---|---|---|---|
FP16 AMP | FP16 | FP32 | FP32 | N/A | FP32 | FP32+FP32 |
Nvidia TE | FP8 | FP32 | FP32 | N/A | FP32 | FP32+FP32 |
MS-AMP O1 | FP8 | FP8 | FP16 | N/A | FP8 | FP32+FP32 |
MS-AMP O2 | FP8 | FP8 | FP16 | N/A | FP8 | FP8+FP16 |
MS-AMP O3 | FP8 | FP8 | FP8 | FP16 | FP8 | FP8+FP16 |
TransformersEngine
TransformersEngine
是第一个尝试在 8 位浮点精度下进行训练的解决方案。它通过使用某些模型层的替代层来实现,这些替代层利用其 FP8 引擎将位数(例如从 32 位减少到 8 位)减少,而不会降低模型的最终精度。
具体来说,Accelerate 会找到并替换以下层为 TransformersEngine
版本:
nn.LayerNorm
替换为te.LayerNorm
nn.Linear
替换为te.Linear
因此,我们最终得到的模型大部分层使用 BF16,而某些层使用 FP8,从而减少了一些内存使用。
据我们观察,使用 TransformerEngine
时,性能提升只有在模型中大多数层被替换为这两个层时才会显现。因此,只有参数数量在几亿及以上的大型模型在使用 TransformerEngine
时才会表现出性能提升。
TransformerEngine
可以接收许多不同的参数,以自定义其如何执行 FP8 计算及其功能。完整的参数列表如下:
margin
: 用于梯度缩放的边距。interval
: 用于重新计算缩放因子的间隔。fp8_format
: 用于 FP8 配方的格式。必须是HYBRID
或E4M3
。(通常训练时使用HYBRID
,评估时使用E4M3
)amax_history_len
: 用于缩放因子计算的历史长度。amax_compute_algo
: 用于缩放因子计算的算法。必须是max
或most_recent
。override_linear_precision
: 是否以更高精度执行fprop
、dgrad
和wgrad
GEMMS。
你可以通过 [utils.FP8RecipeKwargs
] 自定义这些参数,以优化模型的性能。
从前面提到的图表中可以看出,TE 仅将计算层转换为 FP8,而其他部分仍为 FP32。因此,这会占用最多的内存,但好处是保证了训练过程中最终精度的损失最小。
MS-AMP
MS-AMP 采取了与 TransformersEngine
不同的方法,提供了三种不同的优化级别,以将更多操作转换为 FP8 或 FP16。
基本优化级别 (
O1
),将权重通信(如在 DDP 中)转换为 FP8,将模型权重存储为 FP16,并将优化器状态保持在 FP32。这种优化级别的主要好处是通信带宽可以减少一半。此外,由于所有内容的一半被转换为 FP8,权重被转换为 FP16,因此节省了更多的 GPU 内存。值得注意的是,优化器状态仍保持在 FP32。第二个优化级别 (
O2
) 在此基础上进一步减少了优化器状态的精度。其中一个状态为 FP8,另一个为 FP16。通常情况下,这只会带来最终精度不下降、训练速度提升和内存减少的净收益,因为现在每个状态要么是 FP16,要么是 FP8。最后,MS-AMP 提供了第三个优化级别 (
O3
),在 DDP 场景(如 DeepSpeed)中特别有用。模型在内存中的权重完全转换为 FP8,主权重存储为 FP16。这将内存减少到最大程度,因为现在几乎所有内容都是 FP8,只有两个状态保持在 FP16。目前,只有 DeepSpeed 0.9.2 及以下版本支持此功能,因此此功能未包含在 Accelerate 集成中。
结合两者
虽然还需要进行更多实验,但已经注意到结合使用 MS-AMP 和 TransformersEngine 可以通过依赖 NVIDIA 优化的 FP8 操作并利用 MS-AMP 减少内存开销,从而实现最高的吞吐量。