PyTorch 是一个开源的机器学习库,由 Facebook 的 AI 研究团队开发,用于应用深度学习。它提供了动态计算图和强大的 GPU 加速,非常适合研究和开发。
快速开始
以下是使用 PyTorch 创建一个简单的神经网络的基本步骤:
导入 PyTorch 库:
import torch
定义网络结构:
class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() # 定义网络层 self.conv1 = torch.nn.Conv2d(1, 20, 5) self.conv2 = torch.nn.Conv2d(20, 50, 5) self.fc1 = torch.nn.Linear(4*4*50, 500) self.fc2 = torch.nn.Linear(500, 10) def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.max_pool2d(x, 2, 2) x = torch.relu(self.conv2(x)) x = torch.max_pool2d(x, 2, 2) x = x.view(-1, 4*4*50) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x
实例化网络:
net = Net()
定义损失函数和优化器:
criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
训练网络:
for epoch in range(num_epochs): for data, target in train_loader: optimizer.zero_grad() output = net(data) loss = criterion(output, target) loss.backward() optimizer.step()
更多信息
要了解更多关于 PyTorch 的信息,请访问我们的 PyTorch 教程页面。
PyTorch Logo