PyTorch 是一个流行的开源机器学习库,用于应用深度学习。本教程将为你提供一个 PyTorch 的入门指南,包括基础概念和实践示例。

快速开始

以下是一些使用 PyTorch 的基本步骤:

  • 安装 PyTorch:确保你已经安装了 PyTorch。你可以从 PyTorch 官网 下载适合你系统的版本。

  • 导入 PyTorch 库

import torch
  • 创建一个简单的神经网络
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

net = SimpleNet()
  • 训练模型
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 假设 x_data 和 y_data 是你的输入和目标数据
x_data = torch.randn(10, 10)
y_data = torch.randn(10, 1)

optimizer.zero_grad()
output = net(x_data)
loss = criterion(output, y_data)
loss.backward()
optimizer.step()

图像识别示例

以下是一个简单的图像识别示例,使用 PyTorch 进行训练和测试。

import torchvision.transforms as transforms
from torchvision import datasets, models, utils
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# 加载预训练模型
model = models.resnet18(pretrained=True)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

扩展阅读

如果你想要更深入地了解 PyTorch,以下是一些推荐的资源:

PyTorch Logo