BK-SDM

BK-SDM

高效轻量的Stable Diffusion压缩模型

BK-SDM是一种压缩版Stable Diffusion模型,通过移除U-Net中的部分模块实现轻量化。该模型采用有限数据进行蒸馏预训练,适用于SD v1和v2各版本,提供基础、小型和微型三种规模。BK-SDM在保持图像质量的同时,显著提高了推理速度,降低了计算资源需求,为高效文本到图像生成提供了新选择。

Stable DiffusionAI绘图模型压缩知识蒸馏图像生成Github开源项目

块移除知识蒸馏稳定扩散

BK-SDM: 轻量、快速且经济的稳定扩散版本的官方代码库 [ArXiv] [ECCV 2024]。

BK-SDM是轻量级文本到图像(T2I)合成模型:

  • 从SD的U-Net中移除了某些残差和注意力模块。
  • 使用非常有限的数据进行蒸馏预训练,但(令人惊讶地)仍然有效。

⚡快速链接:KD预训练 | MS-COCO评估 | DreamBooth微调 | 演示

公告

模型描述

安装

conda create -n bk-sdm python=3.8 conda activate bk-sdm git clone https://github.com/Nota-NetsPresso/BK-SDM.git cd BK-SDM pip install -r requirements.txt

关于我们使用的torch版本说明:

  • 在单个24GB RTX3090上进行MS-COCO评估和DreamBooth微调时使用torch 1.13.1
  • 在单个80GB A100上进行KD预训练时使用torch 2.0.1
    • 如果在A100上使用总批量大小256进行预训练导致GPU内存不足,请检查torch版本并考虑升级到torch>2.0.0

使用🤗Diffusers的最小示例

使用默认PNDM调度器和50个去噪步骤:

import torch from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-small", torch_dtype=torch.float16) pipe = pipe.to("cuda") prompt = "一个装有各种花朵的金色花瓶" image = pipe(prompt).images[0] image.save("example.png")
import torch from diffusers import StableDiffusionPipeline, UNet2DConditionModel pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) pipe.unet = UNet2DConditionModel.from_pretrained("nota-ai/bk-sdm-small", subfolder="unet", torch_dtype=torch.float16) pipe = pipe.to("cuda") prompt = "一个装有各种花朵的金色花瓶" image = pipe(prompt).images[0] image.save("example.png")

蒸馏预训练

我们的代码基于Diffusers 0.15.0版本的train_text_to_image.py。要访问最新版本,请使用此链接

[可选] 用于检查可运行性的小型测试

bash scripts/get_laion_data.sh preprocessed_11k bash scripts/kd_train_toy.sh

<详细信息> <摘要>注意</摘要>

  • 一个小型数据集(11K图像-文本对)将被下载到./data/laion_aes/preprocessed_11k(tar.gz格式1.7GB; 解压后数据文件夹1.8GB)。
  • 可以使用小型脚本来验证代码的可执行性,并找到适合你的GPU的批量大小。使用批量大小8(=4×2),训练BK-SDM-Base 20次迭代大约需要5分钟,使用22GB的GPU内存。

</详细信息>

单GPU训练BK-SDM-{Base, Small, Tiny}

bash scripts/get_laion_data.sh preprocessed_212k bash scripts/kd_train.sh

<详细信息> <摘要>注意</摘要>

  • 包含212K(=0.22M)对的数据集将被下载到./data/laion_aes/preprocessed_212k(tar.gz格式18GB; 解压后数据文件夹20GB)。
  • 使用批量大小256(=4×64),训练BK-SDM-Base 50K次迭代大约需要300小时,使用53GB的GPU内存。使用批量大小64(=4×16),需要60小时,使用28GB的GPU内存。
  • 训练BK-SDM-{Small, Tiny}会导致GPU内存使用减少5~10%。

</详细信息>

单GPU训练BK-SDM-{Base-2M, Small-2M, Tiny-2M}

bash scripts/get_laion_data.sh preprocessed_2256k bash scripts/kd_train_2m.sh

<详细信息> <摘要>注意</摘要>

  • 包含2256K(=2.3M)对的数据集将被下载到./data/laion_aes/preprocessed_2256k(tar.gz格式182GB; 解压后数据文件夹204GB)。
  • 除了数据集之外,kd_train_2m.shkd_train.sh相同;在相同的迭代次数下,训练计算保持不变。

</详细信息>

多GPU训练

bash scripts/kd_train_toy_ddp.sh

<详细信息> <摘要>注意</摘要>

  • 支持多GPU训练(样例结果:链接),尽管我们论文中的所有实验都是使用单个GPU进行的。感谢@youngwanLEE分享脚本 :)

</详细信息>

用BK-SDM压缩SD-v2

bash scripts/kd_train_v2-base-im512.sh bash scripts/kd_train_v2-im768.sh # 对于推理,请参见:'scripts/generate_with_trained_unet.sh'

关于训练代码的说明

<详细信息> <摘要> KD训练的关键部分 </摘要>

  • 通过调整config.json定义学生U-Net [链接]
  • 通过复制教师U-Net的权重初始化学生U-Net [链接]
  • 为特征KD定义钩子位置 [链接]
  • 定义特征和输出KD的损失 [链接]

</详细信息>

<详细信息> <摘要> 关键学习超参数 </摘要>

--unet_config_name "bk_small" # 选项: ["bk_base", "bk_small", "bk_tiny"] --use_copy_weight_from_teacher # 使用教师权重初始化学生unet --learning_rate 5e-05 --train_batch_size 64 --gradient_accumulation_steps 4 --lambda_sd 1.0 --lambda_kd_output 1.0 --lambda_kd_feat 1.0

</详细信息>

在MS-COCO基准测试上的评估

我们使用以下代码获得MS-COCO上的结果。使用PNDM调度器和25步去噪生成512×512图像后,我们将它们下采样到256×256以计算评分。

使用发布的模型生成(默认使用BK-SDM-Small)

在单个3090 GPU上,'(2)'每个模型需要约10小时,'(3)'需要几分钟。

  • (1) 下载 metadata.csvreal_im256.npz

    bash scripts/get_mscoco_files.sh # ./data/mscoco_val2014_30k/metadata.csv: MS-COCO验证集中的30K提示(用于'(2)') # ./data/mscoco_val2014_41k_full/real_im256.npz: 41K真实图像的FID统计数据(用于'(3)')
    <details> <summary> 关于 'real_im256.npz' 的说明 </summary>
    • 遵循评估协议[DALL·E, Imagen],真实图像的FID统计数据是在MS-COCO完整验证集(41K图像)上计算的。通过'(1)'下载的预计算统计文件位于 ./data/mscoco_val2014_41k_full/real_im256.npz
    • 此外,可以使用 python3 src/get_stat_mscoco_val2014.py 计算 real_im256.npz,该脚本会下载所有图像,将它们调整为256×256大小,并计算FID统计数据。
    </details>
  • (2) 基于MS-COCO验证集中的30K提示生成512×512图像 → 将它们调整为256×256:

    python3 src/generate.py # python3 src/generate.py --model_id nota-ai/bk-sdm-base --save_dir ./results/bk-sdm-base # python3 src/generate.py --model_id nota-ai/bk-sdm-tiny --save_dir ./results/bk-sdm-tiny

    [批量生成] 增加 --batch_sz(默认:1)可以加快推理速度,但会增加显存使用量。感谢 @Godofnothing 提供此功能 :)

    <details> <summary> 点击查看推理成本详情。 </summary>
    • 设置:BK-SDM-Small 在 MS-COCO 30K 图像生成任务上

    • 我们在论文结果中使用了评估批次大小为1。不同的批次大小会影响随机潜在编码的采样,导致略微不同的生成分数。

      评估批次大小1248
      GPU内存4.9GB6.3GB11.3GB19.6GB
      生成时间9.4小时7.9小时7.6小时7.3小时
      FID16.9817.0117.1616.97
      IS31.6831.2031.6231.22
      CLIP 评分0.26770.26790.26770.2675
    </details>
  • (3) 计算 FID、IS 和 CLIP 评分:

    bash scripts/eval_scores.sh # 对于其他模型,修改脚本中的 `./results/bk-sdm-*` 路径以指定不同的模型。

[训练后] 使用训练好的 U-Net 进行生成

bash scripts/get_mscoco_files.sh bash scripts/generate_with_trained_unet.sh

零样本 MS-COCO 256×256 30K 结果

请参阅 MODEL_CARD.md 中的结果

使用 🤗PEFT 进行 DreamBooth 微调

我们的轻量级 SD 骨干网络可用于高效的个性化生成。DreamBooth 能够根据少量图像改进文本到图像的扩散模型。DreamBooth+LoRA 可以大幅降低微调成本。

DreamBooth 数据集

数据集下载到 ./data/dreambooth/dataset [文件夹树]:30个主题 × 25个提示 × 4~6张图像。

git clone https://github.com/google/dreambooth ./data/dreambooth

DreamBooth 微调(默认使用 BK-SDM-Base

我们的代码基于 PEFT 0.1.0train_dreambooth.py。要访问最新版本,请使用此链接

  • (1) 不使用 LoRA — 全面微调 & 在我们的论文中使用
    bash scripts/finetune_full.sh # 学习率 1e-6 bash scripts/generate_after_full_ft.sh
  • (2) 使用 LoRA — 参数高效微调
    bash scripts/finetune_lora.sh # 学习率 1e-4 bash scripts/generate_after_lora_ft.sh
  • 在单个 3090 GPU 上,每个主题的微调需要 10~20 分钟。

个性化生成结果

请参阅 MODEL_CARD.md 中的 DreamBooth 结果

Gradio 演示

查看我们的 Gradio 演示代码(主要文件:app.py)! <details> <summary> [2023年8月1日] 在 Hugging Face 本周精选空间 🔥 中被推荐 </summary> <img alt="本周精选空间" img src="https://yellow-cdn.veclightyear.com/0a4dffa0/81cd9b9e-55ec-40d8-b083-886daaf7e436.png" width="100%"> </details>

Core ML 权重

对于 iOS 或 macOS 应用程序,我们已将模型转换为 Core ML 格式。它们可在 🤗Hugging Face Models(nota-ai/coreml-bk-sdm)上获取,并可与 Apple 的 Core ML Stable Diffusion 库 一起使用。

  • iPhone 14 上 4 秒推理(10 步去噪):结果

许可证

本项目及其权重受 CreativeML Open RAIL-M 许可证 约束,旨在减轻使用高度先进的机器学习系统可能产生的任何潜在负面影响。该许可证的摘要 如下:

1. 您不能使用该模型故意生产或分享非法或有害的输出或内容,
2. 我们不对您生成的输出主张任何权利,您可以自由使用它们,但需对其使用负责,使用时不应违反许可证中规定的条款,
3. 您可以重新分发权重并商业使用该模型和/或将其作为服务提供。如果您这样做,请注意您必须包含与许可证中相同的使用限制,并向所有用户分享 CreativeML OpenRAIL-M 的副本。

致谢

引用

@article{kim2023bksdm, title={BK-SDM: A Lightweight, Fast, and Cheap Version of Stable Diffusion}, author={Kim, Bo-Kyeong and Song, Hyoung-Kyu and Castells, Thibault and Choi, Shinkook}, journal={arXiv preprint arXiv:2305.15798}, year={2023}, url={https://arxiv.org/abs/2305.15798} }

编辑推荐精选

讯飞智文

讯飞智文

一键生成PPT和Word,让学习生活更轻松

讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。

AI办公办公工具AI工具讯飞智文AI在线生成PPTAI撰写助手多语种文档生成AI自动配图热门
讯飞星火

讯飞星火

深度推理能力全新升级,全面对标OpenAI o1

科大讯飞的星火大模型,支持语言理解、知识问答和文本创作等多功能,适用于多种文件和业务场景,提升办公和日常生活的效率。讯飞星火是一个提供丰富智能服务的平台,涵盖科技资讯、图像创作、写作辅助、编程解答、科研文献解读等功能,能为不同需求的用户提供便捷高效的帮助,助力用户轻松获取信息、解决问题,满足多样化使用场景。

热门AI开发模型训练AI工具讯飞星火大模型智能问答内容创作多语种支持智慧生活
Spark-TTS

Spark-TTS

一种基于大语言模型的高效单流解耦语音令牌文本到语音合成模型

Spark-TTS 是一个基于 PyTorch 的开源文本到语音合成项目,由多个知名机构联合参与。该项目提供了高效的 LLM(大语言模型)驱动的语音合成方案,支持语音克隆和语音创建功能,可通过命令行界面(CLI)和 Web UI 两种方式使用。用户可以根据需求调整语音的性别、音高、速度等参数,生成高质量的语音。该项目适用于多种场景,如有声读物制作、智能语音助手开发等。

Trae

Trae

字节跳动发布的AI编程神器IDE

Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。

AI工具TraeAI IDE协作生产力转型热门
咔片PPT

咔片PPT

AI助力,做PPT更简单!

咔片是一款轻量化在线演示设计工具,借助 AI 技术,实现从内容生成到智能设计的一站式 PPT 制作服务。支持多种文档格式导入生成 PPT,提供海量模板、智能美化、素材替换等功能,适用于销售、教师、学生等各类人群,能高效制作出高品质 PPT,满足不同场景演示需求。

讯飞绘文

讯飞绘文

选题、配图、成文,一站式创作,让内容运营更高效

讯飞绘文,一个AI集成平台,支持写作、选题、配图、排版和发布。高效生成适用于各类媒体的定制内容,加速品牌传播,提升内容营销效果。

热门AI辅助写作AI工具讯飞绘文内容运营AI创作个性化文章多平台分发AI助手
材料星

材料星

专业的AI公文写作平台,公文写作神器

AI 材料星,专业的 AI 公文写作辅助平台,为体制内工作人员提供高效的公文写作解决方案。拥有海量公文文库、9 大核心 AI 功能,支持 30 + 文稿类型生成,助力快速完成领导讲话、工作总结、述职报告等材料,提升办公效率,是体制打工人的得力写作神器。

openai-agents-python

openai-agents-python

OpenAI Agents SDK,助力开发者便捷使用 OpenAI 相关功能。

openai-agents-python 是 OpenAI 推出的一款强大 Python SDK,它为开发者提供了与 OpenAI 模型交互的高效工具,支持工具调用、结果处理、追踪等功能,涵盖多种应用场景,如研究助手、财务研究等,能显著提升开发效率,让开发者更轻松地利用 OpenAI 的技术优势。

Hunyuan3D-2

Hunyuan3D-2

高分辨率纹理 3D 资产生成

Hunyuan3D-2 是腾讯开发的用于 3D 资产生成的强大工具,支持从文本描述、单张图片或多视角图片生成 3D 模型,具备快速形状生成能力,可生成带纹理的高质量 3D 模型,适用于多个领域,为 3D 创作提供了高效解决方案。

3FS

3FS

一个具备存储、管理和客户端操作等多种功能的分布式文件系统相关项目。

3FS 是一个功能强大的分布式文件系统项目,涵盖了存储引擎、元数据管理、客户端工具等多个模块。它支持多种文件操作,如创建文件和目录、设置布局等,同时具备高效的事件循环、节点选择和协程池管理等特性。适用于需要大规模数据存储和管理的场景,能够提高系统的性能和可靠性,是分布式存储领域的优质解决方案。

下拉加载更多