项目简介

MNIST数据集是深度学习领域的经典入门案例,包含60,000个手写数字训练样本和10,000个测试样本。本项目将带你通过PyTorch实现一个简单的卷积神经网络,并可视化训练过程中的特征提取结果。

技术路线

  1. 数据准备

    • 加载MNIST数据集:torchvision.datasets.MNIST
    • 数据预处理:归一化、转换为张量
    • 使用DataLoader进行批量加载
    mnist_data_sample
  2. 模型构建

    • 定义CNN网络结构:
      class MNIST_CNN(nn.Module):
          def __init__(self):
              super().__init__()
              self.layers = nn.Sequential(
                  nn.Conv2d(1, 32, kernel_size=5),  
                  nn.ReLU(),  
                  nn.MaxPool2d(2),  
                  nn.Conv2d(32, 64, kernel_size=5),  
                  nn.ReLU(),  
                  nn.AdaptiveAvgPool2d((4,4)),  
                  nn.Flatten(),  
                  nn.Linear(64*4*4, 10)  
              )
          def forward(self, x):  
              return self.layers(x)
      
    • 模型训练与验证
    neural_network_structure
  3. 可视化分析

    • 使用TensorBoard记录训练指标
    • 可视化卷积核权重变化:torchviz工具生成计算图
    • 展示混淆矩阵与特征图对比
    visualization_result

扩展学习

📌 本项目代码已通过PyTorch 2.0测试,建议使用GPU加速训练过程。可视化部分可结合Matplotlib进行结果展示。