Skip to content

完全分片数据并行

为了在更大的批量上加速训练大型模型,我们可以使用完全分片数据并行模型。 这种数据并行范式通过分片优化器状态、梯度和参数,使模型能够容纳更多的数据和更大的模型。 如需了解更多相关信息和优势,请参阅 完全分片数据并行博客。 我们已经集成了最新的 PyTorch 完全分片数据并行 (FSDP) 训练功能。 你只需要通过配置启用它即可。

开箱即用的工作方式

在你的机器上只需运行:

bash
accelerate config

并回答提出的问题。这将生成一个配置文件,该文件将自动用于在执行时正确设置默认选项。

bash
accelerate launch my_script.py --args_to_my_script

例如,你可以在启用 FSDP 的情况下从仓库的根目录运行 examples/nlp_example.py

bash
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: BertLayer
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
bash
accelerate launch examples/nlp_example.py

目前,Accelerate 通过 CLI 支持以下配置:

fsdp_sharding_strategy: [1] FULL_SHARD(分片优化器状态、梯度和参数),[2] SHARD_GRAD_OP(分片优化器状态和梯度),[3] NO_SHARD(DDP),[4] HYBRID_SHARD(在每个节点内分片优化器状态、梯度和参数,而每个节点拥有完整副本),[5] HYBRID_SHARD_ZERO2(在每个节点内分片优化器状态和梯度,而每个节点拥有完整副本)。更多信息请参阅官方 PyTorch 文档

fsdp_offload_params : 决定是否将参数和梯度卸载到 CPU

fsdp_auto_wrap_policy: [1] TRANSFORMER_BASED_WRAP, [2] SIZE_BASED_WRAP, [3] NO_WRAP

fsdp_transformer_layer_cls_to_wrap: 仅适用于 Transformer。当使用 fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP 时,用户可以提供一个以逗号分隔的 Transformer 层类名字符串(区分大小写)来包装,例如 BertLayer, GPTJBlock, T5Block, BertLayer,BertEmbeddings,BertSelfOutput。这很重要,因为共享权重的子模块(例如嵌入层)不应最终出现在不同的 FSDP 包装单元中。使用此策略,每个包含多头注意力和几个 MLP 层的块都会进行包装。剩余的层,包括共享的嵌入层,会方便地包装在同一个最外层的 FSDP 单元中。因此,对于基于 Transformer 的模型,请使用此策略。你可以通过回答 Do you want to use the model's _no_split_modules to wrapyes 来使用 model._no_split_modules。它会尽可能使用 model._no_split_modules

fsdp_min_num_params: 使用 fsdp_auto_wrap_policy=SIZE_BASED_WRAP 时的最小参数数量。

fsdp_backward_prefetch_policy: [1] BACKWARD_PRE, [2] BACKWARD_POST, [3] NO_PREFETCH

fsdp_forward_prefetch: 如果为 True,则 FSDP 在前向传递执行时显式预取下一个即将进行的全聚操作。仅适用于静态图模型,因为预取遵循第一次迭代的执行顺序。即,如果模型执行过程中子模块的顺序动态变化,则不要启用此功能。

fsdp_state_dict_type: [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT

fsdp_use_orig_params: 如果为 True,允许在初始化时使用非均匀的 requires_grad,这意味着支持冻结和可训练参数的混合。此设置在参数高效的微调等情况下非常有用,如 这篇文章 所讨论。此选项还允许有多个优化器参数组。在使用 FSDP 准备/包装模型之前创建优化器时,应将此设置为 True

fsdp_cpu_ram_efficient_loading: 仅适用于 Transformer 模型。如果为 True,只有第一个进程加载预训练模型检查点,而所有其他进程的权重为空。如果通过 from_pretrained 方法加载预训练的 Transformer 模型时遇到错误,应将此设置为 False。当此设置为 True 时,fsdp_sync_module_states 也必须为 True,否则所有进程(除主进程外)将具有随机权重,导致训练期间出现意外行为。为此,确保在调用 Transformers from_pretrained 方法之前初始化分布式进程组。使用 Trainer API 时,创建 TrainingArguments 类的实例时会初始化分布式进程组。

fsdp_sync_module_states: 如果为 True,每个单独包装的 FSDP 单元将从 rank 0 广播模块参数。

为了进行更细致的控制,你可以通过 FullyShardedDataParallelPlugin 指定其他 FSDP 参数。创建 FullyShardedDataParallelPlugin 对象时,传递那些不在加速配置中的参数,或者如果你想覆盖它们。FSDP 参数将根据加速配置文件或启动命令参数进行选择,而你通过 FullyShardedDataParallelPlugin 对象直接传递的其他参数将设置/覆盖这些参数。

以下是一个示例:

py
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

保存和加载

使用 FSDP 模型时,新的推荐检查点方法是在设置加速配置时将 StateDictType 设置为 SHARDED_STATE_DICT。 以下是使用加速库的 save_state 工具保存的代码片段。

py
accelerator.save_state("ckpt")

检查检查点文件夹,查看模型和优化器按进程划分的碎片:

ls ckpt
# optimizer_0  pytorch_model_0  random_states_0.pkl  random_states_1.pkl  scheduler.bin

cd ckpt

ls optimizer_0
# __0_0.distcp  __1_0.distcp

ls pytorch_model_0
# __0_0.distcp  __1_0.distcp

要重新加载它们以恢复训练,请使用 accelerate 的 load_state 工具。

py
accelerator.load_state("ckpt")

在使用 transformerssave_pretrained 方法时,传递 state_dict=accelerator.get_state_dict(model) 以保存模型的状态字典。 以下是一个示例:

diff
  unwrapped_model.save_pretrained(
      args.output_dir,
      is_main_process=accelerator.is_main_process,
      save_function=accelerator.save,
+     state_dict=accelerator.get_state_dict(model),
)

状态字典

accelerator.get_state_dict 将使用 FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 上下文管理器调用底层的 model.state_dict 实现,以仅获取 rank 0 的状态字典,并将其卸载到 CPU。

然后,你可以将 state 传递给 save_pretrained 方法。StateDictTypeFullStateDictConfig 有几种模式,你可以使用它们来控制 state_dict 的行为。更多信息,请参见 PyTorch 文档

如果你选择使用 StateDictType.SHARDED_STATE_DICT,在 Accelerator.save_state 期间,模型的权重将被拆分为 n 个文件,每个子拆分一个文件。为了在训练后将它们合并回一个字典,以便重新加载到模型中,你可以使用 merge_weights 工具:

py
from accelerate.utils import merge_fsdp_weights

# Our weights are saved usually in a `pytorch_model_fsdp_{model_number}` folder
merge_fsdp_weights("pytorch_model_fsdp_0", "output_path", safe_serialization=True)

最终输出将保存到 model.safetensorspytorch_model.bin(如果传递了 safe_serialization=False)。

这也可以通过 CLI 调用:

bash
accelerate merge-weights pytorch_model_fsdp_0/ output_path

FSDP 分片策略与 DeepSpeed ZeRO 阶段的映射

  • FULL_SHARD 对应 DeepSpeed ZeRO Stage-3。分片优化器状态、梯度和参数。
  • SHARD_GRAD_OP 对应 DeepSpeed ZeRO Stage-2。分片优化器状态和梯度。
  • NO_SHARD 对应 ZeRO Stage-0。不进行分片,每个 GPU 拥有模型、优化器状态和梯度的完整副本。
  • HYBRID_SHARD 对应 ZeRO++ Stage-3,其中 zero_hpz_partition_size=<num_gpus_per_node>. Here, this will shard optimizer states, gradients and parameters within each node while each node has full copy.

A few caveats to be aware of

  • In case of multiple models, pass the optimizers to the prepare call in the same order as corresponding models else accelerator.save_state() and accelerator.load_state() will result in wrong/unexpected behaviour.
  • This feature is incompatible with --predict_with_generate in the run_translation.py script of Transformers library.

For more control, users can leverage the FullyShardedDataParallelPlugin. After creating an instance of this class, users can pass it to the Accelerator class instantiation. For more information on these options, please refer to the PyTorch FullyShardedDataParallel code.