欢迎访问 PyTorch 图像分类教程!本教程将带你从零开始构建一个图像分类模型,适合初学者和有一定基础的开发者。通过本教程,你将掌握使用 PyTorch 进行图像分类的核心概念和实践方法。

🧰 环境准备

  1. 安装 PyTorch
    请参考 PyTorch 官方安装指南 完成环境搭建。

  2. 依赖库

    • Python 3.8+
    • Torch 1.10+
    • torchvision
    • matplotlib(用于可视化)
  3. 数据集
    使用 CIFAR-10 数据集进行训练和验证。

🛠️ 实现步骤

1. 导入必要模块

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

2. 构建模型

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Linear(32 * 6 * 6, 10)  # 假设输入图像为 32x32

    def forward(self, x):
        x = self.layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

3. 训练模型

model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练循环示例
for epoch in range(10):
    for images, labels in dataloader:
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

📈 模型评估与优化

  • 评估方法
    使用验证集计算准确率:

    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Validation Accuracy: {100 * correct / total:.2f}%")
    
  • 优化技巧

    • 数据增强:使用 transforms 提升泛化能力
    • 学习率调整:尝试 torch.optim.lr_scheduler
    • 模型保存:使用 torch.save(model.state_dict(), 'model.pth')

🌐 扩展阅读

PyTorch_教程