首页 最新 热门 推荐

  • 首页
  • 最新
  • 热门
  • 推荐

HW5-补充2:||Top-K vs Top-P||——生成式模型中的采样策略与 Temperature 的影响

  • 25-02-16 23:22
  • 4253
  • 10072
blog.csdn.net

目录

0 前言

1 采样方法概述

2 Top-K采样详解

2.1 工作原理

2.2 数学表述

2.3 代码示例

3 Top-P采样详解

3.1 工作原理

3.2 数学表述

3.3 代码示例

4 Temperature的作用

4.1 工作原理

4.2 例子

4.3 代码示例

5 在大模型中的应用

5.1 Top-K和Top-P采样是否可以一起使用?

5.2 如果我只想使用Top-K或者Top-P应该怎么办?

6 参考链接


0 前言

本文为李宏毅学习笔记——2024春《GENERATIVE AI》篇——作业笔记HW5的补充内容2。

如果你还没获取到LLM API,请查看我的另一篇笔记:

HW1~2:LLM API获取步骤及LLM API使用演示:环境配置与多轮对话演示-CSDN博客

完整内容参见:

李宏毅学习笔记——2024春《GENERATIVE AI》篇

在文章HW5-补充1:深入理解 Beam Search——原理, 示例与代码实现-CSDN博客中我们探讨了 Beam Search 和 Greedy Search,现在来聊聊 model.generate() 中常见的三个参数: top-k, top-p 和 temperature。

1 采样方法概述

在生成文本时,模型为每个可能的下一个词汇分配一个概率分布,选择下一个词汇的策略直接决定了输出的质量和多样性。以下是几种常见的选择方法:

  • Greedy Search(贪心搜索): 每次选择概率最高的词汇。
  • Beam Search(束搜索): 保留多个候选序列,平衡生成质量和多样性。
  • Top-K 采样: 限制候选词汇数量。
  • Top-P 采样(Nucleus Sampling): 根据累积概率选择候选词汇,动态调整词汇集。

为了直观叙述,假设我们当前的概率分布为:

词汇概率
A0.4
B0.3
C0.2
D0.05
0.05

2 Top-K采样详解

2.1 工作原理

Top-K 采样是一种通过限制候选词汇数量来增加生成文本多样性的方法。在每一步生成过程中,模型只考虑概率最高(Top)的 K 个词汇,然后从这 K 个词汇中根据概率进行采样(K=1 就是贪心搜索)。

步骤:

  1. 获取概率分布: 模型为每个可能的下一个词汇生成一个概率分布。
  2. 筛选 Top-K: 选择概率最高的 K 个词汇,忽略其余词汇。
  3. 重新归一化: 将筛选后的 K 个词汇的概率重新归一化,使其总和为 1。
  4. 采样: 根据重新归一化后的概率分布,从 Top-K 词汇中随机采样一个词汇作为下一个生成的词。

2.2 数学表述

2.3 代码示例

假设 K=3。

  1. import numpy as np
  2. # 概率分布
  3. probs = np.array([0.4, 0.3, 0.2, 0.05, 0.05])
  4. words = ['A', 'B', 'C', 'D', '']
  5. # 设置 Top-K
  6. K = 3
  7. # 获取概率最高的 K 个词汇索引
  8. top_indices = np.argsort(probs)[-K:]
  9. # 保留这些 K 个词汇及其概率
  10. top_k_probs = np.zeros_like(probs)
  11. top_k_probs[top_indices] = probs[top_indices]
  12. # 归一化保留的 K 个词汇的概率
  13. top_k_probs = top_k_probs / np.sum(top_k_probs)
  14. # 打印 Top-K 采样的结果
  15. print("Top-K 采样选择的词汇和对应的概率:")
  16. for i in top_indices:
  17. print(f"{words[i]}: {top_k_probs[i]:.2f}")

输出:

Top-K 采样选择的词汇和对应的概率: 
C: 0.22
B: 0.33
A: 0.44 

3 Top-P采样详解

3.1 工作原理

Top-P 采样(又称 Nucleus Sampling)是一种动态选择候选词汇的方法。与 Top-K 采样不同,Top-P 采样不是固定选择 K 个词汇,而是选择一组累计概率达到 P 的词汇集合(即从高到低加起来的概率)。这意味着 Top-P 采样可以根据当前的概率分布动态调整候选词汇的数量,从而更好地平衡生成的多样性和质量。

步骤:

  1. 获取概率分布: 模型为每个可能的下一个词汇生成一个概率分布。
  2. 排序概率: 将词汇按照概率从高到低排序。
  3. 累积概率: 计算累积概率,直到达到预设的阈值 P。
  4. 筛选 Top-P: 选择累积概率达到 P 的最小词汇集合。
  5. 重新归一化: 将筛选后的词汇概率重新归一化。
  6. 采样: 根据重新归一化后的概率分布,从 Top-P 词汇中随机采样一个词汇作为下一个生成的词。

3.2 数学表述

3.3 代码示例

假设 P=0.6。

  1. import numpy as np
  2. # 概率分布
  3. probs = np.array([0.4, 0.3, 0.2, 0.05, 0.05])
  4. words = ['A', 'B', 'C', 'D', '']
  5. # 设置 Top-P
  6. P = 0.6
  7. # 对概率进行排序
  8. sorted_indices = np.argsort(probs)[::-1] # 从大到小排序
  9. sorted_probs = probs[sorted_indices]
  10. # 累积概率
  11. cumulative_probs = np.cumsum(sorted_probs)
  12. # 找到累积概率大于等于 P 的索引
  13. cutoff_index = np.where(cumulative_probs >= P)[0][0]
  14. # 保留累积概率达到 P 的词汇及其概率
  15. top_p_probs = np.zeros_like(probs)
  16. top_p_probs[sorted_indices[:cutoff_index + 1]] = sorted_probs[:cutoff_index + 1]
  17. # 归一化保留的词汇的概率
  18. top_p_probs = top_p_probs / np.sum(top_p_probs)
  19. # 打印 Top-P 采样的结果
  20. print("\nTop-P 采样选择的词汇和对应的概率:")
  21. for i in np.where(top_p_probs > 0)[0]:
  22. print(f"{words[i]}: {top_p_probs[i]:.2f}")
  1. Top-P 采样选择的词汇和对应的概率:
  2. A: 0.57
  3. B: 0.43

4 Temperature的作用

Temperature(温度) 是控制生成文本随机性的参数。

4.1 工作原理

在进行采样前,模型实际上会对概率分布应用温度调整:

4.2 例子

举个例子,P(A)=0.16,P(B)=0.04。

1)当T约等于0时:

选择A的概率为:(无限接近于1)

\frac{0.16^\infty }{0.16^\infty + 0.16^\infty*0.25^\infty} 

选择B的概率为:(无限接近于0)

\frac{0.04^\infty }{0.04^\infty + 0.04^\infty*4^\infty}

2)当T大于1,等于2时:

选择A的概率为2/3;

选择B的概率为1/3。

如此,当T大于1时,低概率词汇的选择概率便大大增加了。

4.3 代码示例

这里将展示 Temperature 对概率的影响。

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. # 概率分布
  4. probs = np.array([0.4, 0.3, 0.2, 0.05, 0.05])
  5. words = ['A', 'B', 'C', 'D', '']
  6. # 设置 Top-K
  7. K = 5
  8. # 设置不同的 Temperature 值
  9. temperatures = [0.5, 1.0, 1.5]
  10. # 创建一个图表
  11. plt.figure(figsize=(10, 6))
  12. # 遍历不同的温度
  13. for temp in temperatures:
  14. # 使用 Temperature 调整概率
  15. adjusted_probs = probs ** (1.0 / temp)
  16. adjusted_probs = adjusted_probs / np.sum(adjusted_probs) # 归一化
  17. # 打印当前 Temperature 的概率分布
  18. print(f"\n--- Temperature = {temp} ---")
  19. for i, prob in enumerate(adjusted_probs):
  20. print(f"{words[i]}: {prob:.2f}")
  21. # 绘制概率分布图
  22. plt.plot(words, adjusted_probs, label=f"Temperature = {temp}")
  23. # 绘制原始概率分布的对比
  24. plt.plot(words, probs, label="Original", linestyle="--", color="black")
  25. # 添加图表信息
  26. plt.xlabel("Word")
  27. plt.ylabel("Probability")
  28. plt.title("Effect of Temperature on Top-K Probability Distribution")
  29. plt.legend()
  30. # 显示图表
  31. plt.show()

输出:

  1. --- Temperature = 0.5 ---
  2. A: 0.54
  3. B: 0.31
  4. C: 0.14
  5. D: 0.01
  6. : 0.01
  7. --- Temperature = 1.0 ---
  8. A: 0.40
  9. B: 0.30
  10. C: 0.20
  11. D: 0.05
  12. : 0.05
  13. --- Temperature = 1.5 ---
  14. A: 0.34
  15. B: 0.28
  16. C: 0.21
  17. D: 0.08
  18. : 0.08

观察图片可以直观看到:

  • 当 temperature < 1 时,概率分布变得更加尖锐,高概率词更可能被选择,适用于需要高确定性的任务,如生成技术文档或代码。
  • 当 temperature > 1 时,概率分布变得更加平坦,使得低概率词也有更多机会被选中,适用于需要创造性和多样性的任务,如写作或对话生成。

5 在大模型中的应用

5.1 Top-K和Top-P采样是否可以一起使用?

可以,通过同时设置 top_k 和 top_p 参数,模型会首先应用 Top-K 筛选,限制候选词汇数量,然后在这有限的词汇中应用 Top-P 采样,动态调整词汇集合。(K个加起来都没达到P则直接全选,不再重新采样)

使用 Hugging Face Transformers 库的简单示例:

  1. import warnings
  2. from transformers import AutoTokenizer, AutoModelForCausalLM
  3. import torch
  4. # 忽略 FutureWarning 警告
  5. warnings.filterwarnings("ignore", category=FutureWarning)
  6. # 指定模型
  7. model_name = "distilgpt2"
  8. # 加载分词器和模型
  9. tokenizer = AutoTokenizer.from_pretrained(model_name)
  10. model = AutoModelForCausalLM.from_pretrained(model_name)
  11. # 将模型移动到设备
  12. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  13. model.to(device)
  14. # 输入文本
  15. input_text = "Hello GPT"
  16. # 编码输入文本
  17. inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
  18. attention_mask = torch.ones_like(inputs).to(device)
  19. # 设置 Top-K 和 Top-P 采样
  20. top_k = 10
  21. top_p = 0.5
  22. temperature = 0.8
  23. # 生成文本,结合 Top-K 和 Top-P 采样
  24. with torch.no_grad():
  25. outputs = model.generate(
  26. inputs,
  27. attention_mask=attention_mask,
  28. max_length=50,
  29. do_sample=True,
  30. top_k=top_k, # 设置 Top-K
  31. top_p=top_p, # 设置 Top-P
  32. temperature=temperature, # 控制生成的随机性
  33. no_repeat_ngram_size=2, # 防止重复 n-gram
  34. pad_token_id=tokenizer.eos_token_id
  35. )
  36. # 解码生成的文本
  37. generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  38. print("结合 Top-K 和 Top-P 采样生成的文本: ")
  39. print(generated_text)

输出示例:

  1. 结合 Top-K 和 Top-P 采样生成的文本:
  2. Hello GPT.
  3. The first time I heard of the G-E-X-1, I was wondering what the future holds for the company. I had no idea what it was. It was a very big company, and it had

参数解释:

  • top_k=10: 首先限制候选词汇为概率最高的 10 个。
  • top_p=0.5: 在这 10 个词汇中,从高到低,选择累积概率达到 0.5 的词汇归一化后进行采样。
  • temperature=0.8: 控制生成的随机性,较低的温度使模型更倾向于高概率词汇。

5.2 如果我只想使用Top-K或者Top-P应该怎么办?

对于只使用 Top-K:

将 top_p 设置为 1(表示不使用 Top-P 采样)

  1. outputs = model.generate(
  2. inputs,
  3. max_length=50,
  4. do_sample=True,
  5. top_k=top_k, # 设置 Top-K
  6. top_p=1.0, # 不使用 Top-P
  7. temperature=temperature, # 控制生成的随机性
  8. no_repeat_ngram_size=2, # 防止重复 n-gram
  9. eos_token_id=tokenizer.eos_token_id
  10. )

对于只使用 Top-P:

将 top_k 设置为 0(表示不使用 Top-K 采样)。

  1. outputs = model.generate(
  2. inputs,
  3. max_length=50,
  4. do_sample=True,
  5. top_k=0, # 不使用 Top-K
  6. top_p=top_p, # 设置 Top-P
  7. temperature=temperature, # 控制生成的随机性
  8. no_repeat_ngram_size=2, # 防止重复 n-gram
  9. eos_token_id=tokenizer.eos_token_id
  10. )

6 参考链接

  • Hugging Face Transformers 文档
  • Nucleus Sampling: A Dynamic Top-P Sampling Technique
注:本文转载自blog.csdn.net的笨笨sg的文章"https://blog.csdn.net/a131529/article/details/144193512"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
复制链接
复制链接
相关推荐
发表评论
登录后才能发表评论和回复 注册

/ 登录

评论记录:

未查询到任何数据!
回复评论:

分类栏目

后端 (14832) 前端 (14280) 移动开发 (3760) 编程语言 (3851) Java (3904) Python (3298) 人工智能 (10119) AIGC (2810) 大数据 (3499) 数据库 (3945) 数据结构与算法 (3757) 音视频 (2669) 云原生 (3145) 云平台 (2965) 前沿技术 (2993) 开源 (2160) 小程序 (2860) 运维 (2533) 服务器 (2698) 操作系统 (2325) 硬件开发 (2491) 嵌入式 (2955) 微软技术 (2769) 软件工程 (2056) 测试 (2865) 网络空间安全 (2948) 网络与通信 (2797) 用户体验设计 (2592) 学习和成长 (2593) 搜索 (2744) 开发工具 (7108) 游戏 (2829) HarmonyOS (2935) 区块链 (2782) 数学 (3112) 3C硬件 (2759) 资讯 (2909) Android (4709) iOS (1850) 代码人生 (3043) 阅读 (2841)

热门文章

101
推荐
关于我们 隐私政策 免责声明 联系我们
Copyright © 2020-2025 蚁人论坛 (iYenn.com) All Rights Reserved.
Scroll to Top