手写数字识别是机器学习领域的一个经典任务,它可以通过深度学习模型来实现。在这个例子中,我们将使用 PyTorch 框架来构建一个手写数字识别模型。

数据集介绍

首先,我们需要一个手写数字的数据集。MNIST 数据集是一个常用的手写数字数据集,包含了 0 到 9 的数字图片。我们可以在 MNIST 数据集官网 下载这个数据集。

模型构建

以下是使用 PyTorch 构建一个简单的卷积神经网络 (CNN) 模型的示例代码:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = x.view(-1, 128 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

训练与评估

构建好模型后,我们需要进行训练和评估。以下是训练模型的示例代码:

import torch.optim as optim

# 假设 trainloader 和 testloader 已经被定义好了
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

结论

通过以上步骤,我们已经使用 PyTorch 框架实现了一个简单的手写数字识别模型。当然,这只是一个简单的示例,实际应用中可能需要更复杂的模型和更多的调优。

更多 PyTorch 深度学习实践 可以在这里找到。

(center) handwritten_digits (center)