pytorch加载模型和模型推理常见操作

这篇具有很好参考价值的文章主要介绍了pytorch加载模型和模型推理常见操作。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

1.pth保存模型的说明

.pth文件可以保存模型的拓扑结构和参数,也可以只保存模型的参数,取决于model.save()中的参数。

torch.save(model.state_dict(), 'mymodel.pth')  # 只保存模型权重参数,不保存模型结构
torch.save(model, 'mymodel.pth')  # 保存整个model的状态
#model为已经训练好的模型

使用方式1得到的.pth重构模型代码如下:

model = My_model(*args, **kwargs)
model.load_state_dict(torch.load('mymodel.pth'))
model.eval()

使用方式2得到的.pth重构模型代码如下:

model=torch.load('mymodel.pth')
model.eval()

2.pth文件load细节

以只保存模型参数的pth为例

epth_encoder = depth.ResnetEncoder(18, False)  # 加载encoder模型
loaded_dict_enc = torch.load('depth/models/weights_19/encoder.pth')#数据类型:有序字典

loaded_dict_enc 的类型是:<class ‘odict_items’>(有序字典),本质还是python的字典类型,有键值对,其中键指的是每层网络结构的名字,数据类型是字符串型,值指的是每层网络结构的参数,数据类型是numpy张量。
运行下面这一行代码,可以更加细致的发现pth中含有的信息。

 for k, v in loaded_dict_enc.items():
        print(k)
        print(v)

运行结果反映了,第一个键(key)为encoder.conv1.weight即表示encoder模型第一个卷积层的权重。对应的值(values)是下图的张量。这些参数张量都是pth文件中保存的,不会发生变化。
pytorch模型推理代码,pytorch,python,深度学习,人工智能,算法

3.state_dict

state_dict是Python的字典对象,可用于保存模型参数、超参数以及优化器的状态信息。需要注意的是,只有具有可学习参数的层(如卷积层、线性层等)才有state_dict。
可以用state_dict非常细致的查看网络结构是否正确,能够清晰反映各层滤波器的大小。

 for param_tensor in depth_encoder.state_dict():
        print(param_tensor, '\t', depth_encoder.state_dict()[param_tensor].size())

pytorch模型推理代码,pytorch,python,深度学习,人工智能,算法

4.模型参数读入

filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
depth_encoder.load_state_dict(filtered_dict_enc)

5.eval()

eval()是PyTorch中用来将神经网络设置为评估模式的方法。在评估模式下,网络的参数不会被更新,Dropout和Batch Normalization层的行为也会有所不同。通常在测试阶段使用评估模式。
eval() 可以作为模型推理的性能提升方法,在评估模式下,计算图是不被跟踪的,这样可以节省内存使用,提升性能。还可以使用torch.no_grad()配合使用,在评估阶段关闭梯度跟踪,进一步提升性能。

depth_encoder.eval()  # 切换到评估模式,使得模型BN层等失效

6.模型推理

关闭梯度流跟踪和eval()共同提升模型推理性能。文章来源地址https://www.toymoban.com/news/detail-639703.html

encoder_input = torch.randn(1, 3, 256, 256)
with torch.no_grad():
     encoder_output = depth_encoder(encoder_input))

到了这里,关于pytorch加载模型和模型推理常见操作的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包赞助服务器费用

相关文章

  • 【深度强化学习】(1) DQN 模型解析,附Pytorch完整代码

    【深度强化学习】(1) DQN 模型解析,附Pytorch完整代码

    大家好,今天和各位讲解一下深度强化学习中的基础模型 DQN,配合 OpenAI 的 gym 环境,训练模型完成一个小游戏,完整代码可以从我的 GitHub 中获得: https://github.com/LiSir-HIT/Reinforcement-Learning/tree/main/Model DQN(Deep Q Network) 算法由 DeepMind 团队提出,是深度神经网络和 Q-Learning 算

    2023年04月08日
    浏览(12)
  • 【深度强化学习】(8) iPPO 模型解析,附Pytorch完整代码

    【深度强化学习】(8) iPPO 模型解析,附Pytorch完整代码

    大家好,今天和各位分享一下多智能体深度强化学习算法 ippo,并基于 gym 环境完成一个小案例。完整代码可以从我的 GitHub 中获得:https://github.com/LiSir-HIT/Reinforcement-Learning/tree/main/Model 多智能体的情形相比于单智能体更加复杂,因为 每个智能体在和环境交互的同时也在和其他

    2024年02月03日
    浏览(101)
  • 【深度强化学习】(6) PPO 模型解析,附Pytorch完整代码

    【深度强化学习】(6) PPO 模型解析,附Pytorch完整代码

    大家好,今天和各位分享一下深度强化学习中的 近端策略优化算法 (proximal policy optimization, PPO ),并借助 OpenAI 的 gym 环境完成一个小案例,完整代码可以从我的 GitHub 中获得: https://github.com/LiSir-HIT/Reinforcement-Learning/tree/main/Model PPO 算法之所以被提出,根本原因在于 Polic

    2023年04月08日
    浏览(9)
  • 【深度强化学习】(2) Double DQN 模型解析,附Pytorch完整代码

    【深度强化学习】(2) Double DQN 模型解析,附Pytorch完整代码

    大家好,今天和大家分享一个深度强化学习算法 DQN 的改进版 Double DQN,并基于 OpenAI 的 gym 环境库完成一个小游戏,完整代码可以从我的 GitHub 中获得: https://github.com/LiSir-HIT/Reinforcement-Learning/tree/main/Model DQN 算法的原理是指导机器人不断与环境交互,理解最佳的行为方式,最

    2024年02月03日
    浏览(13)
  • 【深度强化学习】(4) Actor-Critic 模型解析,附Pytorch完整代码

    【深度强化学习】(4) Actor-Critic 模型解析,附Pytorch完整代码

    大家好,今天和各位分享一下深度强化学习中的 Actor-Critic 演员评论家算法, Actor-Critic 算法是一种综合了策略迭代和价值迭代的集成算法 。我将使用该模型结合 OpenAI 中的 Gym 环境完成一个小游戏,完整代码可以从我的 GitHub 中获得: https://github.com/LiSir-HIT/Reinforcement-Learning

    2024年02月03日
    浏览(23)
  • PyTorch多进程模型推理

    进程:一个在内存中运行的应用程序,每个进程有自己独立的一块内存空间。 资源分配的最小单位 。 线程:进程中的一个执行单元, 程序执行的最小单位 。一个进程可以有多个线程。 Python的多线程特点:在Python中,由于GIL的存在,在多线程的时候, 同一时间只能有一个线

    2024年02月01日
    浏览(10)
  • Python与深度学习:Keras、PyTorch和Caffe的使用和模型设计

    Python与深度学习:Keras、PyTorch和Caffe的使用和模型设计

      深度学习已经成为当今计算机科学领域的热门技术,而Python则是深度学习领域最受欢迎的编程语言之一。在Python中,有多个深度学习框架可供选择,其中最受欢迎的包括Keras、PyTorch和Caffe。本文将介绍这三个框架的使用和模型设计,帮助读者了解它们的优势、特点和适用场

    2024年02月09日
    浏览(11)
  • Python使用pytorch深度学习框架构造Transformer神经网络模型预测红酒分类例子

    Python使用pytorch深度学习框架构造Transformer神经网络模型预测红酒分类例子

    经典的红酒分类数据集是指UCI机器学习库中的Wine数据集。该数据集包含178个样本,每个样本有13个特征,可以用于分类任务。 具体每个字段的含义如下: alcohol:酒精含量百分比 malic_acid:苹果酸含量(克/升) ash:灰分含量(克/升) alcalinity_of_ash:灰分碱度(以mEq/L为单位)

    2024年02月02日
    浏览(12)
  • 深度学习网络模型 MobileNet系列MobileNet V1、MobileNet V2、MobileNet V3网络详解以及pytorch代码复现

    深度学习网络模型 MobileNet系列MobileNet V1、MobileNet V2、MobileNet V3网络详解以及pytorch代码复现

    DW与PW计算量 普通卷积计算量 计算量对比 因此理论上普通卷积是DW+PW卷积的8到9倍 Residual blok与Inverted residual block对比: Residual blok :先采用1 x 1的卷积核来对特征矩阵进行压缩,减少输入特征矩阵的channel,再通过3 x 3的卷积核进行特征处理,再采用1 x 1的卷积核来扩充channel维

    2024年02月01日
    浏览(16)
  • pytorch快速训练ai作画模型的python代码

    在 PyTorch 中训练 AI 作画模型的基本步骤如下: 准备数据集: 需要准备一个包含许多图像的数据集, 这些图像可以是手绘的或者是真实的图像. 定义模型: 选择一个适当的深度学习模型, 并使用 PyTorch 定义该模型. 例如, 可以使用卷积神经网络 (CNN) 或者生成对抗网络 (GAN). 训练模型

    2024年02月09日
    浏览(16)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包