PyTorch 提供了强大的数据加载和增强功能,使得在图像、视频等数据集上进行训练和测试变得更加便捷。以下是一些关于 PyTorch 数据加载和增强的基本教程。
数据加载
PyTorch 使用 torch.utils.data.DataLoader
来加载数据集。以下是一个简单的例子:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
数据增强
数据增强是提高模型泛化能力的重要手段。PyTorch 提供了多种数据增强方法,例如随机裁剪、旋转、翻转等。
以下是一个使用数据增强的例子:
from torchvision import transforms
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
代码示例
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
for data, target in train_loader:
# 在这里进行模型训练
pass
扩展阅读
更多关于 PyTorch 数据加载和增强的信息,可以参考以下链接:
MNIST 数据集示例