卷积神经网络(CNN)是深度学习中用于图像识别、图像分类等任务的重要模型。本教程将详细介绍 PyTorch 中如何实现和训练 CNN。

1. 引言

PyTorch 是一个流行的深度学习框架,以其动态计算图和易于使用的 API 而闻名。CNN 是 PyTorch 中实现图像识别任务的主要模型。

2. 安装 PyTorch

在开始之前,确保你的环境中已安装 PyTorch。你可以通过以下命令安装:

pip install torch torchvision

3. 数据准备

在进行 CNN 训练之前,需要准备数据集。PyTorch 提供了多种数据集,例如 CIFAR-10、MNIST 等。

import torchvision.datasets as datasets

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True)

4. 定义 CNN 模型

在 PyTorch 中,你可以使用 torch.nn 模块定义 CNN 模型。

import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 64 * 8 * 8)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

5. 训练模型

定义损失函数和优化器,然后开始训练模型。

import torch.optim as optim

model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

6. 评估模型

在训练完成后,可以使用测试集评估模型的性能。

test_loss = 0
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Loss: {:.4f}'.format(test_loss / len(test_loader.dataset)))
print('Accuracy: {:.4f}'.format(correct / total))

7. 总结

通过本教程,你学习了如何在 PyTorch 中实现和训练 CNN。希望这能帮助你更好地理解和应用 CNN。

更多 PyTorch 教程

CNN 架构