损失函数是机器学习模型训练中不可或缺的部分,它用于评估模型预测结果与真实值之间的差距。在 PyTorch 中,有多种损失函数可供选择,以下是一些常用的损失函数及其使用方法。

常用损失函数

  1. 均方误差损失(Mean Squared Error, MSE) 均方误差损失是最常用的回归损失函数之一,它计算预测值与真实值之间差的平方的平均值。

    criterion = nn.MSELoss()
    output = net(input)
    loss = criterion(output, target)
    
  2. 交叉熵损失(Cross Entropy Loss) 交叉熵损失通常用于分类问题,它计算预测概率分布与真实分布之间的交叉熵。

    criterion = nn.CrossEntropyLoss()
    output = net(input)
    loss = criterion(output, target)
    
  3. 二元交叉熵损失(Binary Cross Entropy Loss) 二元交叉熵损失是交叉熵损失的一种特殊情况,适用于二分类问题。

    criterion = nn.BCEWithLogitsLoss()
    output = net(input)
    loss = criterion(output, target)
    

损失函数示例

以下是一个使用 PyTorch 训练分类模型的示例:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.fc1 = nn.Linear(50 * 4 * 4, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 50 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 训练数据
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 训练模型
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
print('Finished Training')

更多关于 PyTorch 损失函数的教程,请访问PyTorch 损失函数教程