Exemplo n.º 1
0
        layers.FlattenLayer(),
        nn.Linear(512, 10))

    # 下面用分步add的方法构造,看起来比较丑但是输出的时候名称清晰
    net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                        nn.BatchNorm2d(64), nn.ReLU(),
                        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))
    net.add_module("resnet_block4", resnet_block(256, 512, 2))
    net.add_module("global_avg_pool", layers.GlobalAvgPool2d())
    net.add_module("fc",
                   nn.Sequential(layers.FlattenLayer(), nn.Linear(512, 10)))

    X = torch.rand(1, 1, 224, 224)
    for name, layer in net.named_children():
        X = layer(X)
        print(name, ' output shape: ', X.shape)
    print('————————————————————————————')

    # 最后类似之前,进行一下测试,这里也减小了图片大小
    # 1epoch = 189.7sec
    batch_size = 256
    train_iter, test_iter = data_process.load_data_fashion_mnist(batch_size,
                                                                 resize=96)
    lr, num_epochs = 0.001, 5
    optim = torch.optim.Adam(net.parameters(), lr=lr)
    train.train_ch5(net, train_iter, test_iter, batch_size, optim, device,
                    num_epochs)
Exemplo n.º 2
0
# 同上操作的测试集,注意train=False
mnist_test = torchvision.datasets.FashionMNIST(
    root=r"./Datasets", train=False,
    download=True, transform=transforms.ToTensor())

# 可以用len得到数据集的大小,用type得到数据集的类型,用下标得到某个具体的数据
print(len(mnist_train), len(mnist_test))
feature, label = mnist_train[10]
# 能看到图片为1*28*28,是一维的灰度图
print(feature.shape, label)

# 编写好用索引获取标签真名的函数get_fashion_mnist_labels
# 和绘制数据图片的show_fashion_mnist后,下面来尝试显示一下
# 类似之前,这里也有特征X和标签y
X, y = [], []
for i in range(10):
    # 这个数据集的数据组织方式就是第0是特征数据,第1是标签
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
# 将这个得到的特征和标签转换为图片绘制
plot.show_fashion_mnist(X, data_process.get_fashion_mnist_labels(y))

# 将读取数据的流程封装到load_data_fashion_mnist中使用
train_iter, _ = data_process.load_data_fashion_mnist(256)

# 测试下读取完数据所需的时间
start = time.time()
for X, y in train_iter:
    continue
print('%.2f s' % (time.time()-start))