在这个教程中,我们将学习如何在 PyTorch 中保存和加载模型。保存模型是为了将训练好的模型保存到磁盘,以便将来可以加载和继续使用或者用于部署。以下是保存和加载模型的基本步骤。

保存模型

要保存一个 PyTorch 模型,你可以使用 torch.save() 函数。这个函数接受两个参数:模型的状态字典和文件路径。

  • 模型状态字典:通常通过调用模型的 .state_dict() 方法来获取。
  • 文件路径:指定保存模型文件的路径。

以下是一个示例代码:

import torch
import torch.nn as nn

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = SimpleModel()

# 保存模型
torch.save(model.state_dict(), 'model.pth')

加载模型

加载模型与保存模型相反,使用 torch.load() 函数。你需要指定保存模型的文件路径。

以下是一个加载模型的示例代码:

# 加载模型
model.load_state_dict(torch.load('model.pth'))

# 检查模型是否加载成功
print(model)

图片示例

在保存模型时,你可能会遇到一些问题,比如如何选择合适的保存格式。以下是一个关于 PyTorch 模型保存格式的图片示例。

PyTorch 模型保存格式示例

更多信息

想要了解更多关于 PyTorch 的信息,可以访问我们的官方文档:PyTorch 官方文档

返回 PyTorch 教程首页