自定义模块是 PyTorch 中的一个强大功能,它允许用户创建自己的神经网络层或组件。在本教程中,我们将学习如何创建和使用自定义模块。

创建自定义模块

创建自定义模块非常简单。首先,你需要定义一个继承自 torch.nn.Module 的类。然后,在该类中定义 __init__forward 方法。

import torch
import torch.nn as nn

class MyCustomModule(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyCustomModule, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)

使用自定义模块

创建自定义模块后,你可以在你的模型中像使用任何其他 PyTorch 模块一样使用它。

model = nn.Sequential(
    MyCustomModule(input_size=10, output_size=5),
    nn.ReLU(),
    nn.Linear(5, 3)
)

扩展阅读

想了解更多关于 PyTorch 的知识?请访问我们的 PyTorch 教程

图片示例

以下是一个自定义模块的示例图片:

Custom Module