什么是图像分类?
图像分类是计算机视觉的核心任务之一,目标是通过深度学习模型将输入图像分配到预定义的类别中。
例如:
- 输入一张猫的图片,输出类别为“猫”
- 输入一张汽车图片,输出类别为“汽车”
使用PyTorch实现图像分类的步骤 🧰
准备数据集
- 常用数据集:CIFAR-10、MNIST、ImageNet
- 数据预处理:归一化、数据增强、分训练集/验证集
- 示例代码:
from torchvision import datasets, transforms transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
构建模型
- 使用经典架构:ResNet、VGG、MobileNet
- 自定义网络示例:
import torch.nn as nn class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2) ) def forward(self, x): return self.layers(x)
训练与评估
- 使用优化器:SGD、Adam
- 损失函数:交叉熵损失
nn.CrossEntropyLoss()
- 训练循环示例:
model = SimpleCNN() criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(5): for images, labels in train_loader: outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step()
模型保存与加载
- 使用
torch.save()
保存训练好的模型 - 使用
torch.load()
加载模型进行推理
- 使用