首页 最新 热门 推荐

  • 首页
  • 最新
  • 热门
  • 推荐

PyTorch 2.0 一行代码加速模型,简单易懂的基础介绍和实用示例

  • 25-04-18 13:21
  • 2211
  • 7957
juejin.cn

PyTorch 2.0 引入了一个非常强大的新功能——torch.compile(),只需在已有的 PyTorch 模型或函数上加一行代码,就能显著提升运行速度,训练和推理最快可达原来的 1.3 到 2 倍!这对于使用 Hugging Face Transformers、TIMM 等流行模型的开发者来说尤其方便,无需修改现有代码,直接享受性能提升。

什么是 torch.compile()?

  • torch.compile() 是 PyTorch 2.0 的核心新特性之一,它通过自动将 PyTorch 代码编译成更高效的底层代码来加速模型运行。
  • 它支持绝大多数 PyTorch 代码,包括复杂的控制流(if、for 等)、动态形状张量和自定义函数。
  • 只需一行代码包裹模型或函数即可,无需改写代码,兼容性极好。
  • 第一次运行时会进行编译,速度较慢,后续运行速度显著加快。

PyTorch 2.0 加速的原理简述

  • TorchDynamo:动态捕获 Python 代码中的 PyTorch 操作,生成计算图。
  • AOTAutograd:提前生成反向传播代码,优化梯度计算。
  • PrimTorch:统一和简化 PyTorch 内部算子,方便编译器优化。
  • TorchInductor:深度学习编译器,生成针对 GPU 和 CPU 优化的高性能代码,使用 OpenAI Triton 技术加速 CUDA 内核。

安装 PyTorch 2.0(Nightly 版本)

GPU 版本(推荐较新 GPU,如 NVIDIA A100、RTX 30 系列)

bash
代码解读
复制代码
pip3 install numpy --pre torch --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117

CPU 版本

bash
代码解读
复制代码
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu

如何使用 torch.compile()?

只需将模型或函数用 torch.compile() 包装即可:

python
代码解读
复制代码
import torch # 定义一个简单模型 class MyModel(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(100, 10) def forward(self, x): return torch.relu(self.linear(x)) model = MyModel().cuda() # 使用 torch.compile 进行加速 opt_model = torch.compile(model) # 运行加速后的模型 input_tensor = torch.randn(32, 100).cuda() output = opt_model(input_tensor) print(output)

代码示例详解与加速效果

1. 自定义函数加速示例

python
代码解读
复制代码
import torch def simple_fn(x): for _ in range(20): y = torch.sin(x).cuda() x = x + y return x compiled_fn = torch.compile(simple_fn, backend="inductor") input_tensor = torch.randn(10000).cuda() # 第一次运行较慢,后续运行加速明显 result = compiled_fn(input_tensor)
  • 这里展示了如何对普通函数进行加速。
  • 由于融合了多次逐点操作,减少了内存访问,提升了性能。
  • 新 GPU 上加速效果更明显。

2. 加速 ResNet50 模型(PyTorch Hub)

python
代码解读
复制代码
import torch model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).cuda() opt_model = torch.compile(model, backend="inductor") input_tensor = torch.randn(1, 3, 224, 224).cuda() # 预热,第一次运行较慢 opt_model(input_tensor) # 后续运行加速明显 import time start = time.time() for _ in range(10): opt_model(input_tensor) print("加速后的平均推理时间:", (time.time() - start) / 10)
  • 预热后,使用 torch.compile 的模型运行速度比原始模型快约 1.3 到 2 倍。

3. 加速 Hugging Face BERT 模型

python
代码解读
复制代码
import torch from transformers import BertTokenizer, BertModel tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained("bert-base-uncased").cuda() # 只需一行代码加速 opt_model = torch.compile(model) text = "PyTorch 2.0 让模型运行更快!" encoded_input = tokenizer(text, return_tensors='pt').to('cuda') output = opt_model(**encoded_input) print(output.last_hidden_state.shape)
  • 适用于 Hugging Face 上的所有主流 Transformer 模型,无需修改代码。
  • 加速范围通常在 1.5 倍到 2 倍之间。

4. 加速 TIMM 图像模型

python
代码解读
复制代码
import timm import torch model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2).cuda() opt_model = torch.compile(model, backend="inductor") input_tensor = torch.randn(64, 3, 224, 224).cuda() output = opt_model(input_tensor) print(output.shape)
  • TIMM 模型也能开箱即用地获得显著加速。

torch.compile() 参数说明

参数说明推荐设置
backend编译器后端,默认是 "inductor",支持多种后端"inductor"(默认,性能最好)
mode编译模式,影响编译速度和运行速度"default"(大模型)、"reduce-overhead"(小模型)、"max-autotune"(最优性能但编译慢)
dynamic是否启用动态形状支持,减少因不同输入大小导致的重新编译默认为 None,自动启用动态形状

示例:

python
代码解读
复制代码
opt_model = torch.compile(model, backend="inductor", mode="reduce-overhead", dynamic=True)

注意事项和最佳实践

  • 第一次运行慢:torch.compile() 在第一次执行时会进行编译,速度较慢,建议预热几次后再进行性能测试。
  • 动态形状支持:默认自动支持动态形状,适合文本、时间序列等输入长度不固定的场景。
  • 硬件影响:新一代 GPU(如 A100、RTX 30 系列)加速效果更明显,桌面级 GPU 也能提升,但幅度稍小。
  • 兼容性:绝大多数 PyTorch 代码和流行模型都能无缝支持,极少数复杂代码可能需要调整。
  • 分布式训练:建议对内部模型使用 torch.compile(),避免直接对分布式包装器(如 DDP)使用。

总结

PyTorch 2.0 的 torch.compile() 是一项革命性功能,极大简化了深度学习模型的加速过程:

  • 只需一行代码即可加速已有模型和自定义函数。
  • 支持复杂控制流和动态形状,兼容性强。
  • 在 Hugging Face、TIMM、ResNet 等主流模型上已验证可达 30% 到 2 倍的加速。
  • 适合大多数 GPU 和 CPU 环境,尤其是新一代 GPU。
  • 通过简单参数调节,可兼顾编译速度和运行效率。

参考代码仓库与资源

  • PyTorch 官方文档和教程
  • Hugging Face Transformers
  • TIMM 图像模型库
  • PyTorch 2.0 torchdynamo GitHub 讨论区:github.com/pytorch/tor…

通过掌握 torch.compile(),你可以轻松提升模型性能,节省训练和推理时间,助力 AI 项目快速迭代。赶快试试吧!

注:本文转载自juejin.cn的程序员小jobleap的文章"https://juejin.cn/post/7493968018340528182"。版权归原作者所有,此博客不拥有其著作权,亦不承担相应法律责任。如有侵权,请联系我们删除。
复制链接
复制链接
相关推荐
发表评论
登录后才能发表评论和回复 注册

/ 登录

评论记录:

未查询到任何数据!
回复评论:

分类栏目

后端 (14832) 前端 (14280) 移动开发 (3760) 编程语言 (3851) Java (3904) Python (3298) 人工智能 (10119) AIGC (2810) 大数据 (3499) 数据库 (3945) 数据结构与算法 (3757) 音视频 (2669) 云原生 (3145) 云平台 (2965) 前沿技术 (2993) 开源 (2160) 小程序 (2860) 运维 (2533) 服务器 (2698) 操作系统 (2325) 硬件开发 (2492) 嵌入式 (2955) 微软技术 (2769) 软件工程 (2056) 测试 (2865) 网络空间安全 (2948) 网络与通信 (2797) 用户体验设计 (2592) 学习和成长 (2593) 搜索 (2744) 开发工具 (7108) 游戏 (2829) HarmonyOS (2935) 区块链 (2782) 数学 (3112) 3C硬件 (2759) 资讯 (2909) Android (4709) iOS (1850) 代码人生 (3043) 阅读 (2841)

热门文章

103
后端
关于我们 隐私政策 免责声明 联系我们
Copyright © 2020-2024 蚁人论坛 (iYenn.com) All Rights Reserved.
Scroll to Top