Torchvision 是一个基于 PyTorch 的库,专门用于计算机视觉任务。它提供了许多预训练模型和工具,可以帮助开发者快速实现图像分类、目标检测、图像分割等任务。

主要功能

  • 预训练模型: 提供了多种预训练模型,如 ResNet、VGG、AlexNet 等,可以用于图像分类、目标检测等任务。
  • 数据加载和增强: 提供了丰富的数据加载和增强工具,方便进行数据预处理。
  • 模型评估: 提供了多种评估指标,如准确率、召回率等,方便对模型性能进行评估。

快速开始

要使用 Torchvision,首先需要安装 PyTorch。可以通过以下命令进行安装:

pip install torch torchvision

安装完成后,可以使用以下代码进行简单的图像分类:

import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# 加载预训练模型
model = models.resnet50(pretrained=True)

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载图像
image = Image.open('path/to/image.jpg')
image = transform(image).unsqueeze(0)

# 预测
output = model(image)
_, predicted = torch.max(output, 1)

# 打印结果
print('Predicted class:', predicted.item())

更多详细信息和教程,请访问 PyTorch 官方文档

相关资源

Torchvision Logo