变换
内容
变换¶
数据并不总是以训练机器学习算法所需的最终处理形式出现。使用 transforms
来执行一些数据操作,使其适合于训练。所有 TorchVision 数据集都有两个参数:用于修改特征的 transform
和用于修改标签的 target_transform
。它们接受包含变换逻辑的可调用对象。torchvision.transforms 模块提供了几种常用的开箱即用的变换。
FashionMNIST 函数是 PIL 图像格式,标签是整数。为了进行训练,需要将特征作为归一化张量,标签作为一个热编码张量。为了做这些变换,使用 ToTensor
和 Lambda
。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
f = lambda y: torch.zeros(10,
dtype=torch.float).scatter_(0,
torch.tensor(y),
value=1)
ds = datasets.FashionMNIST(
root="../../datasets",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(f)
)
C:\Users\xinet\.conda\envs\torch\lib\site-packages\torchvision\datasets\mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\torch\csrc\utils\tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
ToTensor()
¶
ToTensor
将 PIL
图像或 NumPy ndarray
转换为 FloatTensor
。并将图像的像素强度值在 \([0, 1]\)范围内进行缩放。
Lambda
变换¶
Lambda
转换应用任何用户定义的 lambda
函数。这里,我们定义一个函数将整数转换为一个热编码张量。它首先创建一个大小为 10 的零张量(数据集中标签的数量),然后调用 scatter_
,它在标签 y
给出的索引上赋值为 1
。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))