Skip to content

潜在一致性蒸馏

潜在一致性模型(LCMs) 能够在几步之内生成高质量的图像,这代表了一个巨大的飞跃,因为许多流程至少需要25步以上。LCMs是通过将潜在一致性蒸馏方法应用于任何Stable Diffusion模型而产生的。这种方法通过在潜在空间中应用单阶段引导蒸馏,并结合跳步方法来一致地跳过时间步长以加速蒸馏过程(更多细节请参阅论文的第4.1、4.2和4.3节)。

如果你在vRAM有限的GPU上进行训练,尝试启用gradient_checkpointinggradient_accumulation_stepsmixed_precision以减少内存使用并加速训练。你还可以通过启用xFormersbitsandbytes'的8位优化器来进一步减少内存使用。

本指南将探讨train_lcm_distill_sd_wds.py脚本,帮助你更熟悉它,以及如何根据你的用例进行调整。

在运行脚本之前,请确保从源代码安装库:

bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

然后导航到包含训练脚本的示例文件夹,并安装你使用的脚本所需的所有依赖项:

bash
cd examples/consistency_distillation
pip install -r requirements.txt

初始化一个 🤗 Accelerate 环境(尝试启用 torch.compile 以显著加快训练速度):

bash
accelerate config

要设置一个默认的 🤗 Accelerate 环境而不选择任何配置:

bash
accelerate config default

或者如果你的环境不支持交互式 shell,比如笔记本,你可以使用:

py
from accelerate.utils import write_basic_config

write_basic_config()

最后,如果你想在自己的数据集上训练模型,请查看创建训练数据集指南,学习如何创建与训练脚本兼容的数据集。

脚本参数

训练脚本提供了许多参数,帮助你自定义训练过程。所有参数及其描述都可以在parse_args()函数中找到。该函数为每个参数提供了默认值,例如训练批次大小和学习率,但你也可以在训练命令中设置自己的值。

例如,要使用fp16格式加速训练并启用混合精度,请在训练命令中添加--mixed_precision参数:

bash
accelerate launch train_lcm_distill_sd_wds.py \
  --mixed_precision="fp16"

大部分参数与文本到图像训练指南中的参数相同,因此在本指南中,你将重点关注与潜在一致性蒸馏相关的参数。

  • --pretrained_teacher_model:预训练的潜在扩散模型路径,用作教师模型
  • --pretrained_vae_model_name_or_path:预训练的VAE路径;SDXL VAE已知存在数值不稳定性问题,因此此参数允许你指定替代的VAE(例如由madebyollin提供的VAE,适用于fp16)
  • --w_min--w_max:指导尺度采样的最小和最大指导尺度值
  • --num_ddim_timesteps:DDIM采样的步数
  • --loss_type:用于潜在一致性蒸馏的损失类型(L2或Huber);通常首选Huber损失,因为它对异常值更鲁棒
  • --huber_c:Huber损失参数

训练脚本

训练脚本首先创建一个数据集类——Text2ImageDataset——用于预处理图像并创建训练数据集。

py
def transform(example):
    image = example["image"]
    image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)

    c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
    image = TF.crop(image, c_top, c_left, resolution, resolution)
    image = TF.to_tensor(image)
    image = TF.normalize(image, [0.5], [0.5])

    example["image"] = image
    return example

为了提高在云端存储的大数据集上的读写性能,该脚本使用WebDataset格式创建了一个预处理管道,以应用转换并创建用于训练的数据集和数据加载器。图像在处理后直接输入到训练循环中,无需先下载完整的数据集。

py
processing_pipeline = [
    wds.decode("pil", handler=wds.ignore_and_continue),
    wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
    wds.map(filter_keys({"image", "text"})),
    wds.map(transform),
    wds.to_tuple("image", "text"),
]

main()函数中,加载了所有必要的组件,如噪声调度器、分词器、文本编码器和VAE。教师UNet也在这里加载,然后你可以从教师UNet创建一个学生UNet。学生UNet在训练过程中由优化器更新。

py
teacher_unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)

unet = UNet2DConditionModel(**teacher_unet.config)
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train()

现在你可以创建优化器来更新UNet参数:

py
optimizer = optimizer_class(
    unet.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

创建数据集

py
dataset = Text2ImageDataset(
    train_shards_path_or_url=args.train_shards_path_or_url,
    num_train_examples=args.max_train_samples,
    per_gpu_batch_size=args.train_batch_size,
    global_batch_size=args.train_batch_size * accelerator.num_processes,
    num_workers=args.dataloader_num_workers,
    resolution=args.resolution,
    shuffle_buffer_size=1000,
    pin_memory=True,
    persistent_workers=True,
)
train_dataloader = dataset.train_dataloader

接下来,你准备好设置训练循环并实现潜在一致性蒸馏方法(更多细节请参见论文中的算法1)。这一部分的脚本负责向潜在变量添加噪声、采样并创建指导尺度嵌入,以及从噪声中预测原始图像。

py
pred_x_0 = predicted_origin(
    noise_pred,
    start_timesteps,
    noisy_model_input,
    noise_scheduler.config.prediction_type,
    alpha_schedule,
    sigma_schedule,
)

model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0

接下来,它获取教师模型预测LCM预测,计算损失,然后将损失反向传播到LCM。

py
if args.loss_type == "l2":
    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
    loss = torch.mean(
        torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
    )

如果你想了解更多关于训练循环的工作原理,请查看理解管道、模型和调度器教程,该教程分解了去噪过程的基本模式。

启动脚本

现在你已经准备好启动训练脚本并开始蒸馏了!

在本指南中,你将使用--train_shards_path_or_url来指定Conceptual Captions 12M数据集在Hub上的存储路径这里。将MODEL_DIR环境变量设置为教师模型的名称,并将OUTPUT_DIR设置为你想要保存模型的位置。

bash
export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"

accelerate launch train_lcm_distill_sd_wds.py \
    --pretrained_teacher_model=$MODEL_DIR \
    --output_dir=$OUTPUT_DIR \
    --mixed_precision=fp16 \
    --resolution=512 \
    --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
    --max_train_steps=1000 \
    --max_train_samples=4000000 \
    --dataloader_num_workers=8 \
    --train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
    --validation_steps=200 \
    --checkpointing_steps=200 --checkpoints_total_limit=10 \
    --train_batch_size=12 \
    --gradient_checkpointing --enable_xformers_memory_efficient_attention \
    --gradient_accumulation_steps=1 \
    --use_8bit_adam \
    --resume_from_checkpoint=latest \
    --report_to=wandb \
    --seed=453645634 \
    --push_to_hub

训练完成后,你可以使用新的LCM进行推理。

py
from diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler
import torch

unet = UNet2DConditionModel.from_pretrained("your-username/your-model", torch_dtype=torch.float16, variant="fp16")
pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16, variant="fp16")

pipeline.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipeline.to("cuda")

prompt = "sushi rolls in the form of panda heads, sushi platter"

image = pipeline(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]

LoRA

LoRA 是一种显著减少可训练参数数量的训练技术。因此,训练速度更快,并且更容易存储生成的权重,因为它们要小得多(约 100MB)。使用 train_lcm_distill_lora_sd_wds.pytrain_lcm_distill_lora_sdxl.wds.py 脚本来使用 LoRA 进行训练。

LoRA 训练脚本在 LoRA 训练 指南中有更详细的讨论。

Stable Diffusion XL

Stable Diffusion XL (SDXL) 是一种强大的文本到图像模型,能够生成高分辨率图像,并在其架构中添加了第二个文本编码器。使用 train_lcm_distill_sdxl_wds.py 脚本来使用 LoRA 训练 SDXL 模型。

SDXL 训练脚本在 SDXL 训练 指南中有更详细的讨论。

下一步

恭喜你蒸馏了一个 LCM 模型!要了解更多关于 LCM 的信息,以下内容可能会有所帮助: