Skip to content

Stable Diffusion 3

Stable Diffusion 3 (SD3) 是由 Patrick Esser、Sumith Kulal、Andreas Blattmann、Rahim Entezari、Jonas Muller、Harry Saini、Yam Levi、Dominik Lorenz、Axel Sauer、Frederic Boesel、Dustin Podell、Tim Dockhorn、Zion English、Kyle Lacey、Alex Goodwin、Yannik Marek 和 Robin Rombach 在论文 Scaling Rectified Flow Transformers for High-Resolution Image Synthesis 中提出的。

论文的摘要如下:

扩散模型通过反转数据向噪声的前向路径来从噪声中生成数据,并已成为处理高维感知数据(如图像和视频)的强大生成建模技术。整流流是一种最近提出的生成模型,它通过直线连接数据和噪声。尽管其在理论属性和概念简单性方面具有优势,但尚未被明确确立为标准实践。在这项工作中,我们通过偏向于感知相关尺度的噪声采样技术来改进现有的整流流模型训练方法。通过大规模研究,我们展示了这种方法在高质量文本到图像合成方面相对于现有扩散公式的优越性能。此外,我们提出了一种新颖的基于Transformer的文本到图像生成架构,该架构为两种模态使用单独的权重,并实现了图像和文本token之间的双向信息流,从而提高了文本理解排版和人类偏好评分。我们证明,这种架构遵循可预测的扩展趋势,并且通过各种指标和人类评估,验证损失的降低与文本到图像合成的改进相关。

使用示例

由于模型是受限的,在使用diffusers之前,你需要先访问 Stable Diffusion 3 Medium Hugging Face页面,填写表格并接受限制。一旦你进入,你需要登录以便系统知道你已接受限制。

使用以下命令登录:

bash
huggingface-cli login
python
import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe.to("cuda")

image = pipe(
    prompt="a photo of a cat holding a sign that says hello world",
    negative_prompt="",
    num_inference_steps=28,
    height=1024,
    width=1024,
    guidance_scale=7.0,
).images[0]

image.save("sd3_hello_world.png")

注意: Stable Diffusion 3.5 也可以使用 SD3 管道运行,并且所有提到的优化和技术也同样适用于它。SD3 系列中共有三个官方模型:

SD3 的内存优化

SD3 使用了三个文本编码器,其中一个是非常大的 T5-XXL 模型。这使得在 VRAM 小于 24GB 的 GPU 上运行模型变得具有挑战性,即使使用 fp16 精度也是如此。以下部分概述了 Diffusers 中的一些内存优化,使得在低资源硬件上运行 SD3 变得更加容易。

使用模型卸载进行推理

Diffusers 中最基本的内存优化功能允许你在推理过程中将模型的组件卸载到 CPU 以节省内存,同时会略微增加推理延迟。模型卸载仅在需要执行时将模型组件移动到 GPU,而将剩余组件保留在 CPU 上。

python
import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()

image = pipe(
    prompt="a photo of a cat holding a sign that says hello world",
    negative_prompt="",
    num_inference_steps=28,
    height=1024,
    width=1024,
    guidance_scale=7.0,
).images[0]

image.save("sd3_hello_world.png")

在推理过程中丢弃T5文本编码器

在推理过程中移除内存密集型的4.7B参数T5-XXL文本编码器可以显著降低SD3的内存需求,且仅会略微损失性能。

python
import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    text_encoder_3=None,
    tokenizer_3=None,
    torch_dtype=torch.float16
)
pipe.to("cuda")

image = pipe(
    prompt="a photo of a cat holding a sign that says hello world",
    negative_prompt="",
    num_inference_steps=28,
    height=1024,
    width=1024,
    guidance_scale=7.0,
).images[0]

image.save("sd3_hello_world-no-T5.png")

使用量化版本的T5文本编码器

我们可以利用bitsandbytes库将T5-XXL文本编码器加载并量化为8位精度。这使你能够在仅略微影响性能的情况下继续使用所有三个文本编码器。

首先安装bitsandbytes库。

shell
pip install bitsandbytes

然后使用BitsAndBytesConfig加载T5-XXL模型。

python
import torch
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
text_encoder = T5EncoderModel.from_pretrained(
    model_id,
    subfolder="text_encoder_3",
    quantization_config=quantization_config,
)
pipe = StableDiffusion3Pipeline.from_pretrained(
    model_id,
    text_encoder_3=text_encoder,
    device_map="balanced",
    torch_dtype=torch.float16
)

image = pipe(
    prompt="a photo of a cat holding a sign that says hello world",
    negative_prompt="",
    num_inference_steps=28,
    height=1024,
    width=1024,
    guidance_scale=7.0,
).images[0]

image.save("sd3_hello_world-8bit-T5.png")

你可以在这里找到端到端的脚本。

SD3的性能优化

使用Torch编译加速推理

在SD3管道中使用编译后的组件可以将推理速度提高多达4倍。以下代码片段展示了如何编译SD3管道中的Transformer和VAE组件。

python
import torch
from diffusers import StableDiffusion3Pipeline

torch.set_float32_matmul_precision("high")

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    torch_dtype=torch.float16
).to("cuda")
pipe.set_progress_bar_config(disable=True)

pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

# Warm Up
prompt = "a photo of a cat holding a sign that says hello world"
for _ in range(3):
    _ = pipe(prompt=prompt, generator=torch.manual_seed(1))

# Run Inference
image = pipe(prompt=prompt, generator=torch.manual_seed(1)).images[0]
image.save("sd3_hello_world.png")

查看完整脚本这里

使用长提示与T5文本编码器

默认情况下,T5文本编码器提示使用最大序列长度为256。这可以通过设置max_sequence_length来接受更少或更多的标记进行调整。请记住,较长的序列需要额外的资源,并导致更长的生成时间,例如在批量推理期间。

python
prompt = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creature’s body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree.  As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight"

image = pipe(
    prompt=prompt,
    negative_prompt="",
    num_inference_steps=28,
    guidance_scale=4.5,
    max_sequence_length=512,
).images[0]

向T5文本编码器发送不同的提示

你可以向CLIP文本编码器和T5文本编码器发送不同的提示,以防止提示被CLIP文本编码器截断,并改进生成效果。

python
prompt = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. A river of warm, melted butter, pancake-like foliage in the background, a towering pepper mill standing in for a tree."

prompt_3 = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creature’s body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree.  As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight"

image = pipe(
    prompt=prompt,
    prompt_3=prompt_3,
    negative_prompt="",
    num_inference_steps=28,
    guidance_scale=4.5,
    max_sequence_length=512,
).images[0]

用于稳定扩散3的Tiny AutoEncoder

用于稳定扩散的Tiny AutoEncoder(TAESD3)是Ollin Boer Bohan开发的一个小型蒸馏版本的Stable Diffusion 3的VAE,可以几乎瞬间解码[StableDiffusion3Pipeline]的潜在变量。

与稳定扩散3一起使用:

python
import torch
from diffusers import StableDiffusion3Pipeline, AutoencoderTiny

pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "slice of delicious New York-style berry cheesecake"
image = pipe(prompt, num_inference_steps=25).images[0]
image.save("cheesecake.png")

通过from_single_file加载原始检查点

SD3Transformer2DModelStableDiffusion3Pipeline类支持通过from_single_file方法加载原始检查点。此方法允许你加载用于训练模型的原始检查点文件。

SD3Transformer2DModel加载原始检查点

python
from diffusers import SD3Transformer2DModel

model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium.safetensors")

StableDiffusion3Pipeline加载单个检查点

不带T5加载单个文件检查点

python
import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_single_file(
    "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
    torch_dtype=torch.float16,
    text_encoder_3=None
)
pipe.enable_model_cpu_offload()

image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file.png')

带T5加载单个文件检查点

TIP

以下示例加载了一个以8位浮点格式存储的检查点,这需要PyTorch 2.3或更高版本。

python
import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_single_file(
    "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
    torch_dtype=torch.float16,
)
pipe.enable_model_cpu_offload()

image = pipe("a picture of a cat holding a sign that says hello world").images[0]
image.save('sd3-single-file-t5-fp8.png')

加载Stable Diffusion 3.5 Transformer模型的单文件检查点

python
import torch
from diffusers import SD3Transformer2DModel, StableDiffusion3Pipeline

transformer = SD3Transformer2DModel.from_single_file(
    "https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo/blob/main/sd3.5_large.safetensors",
    torch_dtype=torch.bfloat16,
)
pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3.5-large",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
image = pipe("a cat holding a sign that says hello world").images[0]
image.save("sd35.png")

StableDiffusion3Pipeline

[[autodoc]] StableDiffusion3Pipeline - all - call