Skip to content

执行和延迟任务

当你运行你的常规脚本时,指令是按顺序执行的。使用 Accelerate 将你的脚本部署到多个 GPU 上时会引入一个复杂性:虽然每个进程都会按顺序执行所有指令,但有些进程可能会比其他进程更快。

你可能需要等待所有进程都到达某个特定点之后再执行某个指令。例如,在确保每个进程都完成训练之前,你不应该保存模型,而且在所有模型权重都加载完毕之前,你也不希望继续训练。为此,只需在代码中写入以下行:

accelerator.wait_for_everyone()

此指令将阻塞所有先到达的进程,直到所有其他进程都到达该点(如果你的脚本仅在一个 GPU 或 CPU 上运行,这将不会有任何作用)。

下面列出了几个使用此工具的示例情况:

下载数据集

在下载数据集时,你应该先在主进程中下载,然后加载缓存的数据集。

python
with accelerator.main_process_first():
    datasets = load_dataset("glue", "mrpc")

在内部,这与调用以下内容相同:

python
# First do something on the main process
if accelerator.is_main_process:
    datasets = load_dataset("glue", "mrpc")
else:
    accelerator.wait_for_everyone()

# And then send it to the rest of them
if not accelerator.is_main_process:
    datasets = load_dataset("glue", "mrpc")
else:
    accelerator.wait_for_everyone()

保存 state_dict

在保存模型的 state_dict 时,由于通常只在主进程中保存一个文件,因此你应该指定这一点:

python
if accelerator.is_main_process:
    model = accelerator.unwrap_model(model)
    torch.save(model.state_dict(), "weights.pth")

加载 state_dict

在将 state_dict 加载到模型、优化器或调度器时,你应该等待所有工作进程都加载完权重后再继续训练。

python
with accelerator.main_process_first():
    state = torch.load("weights.pth")
    model.load_state_dict(state)

应用多工作进程的 CPU 操作

在多个工作进程上应用 map() 操作(例如分词)应该首先在主进程中进行,然后再传播到每个工作进程。

python
datasets = load_dataset("glue", "mrpc")

with accelerator.main_process_first():
    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence1", "sentence2"],
    )

应用检查,如提前停止

为了使用由特定进程设置的标志进行检查,应使用 set_triggercheck_trigger API。有用的示例包括使用提前停止和监控损失(因为每个进程的损失略有不同)。

当你的条件满足时,调用 [Accelerator.set_trigger],在检查任何进程是否满足该条件时,调用 [Accelerator.check_trigger]:

python
for (x,y) in data_loader:
    logits = model(x)
    loss = loss_func(logits, y)
    # Assume `should_do_early_stopping` is a custom defined function that returns a conditional
    if should_do_early_stopping(loss):
        accelerator.set_trigger()

    # Later in the training script when we need to check for the breakpoint
    if accelerator.check_trigger():
        break