PyTorch 提供了强大的数据加载和增强功能,使得在图像、视频等数据集上进行训练和测试变得更加便捷。以下是一些关于 PyTorch 数据加载和增强的基本教程。

数据加载

PyTorch 使用 torch.utils.data.DataLoader 来加载数据集。以下是一个简单的例子:

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

数据增强

数据增强是提高模型泛化能力的重要手段。PyTorch 提供了多种数据增强方法,例如随机裁剪、旋转、翻转等。

以下是一个使用数据增强的例子:

from torchvision import transforms

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

代码示例

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

for data, target in train_loader:
    # 在这里进行模型训练
    pass

扩展阅读

更多关于 PyTorch 数据加载和增强的信息,可以参考以下链接:

PyTorch 官方文档 - 数据增强

MNIST 数据集示例