分布式推理
分布式推理可以分为三个类别:
- 将整个模型加载到每个 GPU 上,并将批次的各个部分依次通过每个 GPU 的模型副本
- 将模型的部分加载到每个 GPU 上,并一次处理单个输入
- 将模型的部分加载到每个 GPU 上,并使用称为调度管道并行的技术来结合前两种技术。
我们将介绍第一类和最后一类,展示如何实现这些更现实的场景。
自动将批次的各个部分发送到每个加载的模型
这是最占用内存的解决方案,因为它要求每个 GPU 在任何时候都保持一个完整的模型副本在内存中。
通常在这样做时,用户会将模型发送到特定设备以从 CPU 加载,然后将每个提示移动到不同的设备。
使用 diffusers
库的基本管道可能如下所示:
import torch
import torch.distributed as dist
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
然后根据特定的提示进行推理:
def run_inference(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
pipe.to(rank)
if torch.distributed.get_rank() == 0:
prompt = "a dog"
elif torch.distributed.get_rank() == 1:
prompt = "a cat"
result = pipe(prompt).images[0]
result.save(f"result_{rank}.png")
你会注意到,我们需要检查 rank 以确定要发送的提示,这可能会有些繁琐。
用户可能会想到,使用 Accelerate,通过 Accelerator
为这样的任务准备一个数据加载器可能是一个简单的管理方法。(要了解更多信息,请参阅 快速入门 中的相关部分)
它能管理吗?可以。但是否会增加不必要的额外代码:也是的。
使用 Accelerate,我们可以通过使用 [Accelerator.split_between_processes
] 上下文管理器(该管理器也存在于 PartialState
和 AcceleratorState
中)来简化这个过程。此函数会自动将你传递的任何数据(无论是提示、一组张量、先前数据的字典等)拆分到所有进程中(可能需要填充),以便你立即使用。
让我们使用这个上下文管理器重写上面的示例:
from accelerate import PartialState # Can also be Accelerator or AcceleratorState
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
distributed_state = PartialState()
pipe.to(distributed_state.device)
# Assume two processes
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
result = pipe(prompt).images[0]
result.save(f"result_{distributed_state.process_index}.png")
然后,要运行代码,我们可以使用 Accelerate:
如果你已经使用 accelerate config
生成了一个配置文件:
accelerate launch distributed_inference.py
如果你有特定的配置文件想要使用:
accelerate launch --config_file my_config.json distributed_inference.py
或者如果你不想创建任何配置文件并使用两个 GPU 启动:
注意:你会看到一些警告,提示某些值是根据你的系统猜测的。要消除这些警告,你可以运行
accelerate config default
或者通过accelerate config
创建一个配置文件。
accelerate launch --num_processes 2 distributed_inference.py
我们现在已将分割这些数据所需的样板代码减少到几行代码,非常简单。
但是,如果我们有不均匀的提示分发到 GPU 上怎么办?例如,如果有 3 个提示,但只有 2 个 GPU 呢?
在上下文管理器中,第一个 GPU 会接收前两个提示,第二个 GPU 接收第三个提示,确保所有提示都被分割且不需要额外的开销。
但是,如果我们想对 所有 GPU 的结果进行某些操作呢?(比如将它们全部收集起来并进行某种后处理) 你可以传递 apply_padding=True
以确保提示列表被填充到相同的长度,额外的数据将从最后一个样本中获取。这样,所有 GPU 都将有相同数量的提示,然后你可以收集结果。
例如:
from accelerate import PartialState # Can also be Accelerator or AcceleratorState
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
distributed_state = PartialState()
pipe.to(distributed_state.device)
# Assume two processes
with distributed_state.split_between_processes(["a dog", "a cat", "a chicken"], apply_padding=True) as prompt:
result = pipe(prompt).images
在第一个 GPU 上,提示词将是 ["a dog", "a cat"]
,在第二个 GPU 上将是 ["a chicken", "a chicken"]
。 请确保丢弃最后一个样本,因为它将是前一个样本的重复。
你可以找到更复杂的示例 这里,例如如何与 LLMs 一起使用。
内存高效的管道并行(实验性)
接下来的部分将讨论使用 管道并行。这是一个 实验性 的 API,利用 torch.distributed.pipelining 作为原生解决方案。
管道并行的基本思想是:假设你有 4 个 GPU 和一个足够大的模型,可以使用 device_map="auto"
将其 拆分 到四个 GPU 上。使用这种方法,你可以一次发送 4 个输入(例如这里,任何数量都可以),每个模型块将处理一个输入,然后在前一个块完成后再接收下一个输入,这使得它比前面描述的方法 更高效 且更快。以下是从 PyTorch 仓库中提取的示意图:
为了说明如何在 Accelerate 中使用这种方法,我们创建了一个 示例动物园,展示了多种不同的模型和情况。在本教程中,我们将展示如何在两个 GPU 上使用 GPT2。
在继续之前,请确保你已经安装了最新版本的 PyTorch,运行以下命令:
pip install --upgrade torch
pip install torch
首先在 CPU 上创建模型:
from transformers import GPT2ForSequenceClassification, GPT2Config
config = GPT2Config()
model = GPT2ForSequenceClassification(config)
model.eval()
接下来你需要创建一些示例输入来使用。这些输入有助于 torch.distributed.pipelining
跟踪模型。
input = torch.randint(
low=0,
high=config.vocab_size,
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)
接下来我们需要实际执行跟踪并准备好模型。为此,请使用 [inference.prepare_pippy
] 函数,它将自动完全包装模型以实现管道并行性:
from accelerate.inference import prepare_pippy
example_inputs = {"input_ids": input}
model = prepare_pippy(model, example_args=(input,))
从这里开始,剩下的就是实际执行分布式推理了!
args = some_more_arguments
with torch.no_grad():
output = model(*args)
当完成所有数据将仅存在于最后一个进程中:
from accelerate import PartialState
if PartialState().is_last_process:
print(output)
就是这样!如需进一步了解,请查看 Accelerate 仓库 中的推理示例和我们的 文档,我们将努力改进这一集成。