自定义模块是 PyTorch 中的一个强大功能,它允许用户创建自己的神经网络层或组件。在本教程中,我们将学习如何创建和使用自定义模块。
创建自定义模块
创建自定义模块非常简单。首先,你需要定义一个继承自 torch.nn.Module
的类。然后,在该类中定义 __init__
和 forward
方法。
import torch
import torch.nn as nn
class MyCustomModule(nn.Module):
def __init__(self, input_size, output_size):
super(MyCustomModule, self).__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
使用自定义模块
创建自定义模块后,你可以在你的模型中像使用任何其他 PyTorch 模块一样使用它。
model = nn.Sequential(
MyCustomModule(input_size=10, output_size=5),
nn.ReLU(),
nn.Linear(5, 3)
)
扩展阅读
想了解更多关于 PyTorch 的知识?请访问我们的 PyTorch 教程。
图片示例
以下是一个自定义模块的示例图片: