PyTorch 提供了丰富的数据加载和预处理工具,可以帮助我们更高效地处理数据。本教程将介绍如何在 PyTorch 中进行数据加载。

安装 PyTorch

在开始之前,请确保你已经安装了 PyTorch。你可以通过以下命令进行安装:

pip install torch torchvision

数据加载

PyTorch 提供了 torch.utils.data.Datasettorch.utils.data.DataLoader 两个类来处理数据加载。

创建 Dataset

首先,你需要创建一个继承自 torch.utils.data.Dataset 的类,并实现 __len____getitem__ 方法。

from torch.utils.data import Dataset
import os

class CustomDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_list = os.listdir(root_dir)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        file_path = os.path.join(self.root_dir, file_name)
        # 这里添加加载和处理文件的代码
        return file_path

使用 DataLoader

接下来,你可以使用 DataLoader 来加载数据。

from torch.utils.data import DataLoader

dataset = CustomDataset(root_dir='path/to/your/data')
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

for data in dataloader:
    # 这里添加处理数据的代码
    pass

图像数据加载

对于图像数据,PyTorch 提供了 torchvision.datasets 中的 ImageFolder 类,可以方便地加载图像数据。

from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(root='path/to/your/image/data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

扩展阅读

想要了解更多关于 PyTorch 数据加载的信息,可以阅读 PyTorch 官方文档

返回首页