logo

torchvision

王哲峰 / 2023-01-17


目录

数据集

内置数据集

示例:

from torch.utils.data import DataLoader
from torchvision import datasets

imagenet_data = datasets.ImageNet("path/to/imagenet_root/")
data_loader = DataLoader(
    imagenet_data,
    batch_size = 4,
    shuffle = True,
    num_workers = args.nThreads
)

类型:

自定义图像数据集的基本类

数据集读写

torchvision.io

视屏

import torchvision
video_path = "path to a test video"
# Constructor allocates memory and a threaded decoder
# instance per video. At the moment it takes two arguments:
# path to the video file, and a wanted stream.
reader = torchvision.io.VideoReader(video_path, "video")

# The information about the video can be retrieved using the
# `get_metadata()` method. It returns a dictionary for every stream, with
# duration and other relevant metadata (often frame rate)
reader_md = reader.get_metadata()

# metadata is structured as a dict of dicts with following structure
# {"stream_type": {"attribute": [attribute per stream]}}
#
# following would print out the list of frame rates for every present video stream
print(reader_md["video"]["fps"])

# we explicitly select the stream we would like to operate on. In
# the constructor we select a default video stream, but
# in practice, we can set whichever stream we would like
video.set_current_stream("video:0")

图像

数据转换与增强

torchvision.transforms

数据转换与增强简介

所有的 torchvision datasets 都有两个接受包含转换逻辑的可调用对象的参数:

大部分 transform 同时接受 PIL 图像和 tensor 图像,但是也有一些 tansform 只接受 PIL 图像,或只接受 tensor 图像

transform 接受 tensor 图像或批量 tensor 图像

一个 tensor 图像像素值的范围由 tensor dtype 严格控制

转换的形式:

常用转换

特征提取

预训练模型和权重

torchvision.models

模型类型

预训练权重的一般信息

初始化预训练模型

from torchvision.models import resnet50, ResNet50_Weights

# Old weights with accuracy 76.130%
resnet50(weights = ResNet50_Weights.IMAGENET1K_V1)
resnet50(weights = "IMAGENET_V1")  # new API

# New weights with accuracy 80.858%
resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)
resnet50(weights = "IMAGENET1K_V2")  # net API

# Best available weights(currently alias for IMAGENET1K_V2)
resnet50(weights = ResNet50_Weights.DEFAULT)

# No Weights - random initialization
resnet50(weights = None)
resnet50()  # new API

使用预训练模型

在使用预训练模型前,必须对图像数据进行预处理,处理方式因模型而异,比如:

每个模型的推断转换信息都在其权重文档中提供:

# 初始化权重转换
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

# Apply it to the input image
img_transformed = preprocess(img)

一些模型使用具有不同训练和评估行为的模块,例如批量归一化

# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights = weights)

# Set model to eval mode
model.eval()

模型注册机制

从 Hub 中使用模型

PyTorch Hub

import torch

# Option 1: passing weights param as string
model = torch.hub.load(
    "pytorch/vision", 
    "resnet50", 
    weights = "IMAGENET1K_V2"
)

# Option 2: passing weights param as enum
weights = torch.hub.load(
    "pytorch/vision", 
    "get_weight",
    weights = "ResNet50_Weights.IMAGENET1K_V2"
)
model = torch.hub.load(
    "pytorch/vision", 
    "resnet50", 
    weights = weights
)
print(weight for weight in weights)

工具

APIs

TODO

操作

torchvision.ops

任务

classification

detection

segmentation

similarity learning

video classification

其他 APIs