PyTorch 是一个流行的开源机器学习库,用于深度学习研究和开发。它由 Facebook 的 AI 研究团队开发,旨在提供灵活和高效的深度学习工具。

快速开始

  1. 安装 PyTorch

  2. 创建一个简单的神经网络

    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    # 定义一个简单的神经网络
    class SimpleNet(nn.Module):
        def __init__(self):
            super(SimpleNet, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(20, 50, 5)
            self.fc1 = nn.Linear(4*4*50, 500)
            self.fc2 = nn.Linear(500, 10)
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 4*4*50)
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return x
    
    # 实例化网络
    net = SimpleNet()
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
    # 训练网络
    for epoch in range(2):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
    
            optimizer.zero_grad()
    
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0
    
    print('Finished Training')
    
  3. 使用 PyTorch 进行图像识别

    PyTorch 提供了强大的图像处理工具,例如 torchvision 库。

    import torchvision.transforms as transforms
    import torchvision.datasets as datasets
    import matplotlib.pyplot as plt
    
    # 加载数据集
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])
    
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    
    # 显示图像
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    
    fig = plt.figure()
    for idx in range(4):
        ax = fig.add_subplot(2, 2, idx+1)
        plt.imshow(images[idx])
        ax.set_title('label: %d' % labels[idx])
    plt.show()
    

扩展阅读

希望这些信息能帮助您开始使用 PyTorch!🎉