构建神经网络

构建神经网络

神经网络由对数据进行操作的 layers/modules 组成。torch.nn 命名空间提供了构建神经网络所需的所有构建块。PyTorch 中的每个 module 都继承自 nn.Module。神经网络本身是由其他模块(layers)组成的模块。这种嵌套结构允许轻松地构建和管理复杂的体系结构。

接下来我们将构建一个神经网络来对 FashionMNIST 数据集中的图像进行分类。

import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

获取训练设备

我们希望能够在 GPU 之类的硬件加速器上训练我们的模型,如果它可用的话。我们去看看 torch.cuda 可用,否则我们继续使用 CPU。

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))
Using cuda device

定义类

通过子类化 nn.Module 来定义神经网络。模块,并在 __init__ 中初始化神经网络层。每一个神经网络。模块的子类在forward方法中实现对输入数据的操作。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits