Accelerate 的内部机制
内部而言,Accelerate 通过首先分析脚本启动的环境来确定使用了哪种分布式设置、有多少个不同的进程以及当前脚本位于哪个进程中。所有这些信息都存储在 [~AcceleratorState
] 中。
这个类在你第一次实例化一个 [~Accelerator
] 时初始化,并执行你的分布式设置所需的任何特定初始化。其状态随后在所有 [~state.AcceleratorState
] 实例中唯一共享。(也可以使用继承自它的更简单的版本 [PartialState
] 来实现相同的功能)
然后,当你调用 [~Accelerator.prepare
] 时,库会:
- 将你的模型(们)包装在适应分布式设置的容器中,
- 将你的优化器(们)包装在 [
~optimizer.AcceleratedOptimizer
] 中, - 将你的学习率调度器(们)包装在 [
~scheduler.AcceleratedScheduler
] 中, - 在 [
~data_loader.DataLoaderShard
] 或 [~data_loader.DataLoaderDispatcher
] 中创建你的数据加载器(们)的新版本。
虽然模型(们)、优化器(们)和学习率调度器(们)只是被简单地包装,但数据加载器(们)是重新创建的。这主要是因为 PyTorch 不允许用户在数据加载器创建后更改其 batch_sampler
,而库通过更改 batch_sampler
来处理数据在进程之间的分片(如果启用的话)。
[~data_loader.DataLoaderShard
] 继承自 DataLoader
,并添加了以下功能:
- 在每次新的迭代中同步所有进程的适当随机数生成器,以确保任何随机化(如洗牌)在所有进程中完全一致。
- 在生成批次之前将其放置在适当的设备上(除非你选择不使用
device_placement=True
)。
[~data_loader.DataLoaderDispatcher
] 与 [~data_loader.DataLoaderShard
] 的不同之处在于,当遍历 DataLoader
时,数据从进程 0 开始,然后被分割并发送到每个进程,而不是在数据集级别进行处理。
随机数生成器同步默认会同步:
- 给定采样器(如 PyTorch 的
RandomSampler
)的generator
属性(对于 PyTorch >= 1.6) - PyTorch <= 1.5.1 中的主要随机数生成器
你可以通过主 [Accelerator
] 的 rng_types
参数选择要同步的随机数生成器。在 PyTorch >= 1.6 中,建议依赖本地 generator
以避免在所有进程中设置相同的主要随机数生成器种子。
如果你已安装了 torchdata>=0.8.0
,并且在你的 [~utils.DataLoaderConfiguration
] 中设置了 use_stateful_dataloader=True
,这些类将直接继承自 StatefulDataLoader
,并维护一个 state_dict
。
有关内部实现的更多详细信息,请参阅 内部实现页面。