PyTorch 是一个开源的机器学习库,由 Facebook 的 AI 研究团队开发。它提供了强大的深度学习框架,被广泛应用于计算机视觉、自然语言处理等领域。

特点

  • 动态计算图:PyTorch 使用动态计算图,这使得它更容易进行调试和实验。
  • GPU 加速:PyTorch 支持CUDA,可以在GPU上加速深度学习模型。
  • 丰富的API:PyTorch 提供了丰富的API,包括神经网络、优化器、损失函数等。

快速开始

要开始使用 PyTorch,首先需要安装它。您可以通过以下命令安装:

pip install torch torchvision

示例

以下是一个简单的 PyTorch 示例:

import torch
import torch.nn as nn
import torch.optim as optim

# 创建一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 实例化网络
net = SimpleNet()

# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

# 输入数据
input = torch.randn(1, 10)
target = torch.tensor([1.0], dtype=torch.float32)

# 前向传播
output = net(input)

# 计算损失
loss = criterion(output, target)

# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()

更多信息

如果您想了解更多关于 PyTorch 的信息,请访问我们的官方文档

[center]PyTorch Logo