PyTorch 动态图教程

动态图是 PyTorch 中的一种计算图表示方式,它允许用户在运行时动态地创建和修改计算图。以下是一些关于 PyTorch 动态图的基础教程。

动态图基础

动态图与静态图的主要区别在于,动态图在运行时可以改变其结构,而静态图在构建后是固定的。

  • 创建动态图

    • 使用 torch.autograd.backward 进行反向传播。
  • 修改动态图

    • 使用 torch.autograd.grad 计算梯度。

示例

假设我们有一个简单的神经网络:

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

net = SimpleNet()

我们可以计算损失和梯度:

x = torch.randn(1, 10)
y = torch.randn(1)
output = net(x)
loss = torch.nn.functional.mse_loss(output, y)
loss.backward()

扩展阅读

更多关于 PyTorch 的动态图内容,请参考PyTorch 官方文档

PyTorch Logo