在这个教程中,我们将学习如何在 PyTorch 中保存和加载模型。保存模型是为了将训练好的模型保存到磁盘,以便将来可以加载和继续使用或者用于部署。以下是保存和加载模型的基本步骤。
保存模型
要保存一个 PyTorch 模型,你可以使用 torch.save()
函数。这个函数接受两个参数:模型的状态字典和文件路径。
- 模型状态字典:通常通过调用模型的
.state_dict()
方法来获取。 - 文件路径:指定保存模型文件的路径。
以下是一个示例代码:
import torch
import torch.nn as nn
# 定义一个简单的模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
# 创建模型实例
model = SimpleModel()
# 保存模型
torch.save(model.state_dict(), 'model.pth')
加载模型
加载模型与保存模型相反,使用 torch.load()
函数。你需要指定保存模型的文件路径。
以下是一个加载模型的示例代码:
# 加载模型
model.load_state_dict(torch.load('model.pth'))
# 检查模型是否加载成功
print(model)
图片示例
在保存模型时,你可能会遇到一些问题,比如如何选择合适的保存格式。以下是一个关于 PyTorch 模型保存格式的图片示例。
更多信息
想要了解更多关于 PyTorch 的信息,可以访问我们的官方文档:PyTorch 官方文档。