TorchVision 是一个开源的计算机视觉库,它基于 PyTorch 框架。它提供了丰富的计算机视觉模型和工具,使得在 PyTorch 上进行图像和视频处理变得更加容易。

主要功能

  • 预训练模型: 提供了多种预训练模型,如 ResNet、VGG、MobileNet 等。
  • 数据加载和处理: 提供了便捷的数据加载和处理工具,支持多种图像格式和预处理操作。
  • 可视化工具: 提供了图像和模型的可视化工具,方便用户理解和调试。

安装

pip install torchvision

快速开始

以下是一个简单的使用例子:

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

# 加载模型
model = models.resnet18(pretrained=True)

# 创建一个转换器,用于将图像转换为模型需要的格式
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载并预处理图像
image = Image.open("path_to_image.jpg")
image = transform(image).unsqueeze(0)

# 使用模型进行预测
outputs = model(image)
_, predicted = torch.max(outputs, 1)
print('Predicted:', predicted.item())

更多详细信息和教程,请访问 TorchVision 官方文档

ResNet18