什么是图像分类?

图像分类是计算机视觉的核心任务之一,目标是通过深度学习模型将输入图像分配到预定义的类别中。
例如:

  • 输入一张猫的图片,输出类别为“猫”
  • 输入一张汽车图片,输出类别为“汽车”
图像分类基础

使用PyTorch实现图像分类的步骤 🧰

  1. 准备数据集

    • 常用数据集:CIFAR-10、MNIST、ImageNet
    • 数据预处理:归一化、数据增强、分训练集/验证集
    • 示例代码:
      from torchvision import datasets, transforms
      transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
      train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
      
  2. 构建模型

    • 使用经典架构:ResNet、VGG、MobileNet
    • 自定义网络示例:
      import torch.nn as nn
      class SimpleCNN(nn.Module):
          def __init__(self):
              super().__init__()
              self.layers = nn.Sequential(
                  nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
                  nn.ReLU(),
                  nn.MaxPool2d(kernel_size=2),
                  nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
                  nn.ReLU(),
                  nn.MaxPool2d(kernel_size=2)
              )
          def forward(self, x):
              return self.layers(x)
      
  3. 训练与评估

    • 使用优化器:SGD、Adam
    • 损失函数:交叉熵损失 nn.CrossEntropyLoss()
    • 训练循环示例:
      model = SimpleCNN()
      criterion = nn.CrossEntropyLoss()
      optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
      for epoch in range(5):
          for images, labels in train_loader:
              outputs = model(images)
              loss = criterion(outputs, labels)
              optimizer.zero_grad()
              loss.backward()
              optimizer.step()
      
  4. 模型保存与加载

    • 使用 torch.save() 保存训练好的模型
    • 使用 torch.load() 加载模型进行推理
PyTorch教程

扩展学习 🌐

图像分类模型