在深度学习模型开发过程中,能够清晰地查看模型结构和参数信息对于调试和优化至关重要。PyTorch作为一个灵活的深度学习框架,虽然提供了print(model)方法来打印模型结构,但输出信息往往不够直观和全面。为了解决这个问题,pytorch-summary应运而生,它为PyTorch提供了类似于Keras中model.summary()的功能,能够生成简洁明了的模型结构摘要。
pytorch-summary是一个轻量级的PyTorch模型可视化工具,由GitHub用户sksq96开发。它的主要目标是提供与print(model)互补的信息,帮助用户更好地理解和分析模型结构。截至目前,该项目在GitHub上已获得超过4000颗星,受到广大PyTorch用户的欢迎。
安装pytorch-summary非常简单,可以通过pip直接安装:
pip install torchsummary
或者从GitHub克隆源代码:
git clone https://github.com/sksq96/pytorch-summary
安装完成后,使用方法也很直观:
from torchsummary import summary summary(your_model, input_size=(channels, H, W))
其中,your_model是你定义的PyTorch模型,input_size指定了输入数据的维度。需要注意的是,input_size参数是必需的,因为pytorch-summary需要进行一次前向传播来收集模型信息。
下面我们通过几个具体的例子来展示pytorch-summary的强大功能。
首先,让我们看一个简单的CNN模型在MNIST数据集上的应用:
import torch import torch.nn as nn import torch.nn.functional as F from torchsummary import summary class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x, dim=1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Net().to(device) summary(model, (1, 28, 28))
运行上述代码,我们将得到如下输出:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 10, 24, 24] 260
Conv2d-2 [-1, 20, 8, 8] 5,020
Dropout2d-3 [-1, 20, 8, 8] 0
Linear-4 [-1, 50] 16,050
Linear-5 [-1, 10] 510
================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.08
Estimated Total Size (MB): 0.15
----------------------------------------------------------------
从这个输出中,我们可以清楚地看到模型的每一层结构、输出shape、参数数量,以及整个模型的参数统计和内存占用估算。
对于更复杂的模型,pytorch-summary同样能够提供清晰的概览。以VGG16为例:
import torch from torchvision import models from torchsummary import summary device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') vgg = models.vgg16().to(device) summary(vgg, (3, 224, 224))
这将生成VGG16模型的详细摘要,包括其所有卷积层、全连接层和激活函数。
pytorch-summary还支持具有多个输入的模 型:
import torch import torch.nn as nn from torchsummary import summary class SimpleConv(nn.Module): def __init__(self): super(SimpleConv, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1), nn.ReLU(), ) def forward(self, x, y): x1 = self.features(x) x2 = self.features(y) return x1, x2 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleConv().to(device) summary(model, [(1, 16, 16), (1, 28, 28)])
这个例子展示了如何为具有两个不同大小输入的模型生成摘要。
pytorch-summary为PyTorch用户提供了一个强大而简单的工具,用于可视化和理解深度学习模型的结构。它不仅有助于调试和优化模型,还能帮助研究人员和开发者更好地解释和展示他们的工作。随着深度学习模型日益复杂,这样的工具在模型开发过程中的重要性也将日益凸显。
虽然pytorch-summary已经非常实用,但开发者社区仍在不断改进和扩展其功能。例如,最新的torchinfo项目就是在pytorch-summary的基础上进行了进一步的优化和功能扩展。因此,建议用户关 注项目的最新发展,以便使用最新和最优化的版本。
总的来说,pytorch-summary是每个PyTorch开发者工具箱中不可或缺的一部分。无论你是刚开始学习深度学习,还是已经是经验丰富的研究者,这个工具都能在你的项目中发挥重要作用,帮助你更好地理解和优化你的模型。
AI辅助编程,代码自动修复
Trae是一种自适应的集成开发环境(IDE),通过自动化和多元协作改变开发流程。利用Trae,团队能够更快速、精确地编写和部署代码,从而提高编程效率和项目交付速度。Trae具备上下文感知和代码自动完成功能,是提升开发效率的理想工具。
最强AI数据分析助手
小浣熊家族Raccoon,您的AI智能助手,致力于通过先进的人工智能技术,为用户提供高效、便捷的智能服务。无论是日常咨询还是专业问题解答,小浣熊都能以快速、准确的响应满足您的需求,让您的生活更加智能便捷。
像人一样思考的AI智能体
imini 是一款超级AI智能体,能根据人类指令,自主思考、自主完成、并且交付结果的AI智能体。
AI数字人视频创作平台
Keevx 一款开箱即用的AI数字人视频创作平台,广泛适用于电商广告、企业培训与社媒宣传,让全球企业与个人创作者无需拍摄剪辑,就能快速生成多语言、高质量的专业视频。
一站式AI创作平台
提供 AI 驱动的图片、视频生成及数字人等功能,助力创意创作
AI办公助手,复杂任务高效处理
AI办公助手,复杂任务高效处理。办公效率低?扣子空间AI助手支持播客生成、PPT制作、网页开发及报告写作,覆盖科研、商业、舆情等领域的专家Agent 7x24小时响应,生活工作无缝切换,提升50%效率!
AI小说写作助手,一站式润色、改写、扩写
蛙蛙写作—国内先进的AI写作平台,涵盖小说、学术、社交媒体等多场景。提供续写、改写、润色等功能,助力创作者高效优化写作流程。界面简洁,功能全面,适合各类写作者提升内容品质和工作效率。
全能AI智能助手,随时解答生活与工作的多样问题
问小白,由元石科技研发的AI智能助手,快速准确地解答各种生活和工作问题,包括但不限于搜索、规划和社交互动,帮助用户在日常生活中提高效率,轻松管理个人事务。
实时语音翻译/同声传译工具
Transly是一个多场景的AI大语言模型驱动的同声传译、专业翻译助手,它拥有超精准的音频识别翻译能力,几乎零延迟的使用体验和支持多国语言可以让你带它走遍全球,无论你是留学生、商务人士、韩剧美剧爱好者,还是出国游玩、多国会议、跨国追星等等,都可以满足你所有需要同传的场景需求,线上线下通用,扫除语言障碍,让全世界的语言交流不再有国界。
一键生成PPT和Word,让学习生活更轻松
讯飞智文是一个利用 AI 技术的项目,能够帮助用户生成 PPT 以及各类文档。无论是商业领域的市场分析报告、年度目标制定,还是学生群体的职业生涯规划、实习避坑指南,亦或是活动策划、旅游攻略等内容,它都能提供支持,帮助用户精准表达,轻松呈现各种信息。
最新AI工具、AI资讯
独家AI资源、AI项目落地
微信扫一扫关注公众号