自定义损失函数是 PyTorch 中高级应用的一个重要部分。在本教程中,我们将学习如何创建并使用自定义损失函数。
基础知识
在 PyTorch 中,损失函数用于衡量预测值与真实值之间的差异。默认情况下,PyTorch 提供了许多常用的损失函数,如均方误差(MSE)、交叉熵等。但在某些情况下,你可能需要根据特定问题定义自己的损失函数。
创建自定义损失函数
要创建自定义损失函数,你需要定义一个 Python 函数,该函数接收预测值和真实值作为输入,并返回一个损失值。
import torch
import torch.nn as nn
class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
def forward(self, input, target):
# 这里定义你的损失计算逻辑
return torch.mean((input - target) ** 2)
使用自定义损失函数
创建自定义损失函数后,你可以在训练过程中使用它来计算损失。
model = ... # 你的模型
criterion = CustomLoss()
optimizer = torch.optim.Adam(model.parameters())
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()