logo

PyTorch 数据增强

wangzf / 2024-09-01


目录

数据增强

数据增强(Data Augmentation)已经成为深度学习时代的常规做法, 数据增强目的是为了增加训练数据的丰富度, 让模型接触多样性的数据以增加模型的泛化能力。

通常,数据增强可分为在线(online)与离线(offline)两种方式:

实际上,这两种方法理论上是等价的,一般的框架都采用在线方式的数据增强, PyTorch 的 transforms 就是在线方式。

PyTorch transforms

可以使用 transforms 对数据集进行转换操作,使得数据集可以作为机器学习算法可以使用的形式:

torchvision transforms

torchvision.transforms

transforms 简介

torchvision.transforms 是广泛使用的图像变换库,包含二十多种基础方法以及多种组合功能, 通常可以用 torchvision.transforms.Compose([]) 把各方法串联在一起使用。 大多数的 transforms 类都有对应的 functional transforms,可供用户自定义调整。

torchvision.transforms 库中包含二十多种变换方法,那么多的方法里应该如何挑选, 以及如何设置参数呢?数据增强的方向一定是测试数据集中可能存在的情况。 举个例子,做人脸检测可以用水平翻转(如前置相机的镜像就是水平翻转), 但不宜采用垂直翻转(这里指一般业务场景,特殊业务场景有垂直翻转的人脸就另说)。 因为真实应用场景不存在倒转(垂直翻转)的人脸,因此在训练过程选择数据增强时就不应包含垂直翻转。

所有的 torchvision datasets 都有两个接受包含转换逻辑的可调用对象的参数:

大部分 transform 同时接受 PIL 图像和 tensor 图像, 但是也有一些 tansform 只接受 PIL 图像,或只接受 tensor 图像。 对于 tensor 图像,transform 接受 tensor 图像或批量 tensor 图像:

transforms 机制

开始采用 torchvision.transforms.Compose 把变换的方法包装起来,放到 Dataset 中; 在 DataLoader 依次读数据时,调用 Dataset__getitem__, 每个 sample 读取时,会根据 compose 里的方法依次地对数据进行变换, 以此完成在线数据增强。而具体的 transforms 方法通常包装成一个 Module 类, 具体实现会在各 functional 中。

常用转换

torchvision.transform 模块提供了多个常用转换

torchtext transforms

transforms 简介

常用转换

torchaudio transforms

transforms 简介

常用转换

AIbumentations

参考