模型保存和加载
译者:runzhi214
项目地址:https://pytorch.apachecn.org/2.0/tutorials/beginner/basics/saveloadrun_tutorial
原始地址:https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html
在这个章节我们会学习如何持久化模型状态来保存、加载和执行模型预测。
模型权重的保存和加载
PyTorch 将模型学习到的参数存储在一个内部状态字典中,叫 state_dict
。它们可以通过 torch.save
方法来持久化。
输出:
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/vgg16-397923af.pth
0%| | 0.00/528M [00:00<?, ?B/s]
4%|4 | 22.5M/528M [00:00<00:02, 236MB/s]
9%|8 | 46.5M/528M [00:00<00:02, 245MB/s]
13%|#3 | 70.5M/528M [00:00<00:01, 248MB/s]
18%|#7 | 94.4M/528M [00:00<00:01, 249MB/s]
22%|##2 | 118M/528M [00:00<00:01, 250MB/s]
27%|##6 | 142M/528M [00:00<00:01, 250MB/s]
31%|###1 | 166M/528M [00:00<00:01, 249MB/s]
36%|###5 | 190M/528M [00:00<00:01, 249MB/s]
40%|#### | 214M/528M [00:00<00:01, 249MB/s]
45%|####5 | 238M/528M [00:01<00:01, 250MB/s]
50%|####9 | 262M/528M [00:01<00:01, 251MB/s]
54%|#####4 | 286M/528M [00:01<00:01, 250MB/s]
59%|#####8 | 310M/528M [00:01<00:00, 249MB/s]
63%|######3 | 333M/528M [00:01<00:00, 249MB/s]
68%|######7 | 357M/528M [00:01<00:00, 247MB/s]
72%|#######2 | 381M/528M [00:01<00:00, 248MB/s]
77%|#######6 | 405M/528M [00:01<00:00, 249MB/s]
81%|########1 | 429M/528M [00:01<00:00, 250MB/s]
86%|########5 | 453M/528M [00:01<00:00, 251MB/s]
90%|######### | 477M/528M [00:02<00:00, 251MB/s]
95%|#########5| 502M/528M [00:02<00:00, 253MB/s]
100%|##########| 528M/528M [00:02<00:00, 258MB/s]
100%|##########| 528M/528M [00:02<00:00, 251MB/s]
要加载模型权重,你需要先创建一个跟要加载权重的模型结构一样的模型,然后使用 load_state_dict()
方法加载参数。
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
注意: 请确保在进行推理前调用
model.eval()
方法来将 dropout 层和 batch normalization 层设置为评估模式(evaluation模式)。如果不这么做的话会产生并不一致的推理结果。
保存和加载模型结构
在加载模型权重的时候,我们需要首先实例化一个模型类,因为模型类定义了神经网络的结构。我们也想把模型类结构和模型一起保存,那就可以通过将 model
传递给保存函数(而不是 model.state_dict()
)。
然后我们可以这样载入模型:
关联的教程
在PyTorch中保存、加载一个Checkpoint -- 译者注:该文档目前未完成翻译