Skip to content

Accelerate 的内部机制

内部而言,Accelerate 通过首先分析脚本启动的环境来确定使用了哪种分布式设置、有多少个不同的进程以及当前脚本位于哪个进程中。所有这些信息都存储在 [~AcceleratorState] 中。

这个类在你第一次实例化一个 [~Accelerator] 时初始化,并执行你的分布式设置所需的任何特定初始化。其状态随后在所有 [~state.AcceleratorState] 实例中唯一共享。(也可以使用继承自它的更简单的版本 [PartialState] 来实现相同的功能)

然后,当你调用 [~Accelerator.prepare] 时,库会:

  • 将你的模型(们)包装在适应分布式设置的容器中,
  • 将你的优化器(们)包装在 [~optimizer.AcceleratedOptimizer] 中,
  • 将你的学习率调度器(们)包装在 [~scheduler.AcceleratedScheduler] 中,
  • 在 [~data_loader.DataLoaderShard] 或 [~data_loader.DataLoaderDispatcher] 中创建你的数据加载器(们)的新版本。

虽然模型(们)、优化器(们)和学习率调度器(们)只是被简单地包装,但数据加载器(们)是重新创建的。这主要是因为 PyTorch 不允许用户在数据加载器创建后更改其 batch_sampler,而库通过更改 batch_sampler 来处理数据在进程之间的分片(如果启用的话)。

[~data_loader.DataLoaderShard] 继承自 DataLoader,并添加了以下功能:

  • 在每次新的迭代中同步所有进程的适当随机数生成器,以确保任何随机化(如洗牌)在所有进程中完全一致。
  • 在生成批次之前将其放置在适当的设备上(除非你选择不使用 device_placement=True)。

[~data_loader.DataLoaderDispatcher] 与 [~data_loader.DataLoaderShard] 的不同之处在于,当遍历 DataLoader 时,数据从进程 0 开始,然后被分割并发送到每个进程,而不是在数据集级别进行处理。

随机数生成器同步默认会同步:

  • 给定采样器(如 PyTorch 的 RandomSampler)的 generator 属性(对于 PyTorch >= 1.6)
  • PyTorch <= 1.5.1 中的主要随机数生成器

你可以通过主 [Accelerator] 的 rng_types 参数选择要同步的随机数生成器。在 PyTorch >= 1.6 中,建议依赖本地 generator 以避免在所有进程中设置相同的主要随机数生成器种子。

如果你已安装了 torchdata>=0.8.0,并且在你的 [~utils.DataLoaderConfiguration] 中设置了 use_stateful_dataloader=True,这些类将直接继承自 StatefulDataLoader,并维护一个 state_dict

有关内部实现的更多详细信息,请参阅 内部实现页面