class="hljs-ln-code"> class="hljs-ln-line">pip install fairscale
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">pip install transformers
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">pip install requests
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">pip install accelerate
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">pip install diffusers
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">pip install einop
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">pip install safetensors
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line">pip install voluptuous
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line">pip install jax
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">pip install jaxlib
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">pip install peft
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">pip install deepface==0.0.92
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line">pip install tensorflow==2.9.0 # 为了避免最后评估阶段使用deepface时的错误,这里选择降级版本
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line">pip install keras
  • class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line">pip install opencv-python
  • class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    4 导入

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># ========== 标准库模块 ==========
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">import os
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">import math
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">import glob
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">import shutil
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">import subprocess
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"># ========== 第三方库 ==========
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line">import numpy as np
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line">import torch
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">import torch.nn.functional as F
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">from PIL import Image
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">from tqdm.auto import tqdm
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line">
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"># ========== 深度学习相关库 ==========
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line">from torchvision import transforms
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line">
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"># Transformers (Hugging Face)
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPProcessor
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line">
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"># Diffusers (Hugging Face)
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line">from diffusers import (
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> AutoencoderKL,
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"> DDPMScheduler,
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> UNet2DConditionModel,
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> DiffusionPipeline
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line">)
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line">from diffusers.optimization import get_scheduler
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line">from diffusers.training_utils import compute_snr
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line">
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"># ========== LoRA 模型库 ==========
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line">from peft import LoraConfig, get_peft_model, PeftModel
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line">
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"># ========== 面部检测库 ==========
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line">from deepface import DeepFace
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line">
    37. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line">import cv2
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    5 准备数据

    当前演示使用的是 Brad Pitt(布拉德·皮特),我们的目标是让模型绘制的 man 是 Brad Pitt,粗略地换个表述:AI 换脸。

    那根据我们之前的描述,标注应该长什么样呢?

    :都带 “man”,下面是我们当前数据集的标注示例:

    1. a man with a beard and a suit jacket
    2. a man in a suit and tie standing in front of a crowd
    3. a man with long hair and a tie
    4. ...

    相信你发现了,所有的标注,都不会含有 “Brad Pitt”,那这篇文章训练出的 LoRA 模型的 Trigger Words(触发词)是什么?

    :“a man”。

    是不是很有趣,看似简单的 Prompt 中也有一些真实有用的小技巧和逻辑。别急着去炼丹,我们继续往下看。

    在这里,我们使用 Brad Pitt 的 100 张图片进行演示,数据集已经上传到了Demos/data/14,你可以下载后放到当前目录下的 ./data/14 下。这个路径没有什么说法,单纯是为了对齐示例代码,你也可以修改代码关于数据的路径,这里不会有限制,你甚至可以直接用其他的数据集,只要它的文件组织如下:

    -- 图片1
    -- 图片1.txt
    -- 图片2
    -- 图片2.txt
    ...

    注意:图片和对应的文本标注需要同名,且位于同一文件夹中。

    值得一提的是,样例数据集的裁剪大小和比例都是不一致的,只是接近正方形,但这没有太大的关系,因为在数据预处理的时候会自动放缩(resize),所以在这里不用担心你的数据集无法训练。

    6 设置项目路径

    很好!现在你已经知道这篇文章数据集相关的所有前置知识,直接复制下面的代码运行,不用在意其中的任何代码细节,你只需要知道会创建一个文件夹SD,之后的所有结果都会被存放在其中:

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># 项目名称和数据集名称
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">project_name = "Brad"
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">dataset_name = "Brad"
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line"># 根目录和主要目录
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">root_dir = "./" # 当前目录
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">main_dir = os.path.join(root_dir, "SD") # 主目录
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"># 项目目录
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line">project_dir = os.path.join(main_dir, project_name) # 项目目录
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"># 数据集和模型路径
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">images_folder = os.path.join(main_dir, "Datasets", dataset_name)
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line">prompts_folder = os.path.join(main_dir, "Datasets", "prompts")
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line">captions_folder = images_folder # 与原始代码一致
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line">output_folder = os.path.join(project_dir, "logs") # 存放 model checkpoints 和 validation 的文件夹
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line">
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"># prompt 文件路径
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">validation_prompt_name = "validation_prompt.txt"
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line">validation_prompt_path = os.path.join(prompts_folder, validation_prompt_name)
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line">
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"># 模型检查点路径
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line">model_path = os.path.join(project_dir, "logs", "checkpoint-last")
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line">
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"># 其他路径设置
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line">zip_file = os.path.join("./", "data/14/Datasets.zip")
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line">inference_path = os.path.join(project_dir, "inference") # 保存推理结果的文件夹
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line">
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line">os.makedirs(images_folder, exist_ok=True)
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line">os.makedirs(prompts_folder, exist_ok=True)
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line">os.makedirs(output_folder, exist_ok=True)
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line">os.makedirs(inference_path, exist_ok=True)
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line">
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"># 检查并解压数据集
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line">print("📂 正在检查并解压样例数据集...")
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line">
    37. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line">if not os.path.exists(zip_file):
    38. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line"> print("❌ 未找到数据集压缩文件 Datasets.zip!")
    39. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line"> print("请下载数据集:\n../Demos/data/14/Datasets.zip\n并放在 ./data/14 文件夹下")
    40. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line">else:
    41. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line"> subprocess.run(f"unzip -q -o {zip_file} -d {main_dir}", shell=True)
    42. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"✅ 项目 {project_name} 已准备好!")
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    7 导入数据

    下面,我们需要自定义一个 Dataset 类,它的作用是告诉模型如何处理你的数据集,这个自定义的类能够返回图像和文本标注分别作为 data 和 label。接下来的内容会有点“干”,你也可以将其先当作黑盒,我会在每个函数之后提供一个简练的解释帮你理解。

    7.1 怎么扩充数据集?

    拓展文章:e. 数据增强:torchvision.transforms 常用方法解析

    这里有一个非常熟悉的词:transform,但这个跟我们耳熟能详的 transformer 可不同,transform 就是单纯的对图像进行操作,比如说调整大小,翻转,又或者随机的裁剪一部分区域,这些操作统称为数据增强。

    数据增强就是扩充数据集的外挂,以下图为例,即便进行水平翻转+颜色变化+中心裁剪,它也是一只企鹅。

    这大大地扩充了数据集。

    知道了概念后,简单定义当前的数据增强如下:

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># 训练图像的分辨率
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">resolution = 512
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"># 数据增强操作
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">train_transform = transforms.Compose(
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> [
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), # 调整图像大小
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> transforms.CenterCrop(resolution), # 中心裁剪图像
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> transforms.RandomHorizontalFlip(), # 随机水平翻转
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> transforms.ToTensor(), # 将图像转换为张量
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> ]
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">)
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    7.2 怎么让模型理解文本?

    使用 CLIPTokenizer,这是 Hugging Face transformers 库中的一个类,专门用于对文本进行分词(tokenization)操作。CLIP,全称 Contrastive Language-Image Pretraining(对比语言-图像预训练),Contrastive 这个词说透了它的由来,这是一个非常有意思的自训练思想:通过最大化对应文本-图像对的相似性,同时最小化不同文本-图像对的相似性实现训练。

    学习资料

    论文链接:Learning Transferable Visual Models From Natural Language Supervision 对理论感兴趣的话可以进一步查看以下四个非常棒的视频:

    1. 对比学习论文综述【论文精读】
    2. CLIP 论文逐段精读【论文精读】
    3. CLIP 改进工作串讲(上)【论文精读·42】
    4. CLIP 改进工作串讲(下)【论文精读·42】

    你将发现两个宝藏 UP 主,我无法用语言表达对他们的赞美,只能道一句:“导师好!”。

    具体来说,CLIPTokenizer 将输入的 prompt 拆解为 token(单词或子词),并将这些 token 映射为input_ids 供 CLIP 模型的 text_encoder 处理,从而生成 prompt 的嵌入向量,以让模型理解。

    就像一切数据到了计算机中都变成 0,1 让其处理,所以向上抽象一下,CLIP 就是将人类可以阅读的文本描述变成模型能够理解的形式。

    拓展:看看 Tokenizer 实际上做了什么

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">from transformers import CLIPTokenizer
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"># 初始化 CLIPTokenizer
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"># 示例 prompt
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">prompt_text = "A man in a graphic tee and sport coat."
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"># 先使用 tokenizer.tokenize 查看分词后的 token
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line">tokens = tokenizer.tokenize(prompt_text)
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">print("Tokens:", tokens)
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"># 将文本转化为 token
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line">inputs = tokenizer(
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> prompt_text,
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> padding="max_length", # 如果输入长度不足最大长度,进行填充
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> truncation=True, # 如果输入过长,进行截断
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> return_tensors="pt" # 返回 PyTorch 张量
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">)
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line">
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"># 打印分词后的结果
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line">print("Tokenized Input IDs:", inputs.input_ids)
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line">print("Attention Mask:", inputs.attention_mask)
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    输出:

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">Tokens: ['a', 'man', 'in', 'a', 'graphic', 'tee', 'and', 'sport', 'coat', '.']
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">Tokenized Input IDs: tensor([[49406, 320, 786, 530, 320, 4245, 3385, 537, 2364, 7356,
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> 269, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"> 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line"> 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> 49407, 49407, 49407, 49407, 49407, 49407, 49407]])
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line">Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"> 0, 0, 0, 0, 0]])
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    问:49407 是什么?我们的 prompt 中似乎没有重复的词。

    答:结束标记,这是因为我们设置了 padding="max_length"。思考一下,设置padding=False后输出应该是什么样的?先不要往下滑。

    具体解释:

    padding=False时的输出:

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">Tokens: ['a', 'man', 'in', 'a', 'graphic', 'tee', 'and', 'sport', 'coat', '.']
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">Tokenized Input IDs: tensor([[49406, 320, 786, 530, 320, 4245, 3385, 537, 2364, 7356,
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> 269, 49407]])
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">Attention Mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    是不是和预期一致呢?

    接下来,input_ids 将被传入 text_encoder,生成文本的嵌入向量。

    7.3 自定义数据集

    在认识 transform 和 tokenizer 之后,我们可以定义自己的数据集。这个 Text2ImageDataset 负责将图像和文本配对,并进行数据的预处理,以便输入到模型中。

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># 识别图片后缀
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">class Text2ImageDataset(torch.utils.data.Dataset):
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line"> """
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> (1) 目标:
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> - 用于构建文本到图像模型的微调数据集
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> """
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> def __init__(self, images_folder, captions_folder, transform, tokenizer):
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> """
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> (2) 参数:
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> - images_folder: str, 图像文件夹路径
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"> - captions_folder: str, 标注文件夹路径
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> - transform: function, 将原始图像转换为 torch.Tensor
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> - tokenizer: CLIPTokenizer, 将文本标注转为 word ids
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> """
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> # 初始化图像路径列表,并根据指定的扩展名找到所有图像文件
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> self.image_paths = []
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line"> for ext in IMAGE_EXTENSIONS:
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"> self.image_paths.extend(glob.glob(os.path.join(images_folder, f"*{ext}")))
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> self.image_paths = sorted(self.image_paths)
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line">
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> # 加载对应的文本标注,依次读取每个文本文件中的内容
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"> caption_paths = sorted(glob.glob(os.path.join(captions_folder, "*.txt")))
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> captions = []
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> for p in caption_paths:
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"> with open(p, "r", encoding="utf-8") as f:
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line"> captions.append(f.readline().strip())
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line">
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> # 确保图像和文本标注数量一致
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"> if len(captions) != len(self.image_paths):
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line"> raise ValueError("图像数量与文本标注数量不一致,请检查数据集。")
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line">
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> # 使用 tokenizer 将文本标注转换为 word ids
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line"> inputs = tokenizer(
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line"> captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    37. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line"> )
    38. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line"> self.input_ids = inputs.input_ids
    39. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line"> self.transform = transform
    40. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line">
    41. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line"> def __getitem__(self, idx):
    42. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line"> img_path = self.image_paths[idx]
    43. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="43"> class="hljs-ln-code"> class="hljs-ln-line"> input_id = self.input_ids[idx]
    44. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="44"> class="hljs-ln-code"> class="hljs-ln-line"> try:
    45. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="45"> class="hljs-ln-code"> class="hljs-ln-line"> # 加载图像并将其转换为 RGB 模式,然后应用数据增强
    46. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="46"> class="hljs-ln-code"> class="hljs-ln-line"> image = Image.open(img_path).convert("RGB")
    47. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="47"> class="hljs-ln-code"> class="hljs-ln-line"> tensor = self.transform(image)
    48. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="48"> class="hljs-ln-code"> class="hljs-ln-line"> except Exception as e:
    49. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="49"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"⚠️ 无法加载图像路径: {img_path}, 错误: {e}")
    50. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="50"> class="hljs-ln-code"> class="hljs-ln-line"> # 返回一个全零的张量和空的输入 ID 以避免崩溃
    51. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="51"> class="hljs-ln-code"> class="hljs-ln-line"> tensor = torch.zeros((3, resolution, resolution))
    52. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="52"> class="hljs-ln-code"> class="hljs-ln-line"> input_id = torch.zeros_like(input_id)
    53. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="53"> class="hljs-ln-code"> class="hljs-ln-line">
    54. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="54"> class="hljs-ln-code"> class="hljs-ln-line"> return tensor, input_id # 返回处理后的图像和相应的文本标注
    55. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="55"> class="hljs-ln-code"> class="hljs-ln-line">
    56. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="56"> class="hljs-ln-code"> class="hljs-ln-line"> def __len__(self):
    57. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="57"> class="hljs-ln-code"> class="hljs-ln-line"> return len(self.image_paths)
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    解释

    8 定义微调相关的函数

    8.1 加载 LoRA

    前置文章:

    LoRA(Low-Rank Adaptation) 是一种非常高效的参数微调方法,通过在预训练模型的特定层添加小的低秩矩阵(可以联想线性代数中的奇异值分解),来实现模型的微调,这也是一类 Adapter。

    LoRA 的核心思想是将大模型中的某些权重矩阵近似为两个低秩矩阵进行更新,从而大幅减少需要微调的参数数量,提高训练效率和节省存储空间。一般而言,模型越大,减小比例越夸张,对于 GPT-3,LoRA 微调的训练参数量为原来的 1/10000。

    通常,在微调时我们只对模型的特定部分(如注意力机制中的 Q、K、V 矩阵)进行 LoRA 微调,而不是微调整个模型。这里选择对 unet 和 text_encoder 增加 LoRA,因为这两个模块直接负责图像生成和文本引导中的关键任务:unet 处理扩散过程的逆运算,text_encoder 将输入文本转换为特征向量。下面,我们定义一个函数来应用 LoRA 模型。

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">def prepare_lora_model(lora_config, pretrained_model_name_or_path, model_path=None, resume=False, merge_lora=False):
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line"> """
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> (1) 目标:
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"> - 加载完整的 Stable Diffusion 模型,包括 LoRA 层,并根据需要合并 LoRA 权重。这包括 Tokenizer、噪声调度器、UNet、VAE 和文本编码器。
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> (2) 参数:
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> - lora_config: LoraConfig, LoRA 的配置对象
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> - pretrained_model_name_or_path: str, Hugging Face 上的模型名称或路径
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> - model_path: str, 预训练模型的路径
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> - resume: bool, 是否从上一次训练中恢复
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> - merge_lora: bool, 是否在推理时合并 LoRA 权重
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"> (3) 返回:
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> - tokenizer: CLIPTokenizer
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> - noise_scheduler: DDPMScheduler
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> - unet: UNet2DConditionModel
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> - vae: AutoencoderKL
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> - text_encoder: CLIPTextModel
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line"> """
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"> # 加载噪声调度器,用于控制扩散模型的噪声添加和移除过程
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line">
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> # 加载 Tokenizer,用于将文本标注转换为 tokens
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"> tokenizer = CLIPTokenizer.from_pretrained(
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> pretrained_model_name_or_path,
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> subfolder="tokenizer"
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"> )
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line">
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> # 加载 CLIP 文本编码器,用于将文本标注转换为特征向量
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder = CLIPTextModel.from_pretrained(
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"> pretrained_model_name_or_path,
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line"> torch_dtype=weight_dtype,
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line"> subfolder="text_encoder"
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> )
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line">
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line"> # 加载 VAE 模型,用于在扩散模型中处理图像的潜在表示
    37. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line"> vae = AutoencoderKL.from_pretrained(
    38. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line"> pretrained_model_name_or_path,
    39. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line"> subfolder="vae"
    40. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line"> )
    41. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line">
    42. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line"> # 加载 UNet 模型,负责处理扩散模型中的图像生成和推理过程
    43. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="43"> class="hljs-ln-code"> class="hljs-ln-line"> unet = UNet2DConditionModel.from_pretrained(
    44. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="44"> class="hljs-ln-code"> class="hljs-ln-line"> pretrained_model_name_or_path,
    45. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="45"> class="hljs-ln-code"> class="hljs-ln-line"> torch_dtype=weight_dtype,
    46. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="46"> class="hljs-ln-code"> class="hljs-ln-line"> subfolder="unet"
    47. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="47"> class="hljs-ln-code"> class="hljs-ln-line"> )
    48. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="48"> class="hljs-ln-code"> class="hljs-ln-line">
    49. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="49"> class="hljs-ln-code"> class="hljs-ln-line"> # 如果设置为继续训练,则加载上一次的模型权重
    50. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="50"> class="hljs-ln-code"> class="hljs-ln-line"> if resume:
    51. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="51"> class="hljs-ln-code"> class="hljs-ln-line"> if model_path is None or not os.path.exists(model_path):
    52. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="52"> class="hljs-ln-code"> class="hljs-ln-line"> raise ValueError("当 resume 设置为 True 时,必须提供有效的 model_path")
    53. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="53"> class="hljs-ln-code"> class="hljs-ln-line"> # 使用 PEFT 的 from_pretrained 方法加载 LoRA 模型
    54. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="54"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(model_path, "text_encoder"))
    55. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="55"> class="hljs-ln-code"> class="hljs-ln-line"> unet = PeftModel.from_pretrained(unet, os.path.join(model_path, "unet"))
    56. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="56"> class="hljs-ln-code"> class="hljs-ln-line">
    57. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="57"> class="hljs-ln-code"> class="hljs-ln-line"> # 确保 UNet 的可训练参数的 requires_grad 为 True
    58. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="58"> class="hljs-ln-code"> class="hljs-ln-line"> for param in unet.parameters():
    59. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="59"> class="hljs-ln-code"> class="hljs-ln-line"> if param.requires_grad is False:
    60. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="60"> class="hljs-ln-code"> class="hljs-ln-line"> param.requires_grad = True
    61. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="61"> class="hljs-ln-code"> class="hljs-ln-line">
    62. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="62"> class="hljs-ln-code"> class="hljs-ln-line"> # 确保文本编码器的可训练参数的 requires_grad 为 True
    63. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="63"> class="hljs-ln-code"> class="hljs-ln-line"> for param in text_encoder.parameters():
    64. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="64"> class="hljs-ln-code"> class="hljs-ln-line"> if param.requires_grad is False:
    65. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="65"> class="hljs-ln-code"> class="hljs-ln-line"> param.requires_grad = True
    66. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="66"> class="hljs-ln-code"> class="hljs-ln-line">
    67. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="67"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"✅ 已从 {model_path} 恢复模型权重")
    68. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="68"> class="hljs-ln-code"> class="hljs-ln-line">
    69. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="69"> class="hljs-ln-code"> class="hljs-ln-line"> else:
    70. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="70"> class="hljs-ln-code"> class="hljs-ln-line"> # 将 LoRA 配置应用到 text_encoder 和 unet
    71. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="71"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder = get_peft_model(text_encoder, lora_config)
    72. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="72"> class="hljs-ln-code"> class="hljs-ln-line"> unet = get_peft_model(unet, lora_config)
    73. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="73"> class="hljs-ln-code"> class="hljs-ln-line">
    74. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="74"> class="hljs-ln-code"> class="hljs-ln-line"> # 打印可训练参数数量
    75. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="75"> class="hljs-ln-code"> class="hljs-ln-line"> print("📊 Text Encoder 可训练参数:")
    76. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="76"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder.print_trainable_parameters()
    77. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="77"> class="hljs-ln-code"> class="hljs-ln-line"> print("📊 UNet 可训练参数:")
    78. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="78"> class="hljs-ln-code"> class="hljs-ln-line"> unet.print_trainable_parameters()
    79. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="79"> class="hljs-ln-code"> class="hljs-ln-line">
    80. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="80"> class="hljs-ln-code"> class="hljs-ln-line"> if merge_lora:
    81. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="81"> class="hljs-ln-code"> class="hljs-ln-line"> # 合并 LoRA 权重到基础模型,仅在推理时调用
    82. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="82"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder = text_encoder.merge_and_unload()
    83. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="83"> class="hljs-ln-code"> class="hljs-ln-line"> unet = unet.merge_and_unload()
    84. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="84"> class="hljs-ln-code"> class="hljs-ln-line">
    85. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="85"> class="hljs-ln-code"> class="hljs-ln-line"> # 切换为评估模式
    86. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="86"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder.eval()
    87. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="87"> class="hljs-ln-code"> class="hljs-ln-line"> unet.eval()
    88. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="88"> class="hljs-ln-code"> class="hljs-ln-line">
    89. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="89"> class="hljs-ln-code"> class="hljs-ln-line"> # 冻结 VAE 参数
    90. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="90"> class="hljs-ln-code"> class="hljs-ln-line"> vae.requires_grad_(False)
    91. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="91"> class="hljs-ln-code"> class="hljs-ln-line">
    92. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="92"> class="hljs-ln-code"> class="hljs-ln-line"> # 将模型移动到 GPU 上并设置权重的数据类型
    93. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="93"> class="hljs-ln-code"> class="hljs-ln-line"> unet.to(DEVICE, dtype=weight_dtype)
    94. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="94"> class="hljs-ln-code"> class="hljs-ln-line"> vae.to(DEVICE, dtype=weight_dtype)
    95. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="95"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder.to(DEVICE, dtype=weight_dtype)
    96. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="96"> class="hljs-ln-code"> class="hljs-ln-line">
    97. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="97"> class="hljs-ln-code"> class="hljs-ln-line"> return tokenizer, noise_scheduler, unet, vae, text_encoder
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    解释:

    为什么只微调 unet 和 text_encoder 最终却返回这么多模块?

    因为在后面的微调中,我们将从文本开始处理而非将其当作又一个黑盒。

    8.2 准备优化器

    接下来,需要对于应用了 LoRA 的 UNet 和文本编码器(text_encoder)分别使用不同的学习率,这也是炼丹炉 UI 中常需要调节的选项。

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">def prepare_optimizer(unet, text_encoder, unet_learning_rate=5e-4, text_encoder_learning_rate=1e-4):
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line"> """
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> (1) 目标:
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"> - 为 UNet 和文本编码器的可训练参数分别设置优化器,并指定不同的学习率。
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> (2) 参数:
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> - unet: UNet2DConditionModel, Hugging Face 的 UNet 模型
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> - text_encoder: CLIPTextModel, Hugging Face 的文本编码器
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> - unet_learning_rate: float, UNet 的学习率
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> - text_encoder_learning_rate: float, 文本编码器的学习率
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> (3) 返回:
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"> - 输出: 优化器 Optimizer
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> """
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> # 筛选出 UNet 中需要训练的 Lora 层参数
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> unet_lora_layers = [p for p in unet.parameters() if p.requires_grad]
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line">
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> # 筛选出文本编码器中需要训练的 Lora 层参数
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder_lora_layers = [p for p in text_encoder.parameters() if p.requires_grad]
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line">
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> # 将需要训练的参数分组并设置不同的学习率
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> trainable_params = [
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> {"params": unet_lora_layers, "lr": unet_learning_rate},
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"> {"params": text_encoder_lora_layers, "lr": text_encoder_learning_rate}
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> ]
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line">
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"> # 使用 AdamW 优化器
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer = torch.optim.AdamW(trainable_params)
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line">
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> return optimizer
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    8.3 定义 collate_fn 函数

    在大多数常见的机器学习任务中(例如图像分类或回归),数据集通常是简单的 (data, label) 结构,PyTorch 的 DataLoader 默认能够处理这样的简单数据结构,将样本打包成批次(batch)。在我们的项目中,每个样本也是一个包含图像张量和文本编码的元组 (tensor, input_id)。默认的 collate_fn 可以将这些样本打包成批次,访问时需要使用索引,例如 batch[0] 和 batch[1]

    为了使代码更具可读性,我们可以自定义一个 collate_fn 函数,将批次数据组织成字典的形式,方便通过键名直接访问,例如 batch["pixel_values"] 和 batch["input_ids"]。自定义的 collate_fn 定义如下:

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">def collate_fn(examples):
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line"> pixel_values = []
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> input_ids = []
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line"> for tensor, input_id in examples:
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> pixel_values.append(tensor)
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> input_ids.append(input_id)
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> pixel_values = torch.stack(pixel_values, dim=0).float()
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> input_ids = torch.stack(input_ids, dim=0)
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> # 如果你喜欢列表推导式的话,使用下面的方法
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"> #pixel_values = torch.stack([example[0] for example in examples], dim=0).float()
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> #input_ids = torch.stack([example[1] for example in examples], dim=0)
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> return {"pixel_values": pixel_values, "input_ids": input_ids}
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    解释:

    补充:PyTorch 的 torch.stack() 函数会将多个张量沿新维度拼接在一起。例如,将一批图像张量拼接成 (batch_size, C, H, W) 的形式,确保每个批次数据的组织结构一致。

    拓展:自定义和默认 collate_fn 的对比

    下面提供了一个对比函数,来展示自定义 collate_fn 和默认 collate_fn 在处理当前数据时的不同。你可以通过运行代码来观察自定义和默认方式的使用差异。

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">import torch
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">from torch.utils.data import DataLoader, Dataset
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">def compare_dataloaders(dataset, batch_size):
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line"> # 第一种情况:使用自定义的 collate_fn
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> train_dataloader_custom = DataLoader(
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> dataset,
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> shuffle=True,
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> collate_fn=collate_fn, # 使用自定义的 collate_fn
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> batch_size=batch_size,
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> )
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"> # 第二种情况:不使用自定义的 collate_fn(默认方式)
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> train_dataloader_default = DataLoader(
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> dataset,
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> shuffle=True,
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> batch_size=batch_size,
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> )
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"> # 从每个数据加载器中取一个批次进行对比
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> custom_batch = next(iter(train_dataloader_custom))
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> default_batch = next(iter(train_dataloader_default))
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line">
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"> # 打印自定义 collate_fn 的输出结果
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> print("使用自定义 collate_fn:")
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> print("批次的类型:", type(custom_batch))
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"> print("批次 pixel_values 的形状:", custom_batch["pixel_values"].shape)
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line"> print("批次 input_ids 的形状:", custom_batch["input_ids"].shape)
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line">
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> # 打印默认 DataLoader 的输出结果
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"> print("\n使用默认 collate_fn:")
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line"> print("批次的类型:", type(default_batch))
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line">
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> pixel_values, input_ids = default_batch
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line"> print("批次 pixel_values 的形状:", pixel_values.shape)
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line"> print("批次 input_ids 的形状:", input_ids.shape)
    37. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line">
    38. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line"> return custom_batch, default_batch
    39. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line">
    40. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line"># 对比
    41. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line">custom_batch, default_batch = compare_dataloaders(dataset, batch_size=2)
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    输出

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">使用自定义 collate_fn:
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">批次的类型: <class 'dict'>
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">批次 pixel_values 的形状: torch.Size([2, 3, 224, 224])
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">批次 input_ids 的形状: torch.Size([2, 16])
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">使用默认 collate_fn:
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">批次的类型: <class 'list'>
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">批次 pixel_values 的形状: torch.Size([2, 3, 224, 224])
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line">批次 input_ids 的形状: torch.Size([2, 16])
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    具体选择哪一种由你决定,默认的方法实际上更普遍。

    9 设置相关参数

    9.1 设备配置

    当前的微调毫无疑问需要用到显卡(GPU),对于 Apple 芯片的 Mac 来说,把 "cuda" 改为 "mps",也就是使用第二行代码,但需要注意的是,对于PyTorch版本过低的环境, torch.backends.mps.is_available() 会报错,所以这里选择注释。

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># 设备配置
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"># For Mac M1, M2...
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line"># DEVICE = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">print(f"🖥 当前使用的设备: {DEVICE}")
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    9.2 模型与训练参数配置

    这里的参数大多与之前的函数相关,下面是你可以调节的内容:

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># 训练相关参数
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">train_batch_size = 2 # 训练批次大小,即每次训练中处理的样本数量
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">weight_dtype = torch.bfloat16 # 权重数据类型,使用 bfloat16 以节省内存并加快计算速度
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">snr_gamma = 5 # SNR 参数,用于信噪比加权损失的调节系数
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"># 设置随机数种子以确保可重复性
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">seed = 1126 # 随机数种子
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">torch.manual_seed(seed)
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line">if torch.cuda.is_available():
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> torch.cuda.manual_seed_all(seed)
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"># Stable Diffusion LoRA 的微调参数
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"># 优化器参数
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line">unet_learning_rate = 1e-4 # UNet 的学习率,控制 UNet 参数更新的步长
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line">text_encoder_learning_rate = 1e-4 # 文本编码器的学习率,控制文本嵌入层的参数更新步长
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line">
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"># 学习率调度器参数
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">lr_scheduler_name = "cosine_with_restarts" # 设置学习率调度器为 Cosine annealing with restarts,逐渐减少学习率并定期重启
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line">lr_warmup_steps = 100 # 学习率预热步数,在最初的 100 步中逐渐增加学习率到最大值
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line">max_train_steps = 2000 # 总训练步数,决定了整个训练过程的迭代次数
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line">num_cycles = 3 # Cosine 调度器的周期数量,在训练期间会重复 3 次学习率周期性递减并重启
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line">
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"># 预训练的 Stable Diffusion 模型路径,用于加载模型进行微调
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line">pretrained_model_name_or_path = "stablediffusionapi/cyberrealistic-41"
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line">
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"># LoRA 配置
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line">lora_config = LoraConfig(
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> r=32, # LoRA 的秩,即低秩矩阵的维度,决定了参数调整的自由度
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> lora_alpha=16, # 缩放系数,控制 LoRA 权重对模型的影响
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"> target_modules=[
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line"> "q_proj", "v_proj", "k_proj", "out_proj", # 指定 Text encoder 的 LoRA 应用对象(用于调整注意力机制中的投影矩阵)
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line"> "to_k", "to_q", "to_v", "to_out.0" # 指定 UNet 的 LoRA 应用对象(用于调整 UNet 中的注意力机制)
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> ],
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line"> lora_dropout=0 # LoRA dropout 概率,0 表示不使用 dropout
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line">)
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    10 微调前的准备

    10.1 准备数据集

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># 初始化 tokenizer
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">tokenizer = CLIPTokenizer.from_pretrained(
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> pretrained_model_name_or_path,
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"> subfolder="tokenizer"
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">)
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"># 准备数据集
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">dataset = Text2ImageDataset(
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> images_folder=images_folder,
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> captions_folder=captions_folder,
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> transform=train_transform,
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> tokenizer=tokenizer,
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">)
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line">
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line">train_dataloader = torch.utils.data.DataLoader(
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> dataset,
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> shuffle=True,
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> collate_fn=collate_fn, # 之前定义的collate_fn()
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line"> batch_size=train_batch_size,
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"> num_workers=8,
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line">)
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line">
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line">print("✅ 数据集准备完成!")
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    解释:

    10.2 准备模型和优化器

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># 准备模型
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">tokenizer, noise_scheduler, unet, vae, text_encoder = prepare_lora_model(
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> lora_config,
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"> pretrained_model_name_or_path,
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line"> model_path,
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> resume=False,
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> merge_lora=False
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">)
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line">
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"># 准备优化器
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line">optimizer = prepare_optimizer(
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> unet,
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder,
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> unet_learning_rate=unet_learning_rate,
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder_learning_rate=text_encoder_learning_rate
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line">)
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line">
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"># 设置学习率调度器
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">lr_scheduler = get_scheduler(
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"> lr_scheduler_name,
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer=optimizer,
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> num_warmup_steps=lr_warmup_steps,
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> num_training_steps=max_train_steps,
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"> num_cycles=num_cycles
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line">)
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line">
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line">print("✅ 模型和优化器准备完成!可以开始训练。")
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    解释:

    11 开始微调

    主要流程和结构如下:

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># 禁用并行化,避免警告
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">os.environ["TOKENIZERS_PARALLELISM"] = "false"
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"># 初始化
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">global_step = 0
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">best_face_score = float("inf") # 初始化为正无穷大,存储最佳面部相似度分数
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"># 进度条显示训练进度
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line">progress_bar = tqdm(
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> range(max_train_steps), # 根据 num_training_steps 设置
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> desc="训练步骤",
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">)
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"># 训练循环
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line">for epoch in range(math.ceil(max_train_steps / len(train_dataloader))):
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> # 如果你想在训练中增加评估,那在循环中增加 train() 是有必要的
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> unet.train()
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder.train()
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"> for step, batch in enumerate(train_dataloader):
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> if global_step >= max_train_steps:
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> break
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line">
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line"> # 编码图像为潜在表示(latent)
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> latents = vae.encode(batch["pixel_values"].to(DEVICE, dtype=weight_dtype)).latent_dist.sample()
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> latents = latents * vae.config.scaling_factor # 根据 VAE 的缩放因子调整潜在空间
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line">
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line"> # 为潜在表示添加噪声,生成带噪声的图像
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> noise = torch.randn_like(latents) # 生成与潜在表示相同形状的随机噪声
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=DEVICE).long()
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"> noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line">
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line"> # 获取文本的嵌入表示
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> encoder_hidden_states = text_encoder(batch["input_ids"].to(DEVICE))[0]
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line">
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line"> # 计算目标值
    37. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line"> if noise_scheduler.config.prediction_type == "epsilon":
    38. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line"> target = noise # 预测噪声
    39. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line"> elif noise_scheduler.config.prediction_type == "v_prediction":
    40. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line"> target = noise_scheduler.get_velocity(latents, noise, timesteps) # 预测速度向量
    41. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line">
    42. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line"> # UNet 模型预测
    43. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="43"> class="hljs-ln-code"> class="hljs-ln-line"> model_pred = unet(noisy_latents, timesteps, encoder_hidden_states)[0]
    44. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="44"> class="hljs-ln-code"> class="hljs-ln-line">
    45. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="45"> class="hljs-ln-code"> class="hljs-ln-line"> # 计算损失
    46. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="46"> class="hljs-ln-code"> class="hljs-ln-line"> if not snr_gamma:
    47. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="47"> class="hljs-ln-code"> class="hljs-ln-line"> loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
    48. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="48"> class="hljs-ln-code"> class="hljs-ln-line"> else:
    49. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="49"> class="hljs-ln-code"> class="hljs-ln-line"> # 计算信噪比 (SNR) 并根据 SNR 加权 MSE 损失
    50. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="50"> class="hljs-ln-code"> class="hljs-ln-line"> snr = compute_snr(noise_scheduler, timesteps)
    51. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="51"> class="hljs-ln-code"> class="hljs-ln-line"> mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0]
    52. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="52"> class="hljs-ln-code"> class="hljs-ln-line"> if noise_scheduler.config.prediction_type == "epsilon":
    53. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="53"> class="hljs-ln-code"> class="hljs-ln-line"> mse_loss_weights = mse_loss_weights / snr
    54. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="54"> class="hljs-ln-code"> class="hljs-ln-line"> elif noise_scheduler.config.prediction_type == "v_prediction":
    55. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="55"> class="hljs-ln-code"> class="hljs-ln-line"> mse_loss_weights = mse_loss_weights / (snr + 1)
    56. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="56"> class="hljs-ln-code"> class="hljs-ln-line">
    57. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="57"> class="hljs-ln-code"> class="hljs-ln-line"> # 计算加权的 MSE 损失
    58. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="58"> class="hljs-ln-code"> class="hljs-ln-line"> loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
    59. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="59"> class="hljs-ln-code"> class="hljs-ln-line"> loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
    60. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="60"> class="hljs-ln-code"> class="hljs-ln-line"> loss = loss.mean()
    61. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="61"> class="hljs-ln-code"> class="hljs-ln-line">
    62. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="62"> class="hljs-ln-code"> class="hljs-ln-line"> # 反向传播
    63. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="63"> class="hljs-ln-code"> class="hljs-ln-line"> loss.backward()
    64. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="64"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer.step()
    65. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="65"> class="hljs-ln-code"> class="hljs-ln-line"> lr_scheduler.step()
    66. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="66"> class="hljs-ln-code"> class="hljs-ln-line"> optimizer.zero_grad()
    67. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="67"> class="hljs-ln-code"> class="hljs-ln-line"> progress_bar.update(1)
    68. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="68"> class="hljs-ln-code"> class="hljs-ln-line"> global_step += 1
    69. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="69"> class="hljs-ln-code"> class="hljs-ln-line">
    70. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="70"> class="hljs-ln-code"> class="hljs-ln-line"> # 打印训练损失
    71. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="71"> class="hljs-ln-code"> class="hljs-ln-line"> if global_step % 100 == 0 or global_step == max_train_steps:
    72. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="72"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"🔥 步骤 {global_step}, 损失: {loss.item()}")
    73. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="73"> class="hljs-ln-code"> class="hljs-ln-line">
    74. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="74"> class="hljs-ln-code"> class="hljs-ln-line"> # 保存中间检查点,当前简单设置为每 500 步保存一次
    75. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="75"> class="hljs-ln-code"> class="hljs-ln-line"> if global_step % 500 == 0:
    76. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="76"> class="hljs-ln-code"> class="hljs-ln-line"> save_path = os.path.join(output_folder, f"checkpoint-{global_step}")
    77. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="77"> class="hljs-ln-code"> class="hljs-ln-line"> os.makedirs(save_path, exist_ok=True)
    78. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="78"> class="hljs-ln-code"> class="hljs-ln-line">
    79. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="79"> class="hljs-ln-code"> class="hljs-ln-line"> # 使用 save_pretrained 保存 PeftModel
    80. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="80"> class="hljs-ln-code"> class="hljs-ln-line"> unet.save_pretrained(os.path.join(save_path, "unet"))
    81. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="81"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder.save_pretrained(os.path.join(save_path, "text_encoder"))
    82. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="82"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"💾 已保存中间模型到 {save_path}")
    83. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="83"> class="hljs-ln-code"> class="hljs-ln-line">
    84. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="84"> class="hljs-ln-code"> class="hljs-ln-line"># 保存最终模型到 checkpoint-last
    85. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="85"> class="hljs-ln-code"> class="hljs-ln-line">save_path = os.path.join(output_folder, "checkpoint-last")
    86. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="86"> class="hljs-ln-code"> class="hljs-ln-line">os.makedirs(save_path, exist_ok=True)
    87. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="87"> class="hljs-ln-code"> class="hljs-ln-line">unet.save_pretrained(os.path.join(save_path, "unet"))
    88. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="88"> class="hljs-ln-code"> class="hljs-ln-line">text_encoder.save_pretrained(os.path.join(save_path, "text_encoder"))
    89. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="89"> class="hljs-ln-code"> class="hljs-ln-line">print(f"💾 已保存最终模型到 {save_path}")
    90. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="90"> class="hljs-ln-code"> class="hljs-ln-line">
    91. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="91"> class="hljs-ln-code"> class="hljs-ln-line">print("🎉 微调完成!")
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    训练完成后的 checkpoint 会保存到 ./SD/Brad/logs/checkpoint-last 中,以 max_train_steps=200 为例,模型输出如下:

    12 生成图像和评估

    12.1 什么是 pipeline?

    pipeline 是 Hugging Face 库中一种高层次的封装工具,通常用于推理。默认情况下,pipeline 以 eval 模式加载模型,因此适合用于生成或评估场景。我们这里使用的是 Diffusers.DiffusionPipeline,它将之前提到的多个模型组件(如 UNet、VAE、文本编码器等)组合在一起,实现从文本到图像的生成。

    pipeline 的工作原理也跟之前微调过程类似:

    1. 文本编码pipeline 中的文本编码器会将输入的 prompt 转换为特征向量。
    2. 噪声注入:在潜在空间中,模型从随机噪声开始生成图像。
    3. 迭代去噪:UNet 使用从文本编码器得到的特征向量指导去噪过程,逐步将噪声还原为高质量图像。
    4. 图像解码:最终,VAE 将潜在表示解码为实际的图像。

    12.2 推理相关的参数

    1. 什么是推理步数(num_inference_steps)?

      推理步数控制扩散模型生成图像时的去噪迭代次数。步数越多,生成的图像质量越高,但推理时间也相应增加。这是一个需要你根据图像质量和时间需求去权衡的参数,通常在肉眼觉得够好的时候,就可以了。

    2. 如何决定 prompt 的影响程度(guidance_scale)?

      guidance_scale 决定了文本提示对生成图像的影响程度。较高的 guidance_scale 会让模型更严格地按照 prompt 生成图像,数值通常在 7.5 到 10 之间调整,过高可能会导致图像失真,同样需要你去权衡。这个参数与文本生成任务中的 temperature 参数类似,适用于不同场景。

    3. 怎么确保相同 prompt 生成相同的图像?

      设置固定的随机数种子(seed),可以确保同样的 prompt 在每次运行时生成相同的图像。可以通过使用 torch.Generator 生成随机数并设置种子(seed),示例如下:

      generator = torch.Generator().manual_seed(42)

    12.3 加载用于验证的 prompts

    这是一组用于生图的文本提示(prompts),本实验中位于./SD/Datasets/prompts/validation_prompt.txt,下面摘取几行 prompt 预览:

    定义加载 prompts 的函数如下:

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">def load_validation_prompts(validation_prompt_path):
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line"> """
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> (1) 目标:
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"> - 加载验证提示文本。
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> (2) 参数:
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> - validation_prompt_path: str, 验证提示文件的路径
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> (3) 返回:
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> - validation_prompt: list, 验证提示的字符串列表,每一行就是一个prompt
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> """
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> with open(validation_prompt_path, "r", encoding="utf-8") as f:
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"> validation_prompt = [line.strip() for line in f.readlines()]
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> return validation_prompt
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    12.4 定义生成图像的函数

    结合之前的讨论,我们可以定义一个生成图像的函数:

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">def generate_images(pipeline, prompts, num_inference_steps=50, guidance_scale=7.5, output_folder="inference", generator=None):
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line"> """
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> (1) 目标:
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"> - 使用 DiffusionPipeline 生成图像,保存到指定文件夹并返回生成的图像列表。
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> (2) 参数:
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> - pipeline: DiffusionPipeline, 已加载并配置好的 Pipeline
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> - prompts: list, 文本提示列表
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> - num_inference_steps: int, 推理步骤数,越高图像质量越好,但推理时间也会增加
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> - guidance_scale: float, 决定文本提示对生成图像的影响程度
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> - output_folder: str, 保存生成图像的文件夹路径
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> - generator: torch.Generator, 控制生成随机数的种子,确保图像生成的一致性。如果不提供,生成的图像每次可能不同
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line"> (3) 返回:
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> - 生成的图像列表,同时图像也会保存到指定文件夹。
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> """
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> print("🎨 正在生成图像...")
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> os.makedirs(output_folder, exist_ok=True)
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line"> generated_images = []
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line">
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> for i, prompt in enumerate(tqdm(prompts, desc="生成图像中")):
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> # 使用 pipeline 生成图像
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> image = pipeline(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator).images[0]
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line">
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> # 保存图像到指定文件夹
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> save_file = os.path.join(output_folder, f"generated_{i+1}.png")
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"> image.save(save_file)
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line">
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> # 将图像保存到列表中,稍后返回
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> generated_images.append(image)
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line">
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"✅ 已生成并保存 {len(prompts)} 张图像到 {output_folder}")
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line">
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> return generated_images
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    12.5 定义评估函数

    虽然图像生成的好与坏现在更多的由人去判断,但最基础的模块还是可以交给机器,以当前实验为例,我们的目的是 “AI 换脸”,那就可以有两个新的度量:

    除了人脸生成之外,AI 图像生成领域还有很多其他应用场景。那么,有没有通用的评估方法来衡量生成图像与文本提示的匹配度呢?

    有,CLIP 评分

    是的,CLIP 除了可以处理文本输入,还可以评估最终的模型,无论生成的是人脸、风景还是物体,它都可以帮助我们判断生成图像与文本提示的相关性。

    对于当前实验,我们采取这三种方式对模型进行度量,完整流程如下:

    1. 使用 load_validation_prompts() 函数从文件中加载 prompts。
    2. 使用 prepare_lora_model() 函数加载已经经过 LoRA 微调的 UNet 和文本编码器(text_encoder),并合并 LoRA 权重。模型会从上一次训练保存的文件中恢复权重。
    3. 使用已经微调的 UNet 和文本编码器来创建 DiffusionPipeline
    4. 加载 CLIP 模型后续用于评估。
    5. 使用 DeepFace 提取训练图像的面部嵌入 train_emb 与生成的图像进行对比,计算面部相似度。
    6. 进行评估,最后打印结果。
    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">def evaluate(lora_config):
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line"> """
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> 加载模型、生成图像并评估。
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line"> 主要步骤:
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> 1. 加载验证文本提示(prompts)用于生成图像。
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> 2. 加载和准备 LoRA 微调后的模型。
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> 3. 使用 DiffusionPipeline 生成图像。
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line"> 4. 评估生成图像的人脸相似度、CLIP 评分和无面部图像数量。
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line"> 5. 打印评估结果。
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> """
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line"> print("📂 加载验证提示...")
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line"> validation_prompts = load_validation_prompts(validation_prompt_path)
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line">
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line"> print("🔧 准备 LoRA 模型...")
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line"> # 准备 LoRA 模型(用于推理,合并权重)
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line"> tokenizer, noise_scheduler, unet, vae, text_encoder = prepare_lora_model(
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> lora_config,
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line"> pretrained_model_name_or_path,
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line"> model_path=model_path,
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> resume=True, # 从检查点恢复
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line"> merge_lora=True # 合并 LoRA 权重
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> )
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line">
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> # 创建 DiffusionPipeline 并更新其组件
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line"> print("🔄 创建 DiffusionPipeline...")
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"> pipeline = DiffusionPipeline.from_pretrained(
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line"> pretrained_model_name_or_path,
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> unet=unet, # 传递基础模型
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line"> text_encoder=text_encoder, # 传递基础模型
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"> torch_dtype=weight_dtype,
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line"> safety_checker=None,
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line"> )
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line"> pipeline = pipeline.to(DEVICE)
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line">
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line"> # 加载 CLIP 模型和处理器
    37. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line"> print("🎯 加载 CLIP 模型...")
    38. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line"> clip_model_name = "openai/clip-vit-base-patch32"
    39. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line"> clip_model = CLIPModel.from_pretrained(clip_model_name).to(DEVICE)
    40. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line"> clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
    41. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line">
    42. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line"> # CLIP 模型设置为评估模式
    43. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="43"> class="hljs-ln-code"> class="hljs-ln-line"> clip_model.eval()
    44. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="44"> class="hljs-ln-code"> class="hljs-ln-line">
    45. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="45"> class="hljs-ln-code"> class="hljs-ln-line"> # 设置随机数种子
    46. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="46"> class="hljs-ln-code"> class="hljs-ln-line"> generator = torch.Generator(device=DEVICE)
    47. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="47"> class="hljs-ln-code"> class="hljs-ln-line"> generator.manual_seed(seed)
    48. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="48"> class="hljs-ln-code"> class="hljs-ln-line">
    49. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="49"> class="hljs-ln-code"> class="hljs-ln-line"> # 加载训练图像的面部嵌入
    50. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="50"> class="hljs-ln-code"> class="hljs-ln-line"> print("📂 加载训练图像的面部嵌入...")
    51. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="51"> class="hljs-ln-code"> class="hljs-ln-line"> train_image_paths = sorted([
    52. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="52"> class="hljs-ln-code"> class="hljs-ln-line"> p for p in glob.glob(os.path.join(images_folder, "*"))
    53. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="53"> class="hljs-ln-code"> class="hljs-ln-line"> if any(p.endswith(ext) for ext in IMAGE_EXTENSIONS)
    54. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="54"> class="hljs-ln-code"> class="hljs-ln-line"> ])
    55. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="55"> class="hljs-ln-code"> class="hljs-ln-line"> train_emb_list = []
    56. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="56"> class="hljs-ln-code"> class="hljs-ln-line"> for img_path in tqdm(train_image_paths, desc="提取训练图像面部嵌入"):
    57. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="57"> class="hljs-ln-code"> class="hljs-ln-line"> face_representation = DeepFace.represent(
    58. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="58"> class="hljs-ln-code"> class="hljs-ln-line"> img_path,
    59. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="59"> class="hljs-ln-code"> class="hljs-ln-line"> detector_backend="ssd",
    60. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="60"> class="hljs-ln-code"> class="hljs-ln-line"> model_name="GhostFaceNet",
    61. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="61"> class="hljs-ln-code"> class="hljs-ln-line"> enforce_detection=False
    62. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="62"> class="hljs-ln-code"> class="hljs-ln-line"> )
    63. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="63"> class="hljs-ln-code"> class="hljs-ln-line"> if face_representation:
    64. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="64"> class="hljs-ln-code"> class="hljs-ln-line"> embedding = face_representation[0]['embedding']
    65. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="65"> class="hljs-ln-code"> class="hljs-ln-line"> train_emb_list.append(embedding)
    66. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="66"> class="hljs-ln-code"> class="hljs-ln-line">
    67. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="67"> class="hljs-ln-code"> class="hljs-ln-line"> if len(train_emb_list) == 0:
    68. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="68"> class="hljs-ln-code"> class="hljs-ln-line"> print("⚠️ 未能提取到任何训练图像的面部嵌入。")
    69. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="69"> class="hljs-ln-code"> class="hljs-ln-line"> train_emb = torch.tensor([]).to(DEVICE)
    70. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="70"> class="hljs-ln-code"> class="hljs-ln-line"> else:
    71. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="71"> class="hljs-ln-code"> class="hljs-ln-line"> train_emb = torch.tensor(train_emb_list).to(DEVICE)
    72. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="72"> class="hljs-ln-code"> class="hljs-ln-line">
    73. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="73"> class="hljs-ln-code"> class="hljs-ln-line"> # 生成图像
    74. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="74"> class="hljs-ln-code"> class="hljs-ln-line"> generated_images = generate_images(
    75. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="75"> class="hljs-ln-code"> class="hljs-ln-line"> pipeline=pipeline,
    76. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="76"> class="hljs-ln-code"> class="hljs-ln-line"> prompts=validation_prompts,
    77. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="77"> class="hljs-ln-code"> class="hljs-ln-line"> num_inference_steps=30,
    78. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="78"> class="hljs-ln-code"> class="hljs-ln-line"> guidance_scale=7.5,
    79. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="79"> class="hljs-ln-code"> class="hljs-ln-line"> output_folder=inference_path,
    80. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="80"> class="hljs-ln-code"> class="hljs-ln-line"> generator=generator
    81. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="81"> class="hljs-ln-code"> class="hljs-ln-line"> )
    82. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="82"> class="hljs-ln-code"> class="hljs-ln-line">
    83. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="83"> class="hljs-ln-code"> class="hljs-ln-line"> # 评估生成的图像,mis记录无法检测到面部的图像数量
    84. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="84"> class="hljs-ln-code"> class="hljs-ln-line"> face_score, clip_score, mis = 0, 0, 0 # 初始化评估分数和计数
    85. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="85"> class="hljs-ln-code"> class="hljs-ln-line"> valid_emb = []
    86. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="86"> class="hljs-ln-code"> class="hljs-ln-line"> print("📊 正在计算评估分数...")
    87. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="87"> class="hljs-ln-code"> class="hljs-ln-line">
    88. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="88"> class="hljs-ln-code"> class="hljs-ln-line"> for i, image in enumerate(tqdm(generated_images, desc="评估图像中")):
    89. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="89"> class="hljs-ln-code"> class="hljs-ln-line"> # 使用 DeepFace 检测面部特征
    90. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="90"> class="hljs-ln-code"> class="hljs-ln-line"> opencvImage = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    91. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="91"> class="hljs-ln-code"> class="hljs-ln-line"> emb = DeepFace.represent(
    92. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="92"> class="hljs-ln-code"> class="hljs-ln-line"> opencvImage,
    93. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="93"> class="hljs-ln-code"> class="hljs-ln-line"> detector_backend="ssd",
    94. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="94"> class="hljs-ln-code"> class="hljs-ln-line"> model_name="GhostFaceNet",
    95. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="95"> class="hljs-ln-code"> class="hljs-ln-line"> enforce_detection=False,
    96. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="96"> class="hljs-ln-code"> class="hljs-ln-line"> )
    97. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="97"> class="hljs-ln-code"> class="hljs-ln-line"> if not emb or emb[0].get('face_confidence', 0) == 0:
    98. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="98"> class="hljs-ln-code"> class="hljs-ln-line"> mis += 1 # 无法检测到面部的图像数量
    99. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="99"> class="hljs-ln-code"> class="hljs-ln-line"> continue
    100. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="100"> class="hljs-ln-code"> class="hljs-ln-line">
    101. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="101"> class="hljs-ln-code"> class="hljs-ln-line"> # 计算 CLIP 分数
    102. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="102"> class="hljs-ln-code"> class="hljs-ln-line"> current_prompt = validation_prompts[i]
    103. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="103"> class="hljs-ln-code"> class="hljs-ln-line"> inputs = clip_processor(text=current_prompt, images=image, return_tensors="pt").to(DEVICE)
    104. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="104"> class="hljs-ln-code"> class="hljs-ln-line"> with torch.no_grad():
    105. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="105"> class="hljs-ln-code"> class="hljs-ln-line"> outputs = clip_model(**inputs)
    106. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="106"> class="hljs-ln-code"> class="hljs-ln-line"> sim = outputs.logits_per_image
    107. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="107"> class="hljs-ln-code"> class="hljs-ln-line"> clip_score += sim.item()
    108. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="108"> class="hljs-ln-code"> class="hljs-ln-line">
    109. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="109"> class="hljs-ln-code"> class="hljs-ln-line"> # 收集有效的面部嵌入
    110. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="110"> class="hljs-ln-code"> class="hljs-ln-line"> valid_emb.append(emb[0]['embedding'])
    111. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="111"> class="hljs-ln-code"> class="hljs-ln-line">
    112. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="112"> class="hljs-ln-code"> class="hljs-ln-line"> # 如果没有有效的面部嵌入,则返回默认分数
    113. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="113"> class="hljs-ln-code"> class="hljs-ln-line"> if len(valid_emb) == 0:
    114. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="114"> class="hljs-ln-code"> class="hljs-ln-line"> print("⚠️ 无法检测到面部嵌入!")
    115. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="115"> class="hljs-ln-code"> class="hljs-ln-line"> return 0, 0, mis
    116. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="116"> class="hljs-ln-code"> class="hljs-ln-line">
    117. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="117"> class="hljs-ln-code"> class="hljs-ln-line"> # 计算面部相似度分数(使用欧氏距离)
    118. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="118"> class="hljs-ln-code"> class="hljs-ln-line"> valid_emb = torch.tensor(valid_emb).to(DEVICE)
    119. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="119"> class="hljs-ln-code"> class="hljs-ln-line"> valid_emb = valid_emb / valid_emb.norm(p=2, dim=-1, keepdim=True)
    120. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="120"> class="hljs-ln-code"> class="hljs-ln-line"> train_emb = train_emb / train_emb.norm(p=2, dim=-1, keepdim=True)
    121. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="121"> class="hljs-ln-code"> class="hljs-ln-line"> face_distance = torch.cdist(valid_emb, train_emb, p=2).mean().item()
    122. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="122"> class="hljs-ln-code"> class="hljs-ln-line"> face_score = face_distance # 平均欧氏距离作为面部相似性分数
    123. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="123"> class="hljs-ln-code"> class="hljs-ln-line"> clip_score /= (len(validation_prompts) - mis) if (len(validation_prompts) - mis) > 0 else 1
    124. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="124"> class="hljs-ln-code"> class="hljs-ln-line"> print("📈 评估完成!")
    125. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="125"> class="hljs-ln-code"> class="hljs-ln-line">
    126. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="126"> class="hljs-ln-code"> class="hljs-ln-line"> # 打印评估结果
    127. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="127"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"✅ 面部相似度评分 (平均欧氏距离): {face_score:.4f} (越低越好,表示生成图像与训练图像更相似)")
    128. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="128"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"✅ CLIP 评分 (平均相似度): {clip_score:.4f} (越高越好,表示生成图像与文本提示的相关性更强)")
    129. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="129"> class="hljs-ln-code"> class="hljs-ln-line"> print(f"✅ 无面部图像数量: {mis} (无法检测到面部的生成图像数量)")
    130. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="130"> class="hljs-ln-code"> class="hljs-ln-line">
    131. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="131"> class="hljs-ln-code"> class="hljs-ln-line"># 调用函数执行
    132. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="132"> class="hljs-ln-code"> class="hljs-ln-line">evaluate(lora_config)
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">
    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">📂 加载验证提示...
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">🔧 准备 LoRA 模型...
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line">An error occurred while trying to fetch stablediffusionapi/cyberrealistic-41: stablediffusionapi/cyberrealistic-41 does not appear to have a file named diffusion_pytorch_model.safetensors.
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line">Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line">An error occurred while trying to fetch stablediffusionapi/cyberrealistic-41: stablediffusionapi/cyberrealistic-41 does not appear to have a file named diffusion_pytorch_model.safetensors.
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line">Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line">✅ 已从 ./SD/Brad/logs/checkpoint-last 恢复模型权重
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line">vae/diffusion_pytorch_model.safetensors not found
    9. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="9"> class="hljs-ln-code"> class="hljs-ln-line">🔄 创建 DiffusionPipeline...
    10. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="10"> class="hljs-ln-code"> class="hljs-ln-line">Loading pipeline components...: 100%
    11. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="11"> class="hljs-ln-code"> class="hljs-ln-line"> 6/6 [00:00<00:00, 10.45it/s]
    12. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="12"> class="hljs-ln-code"> class="hljs-ln-line">An error occurred while trying to fetch /root/.cache/huggingface/hub/models--stablediffusionapi--cyberrealistic-41/snapshots/31259688a2398b11f5e7156bac475c459afaccd8/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /root/.cache/huggingface/hub/models--stablediffusionapi--cyberrealistic-41/snapshots/31259688a2398b11f5e7156bac475c459afaccd8/vae.
    13. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="13"> class="hljs-ln-code"> class="hljs-ln-line">Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
    14. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="14"> class="hljs-ln-code"> class="hljs-ln-line">You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
    15. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="15"> class="hljs-ln-code"> class="hljs-ln-line">🎯 加载 CLIP 模型...
    16. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="16"> class="hljs-ln-code"> class="hljs-ln-line">📂 加载训练图像的面部嵌入...
    17. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="17"> class="hljs-ln-code"> class="hljs-ln-line">提取训练图像面部嵌入: 100%
    18. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="18"> class="hljs-ln-code"> class="hljs-ln-line"> 100/100 [00:38<00:00,  1.75it/s]
    19. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="19"> class="hljs-ln-code"> class="hljs-ln-line">🎨 正在生成图像...
    20. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="20"> class="hljs-ln-code"> class="hljs-ln-line">生成图像中: 100%
    21. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="21"> class="hljs-ln-code"> class="hljs-ln-line"> 25/25 [12:53<00:00, 31.07s/it]
    22. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="22"> class="hljs-ln-code"> class="hljs-ln-line">100%
    23. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="23"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:28<00:00,  1.08it/s]
    24. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="24"> class="hljs-ln-code"> class="hljs-ln-line">100%
    25. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="25"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:28<00:00,  1.07it/s]
    26. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="26"> class="hljs-ln-code"> class="hljs-ln-line">100%
    27. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="27"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.07it/s]
    28. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="28"> class="hljs-ln-code"> class="hljs-ln-line">100%
    29. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="29"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    30. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="30"> class="hljs-ln-code"> class="hljs-ln-line">100%
    31. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="31"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.06it/s]
    32. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="32"> class="hljs-ln-code"> class="hljs-ln-line">100%
    33. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="33"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    34. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="34"> class="hljs-ln-code"> class="hljs-ln-line">100%
    35. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="35"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    36. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="36"> class="hljs-ln-code"> class="hljs-ln-line">100%
    37. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="37"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    38. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="38"> class="hljs-ln-code"> class="hljs-ln-line">100%
    39. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="39"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    40. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="40"> class="hljs-ln-code"> class="hljs-ln-line">100%
    41. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="41"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    42. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="42"> class="hljs-ln-code"> class="hljs-ln-line">100%
    43. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="43"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    44. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="44"> class="hljs-ln-code"> class="hljs-ln-line">100%
    45. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="45"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    46. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="46"> class="hljs-ln-code"> class="hljs-ln-line">100%
    47. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="47"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    48. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="48"> class="hljs-ln-code"> class="hljs-ln-line">100%
    49. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="49"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    50. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="50"> class="hljs-ln-code"> class="hljs-ln-line">100%
    51. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="51"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    52. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="52"> class="hljs-ln-code"> class="hljs-ln-line">100%
    53. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="53"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    54. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="54"> class="hljs-ln-code"> class="hljs-ln-line">100%
    55. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="55"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    56. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="56"> class="hljs-ln-code"> class="hljs-ln-line">100%
    57. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="57"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    58. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="58"> class="hljs-ln-code"> class="hljs-ln-line">100%
    59. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="59"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    60. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="60"> class="hljs-ln-code"> class="hljs-ln-line">100%
    61. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="61"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    62. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="62"> class="hljs-ln-code"> class="hljs-ln-line">100%
    63. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="63"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    64. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="64"> class="hljs-ln-code"> class="hljs-ln-line">100%
    65. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="65"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    66. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="66"> class="hljs-ln-code"> class="hljs-ln-line">100%
    67. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="67"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    68. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="68"> class="hljs-ln-code"> class="hljs-ln-line">100%
    69. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="69"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    70. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="70"> class="hljs-ln-code"> class="hljs-ln-line">100%
    71. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="71"> class="hljs-ln-code"> class="hljs-ln-line"> 30/30 [00:29<00:00,  1.05it/s]
    72. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="72"> class="hljs-ln-code"> class="hljs-ln-line">✅ 已生成并保存 25 张图像到 ./SD/Brad/inference
    73. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="73"> class="hljs-ln-code"> class="hljs-ln-line">📊 正在计算评估分数...
    74. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="74"> class="hljs-ln-code"> class="hljs-ln-line">评估图像中: 100%
    75. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="75"> class="hljs-ln-code"> class="hljs-ln-line"> 25/25 [00:08<00:00,  3.37it/s]
    76. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="76"> class="hljs-ln-code"> class="hljs-ln-line">📈 评估完成!
    77. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="77"> class="hljs-ln-code"> class="hljs-ln-line">✅ 面部相似度评分 (平均欧氏距离): 1.2472 (越低越好,表示生成图像与训练图像更相似)
    78. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="78"> class="hljs-ln-code"> class="hljs-ln-line">✅ CLIP 评分 (平均相似度): 31.5189 (越高越好,表示生成图像与文本提示的相关性更强)
    79. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="79"> class="hljs-ln-code"> class="hljs-ln-line">✅ 无面部图像数量: 0 (无法检测到面部的生成图像数量)
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

     

    生成的图像会保存在 ./SD/Brad/inference 中。 

    13 拓展作业

    1. 当前 prompt 的触发词(trigger words)只是 “a man” 吗?
      仔细观察之前数据集的prompt:

    2. 使用当前数据集训练出的模型,如果 prompt 设置为 “a man”,生成的图像应该是什么样的?

    3. 除了之前设置的参数外,探究生成图像相关参数(位于 evaluate())。

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line">generated_images = generate_images(
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line"> pipeline=pipeline,
    3. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="3"> class="hljs-ln-code"> class="hljs-ln-line"> prompts=validation_prompts,
    4. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="4"> class="hljs-ln-code"> class="hljs-ln-line"> num_inference_steps=30, # 修改推理步数
    5. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="5"> class="hljs-ln-code"> class="hljs-ln-line"> guidance_scale=7.5, # 修改文本提示影响程度
    6. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="6"> class="hljs-ln-code"> class="hljs-ln-line"> output_folder=inference_path,
    7. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="7"> class="hljs-ln-code"> class="hljs-ln-line"> generator=generator # 注释这一行,看看不传入 generator 时生成的图像是否有变化?尝试运行三次进行对比。
    8. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="8"> class="hljs-ln-code"> class="hljs-ln-line"> )
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    希望你能通过对代码文件的运行,找到它们的答案。 

    14 用脚本微调 SD(可选)

    这是可选的行为,脚本的代码处理逻辑与文章对应。

    14.1 克隆仓库

    1. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="1"> class="hljs-ln-code"> class="hljs-ln-line"># 如果已经克隆仓库的话跳过这行
    2. class="hljs-ln-numbers"> class="hljs-ln-line hljs-ln-n" data-line-number="2"> class="hljs-ln-code"> class="hljs-ln-line">git clone https://github.com/Hoper-J/AI-Guide-and-Demos-zh_CN
    class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">

    14.2 执行脚本

    1. 切换到 CodePlayground 文件夹:

      cd AI-Guide-and-Demos-zh_CN/CodePlayground class="hljs-button signin active" data-title="登录复制" data-report-click="{"spm":"1001.2101.3001.4334"}" onclick="hljs.signin(event)">
    2. 准备样例数据集:

      # 如果已经下载过,可以跳过,将之后的命令参数修改为对应路径
      wget https://github.com/Hoper-J/AI-Guide-and-Demos-zh_CN/raw/refs/heads/master/Demos/data/14/Datasets.zip
      unzip Datasets.zip
      
    3. 使用指定的数据集和提示文件:

      python sd_lora.py -d ./Datasets/Brad -gp ./Datasets/prompts/validation_prompt.txt
    4. 指定其他参数:

      python sd_lora.py -d ./Datasets/Brad -gp ./Datasets/prompts/validation_prompt.txt -e 500 -b 4 -u 1e-4 -t 1e-5

    更详细的介绍见 CodePlayground,点击  或对应的文本展开。

    15 参考链接

    注:本文转载自blog.csdn.net的PlutoZuo的文章"https://blog.csdn.net/PlutoZuo/article/details/133636032"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
    复制链接

    评论记录:

    未查询到任何数据!