PyTorch 是一个强大的深度学习框架,其模块系统是构建神经网络的核心。以下是关于 PyTorch 模块的关键信息:

1. 模块基础概念 🔧

  • Module:PyTorch 中所有神经网络组件的基类,继承自 torch.nn.Module
  • 子模块:通过 nn.ModuleListnn.Sequential 组合多个模块,形成复杂网络结构。
  • 参数管理:模块自动封装参数(如权重和偏置),通过 parameters() 方法访问。

2. 常见模块类型 🧠

模块类型 功能说明 示例链接
nn.Linear 全连接层,用于特征变换 /pytorch/线性层
nn.Conv2d 二维卷积层,处理图像数据 /pytorch/卷积层
nn.ReLU 激活函数,引入非线性 /pytorch/激活函数
nn.BatchNorm2d 批归一化层,加速训练 /pytorch/归一化层

3. 模块使用示例 📜

import torch.nn as nn

class MyNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 5)  # 全连接层示例
        self.relu = nn.ReLU()          # 激活函数
        self.conv = nn.Conv2d(3, 16, kernel_size=3)  # 卷积层示例

    def forward(self, x):
        x = self.layer(x)
        x = self.relu(x)
        x = self.conv(x)
        return x

4. 模块与优化器联动 🔄

  • 模块参数会自动注册到优化器中,例如:
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
  • 通过 model.train()model.eval() 切换训练/推理模式,影响模块行为(如 Dropout、BatchNorm)。

PyTorch_modules_structure
*图:PyTorch 模块层级结构示意图*