变换

数据并不总是以训练机器学习算法所需的最终处理形式出现。使用 transforms 来执行一些数据操作,使其适合于训练。所有 TorchVision 数据集都有两个参数:用于修改特征的 transform 和用于修改标签的 target_transform。它们接受包含变换逻辑的可调用对象。torchvision.transforms 模块提供了几种常用的开箱即用的变换。

FashionMNIST 函数是 PIL 图像格式,标签是整数。为了进行训练,需要将特征作为归一化张量,标签作为一个热编码张量。为了做这些变换,使用 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda


f = lambda y: torch.zeros(10, 
                          dtype=torch.float).scatter_(0, 
                                                      torch.tensor(y), 
                                                      value=1)


ds = datasets.FashionMNIST(
    root="../../datasets",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(f)
)
C:\Users\xinet\.conda\envs\torch\lib\site-packages\torchvision\datasets\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..\torch\csrc\utils\tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

ToTensor()

ToTensorPIL 图像或 NumPy ndarray 转换为 FloatTensor。并将图像的像素强度值在 \([0, 1]\)范围内进行缩放。

Lambda 变换

Lambda 转换应用任何用户定义的 lambda 函数。这里,我们定义一个函数将整数转换为一个热编码张量。它首先创建一个大小为 10 的零张量(数据集中标签的数量),然后调用 scatter_,它在标签 y 给出的索引上赋值为 1

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))