项目简介
MNIST数据集是深度学习领域的经典入门案例,包含60,000个手写数字训练样本和10,000个测试样本。本项目将带你通过PyTorch实现一个简单的卷积神经网络,并可视化训练过程中的特征提取结果。
技术路线
数据准备
- 加载MNIST数据集:
torchvision.datasets.MNIST
- 数据预处理:归一化、转换为张量
- 使用
DataLoader
进行批量加载
- 加载MNIST数据集:
模型构建
- 定义CNN网络结构:
class MNIST_CNN(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( nn.Conv2d(1, 32, kernel_size=5), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, kernel_size=5), nn.ReLU(), nn.AdaptiveAvgPool2d((4,4)), nn.Flatten(), nn.Linear(64*4*4, 10) ) def forward(self, x): return self.layers(x)
- 模型训练与验证
- 定义CNN网络结构:
可视化分析
- 使用TensorBoard记录训练指标
- 可视化卷积核权重变化:
torchviz
工具生成计算图 - 展示混淆矩阵与特征图对比
扩展学习
📌 本项目代码已通过PyTorch 2.0测试,建议使用GPU加速训练过程。可视化部分可结合Matplotlib进行结果展示。