Пример #1
0
def main():
    epoch = 1000
    batch_size = 32

    cifar_train = datasets.CIFAR10('./datasets',
                                   train=True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor()
                                   ]),
                                   download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)

    cifar_test = datasets.CIFAR10('./datasets',
                                  train=False,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      transforms.ToTensor()
                                  ]),
                                  download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=True)

    x, label = next(iter(cifar_train))
    print('x:', x.shape, 'label:', label.shape)

    # device = torch.device('cuda')  # 使用gpu计算
    model = Lenet5()
    print('model结构:', model)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch_id in range(epoch):
        # train...
        model.train()
        for batch_id, (x, label) in enumerate(cifar_train):
            logits = model(x)  # logits[b,10]  label[b]
            loss = criterion(
                logits, label)  # logits和pred区别:pred是logits经过了softmax处理后的结果

            # 反向传播
            optimizer.zero_grad()  # 为什么要清零:每次反向传播时,不是重新写梯度,而是累加梯度
            loss.backward()
            optimizer.step()  # 更新了参数

        print(epoch_id, loss.item())

        # test...
        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                logits = model(x)
                pred = logits.argmax(dim=1)  # 取最大值的索引 就是分类结果
                total_correct += torch.eq(pred, label).float().sum().item()
                total_num += x.size(0)

            acc = total_correct / total_num  # 准确率
            print(epoch_id, acc)
Пример #2
0
def main():
    batch_size = 16
    cifar_train = datasets.CIFAR10(root='cifar',download=True,train=True,transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]
    ))
    cifar_train = DataLoader(cifar_train,batch_size=batch_size,shuffle=True)
    cifar_test = datasets.CIFAR10(root='cifar',download=True,train=False,transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]))
    cifar_test = DataLoader(cifar_test,batch_size=batch_size,shuffle=True)

    x,label = next(iter(cifar_train))
    print("x:",x.shape,'label:',label.shape)


    device = torch.device('cpu')
    model = Lenet5().to(device)
    criteon = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
    print(model)
    for epoch in range(1000):
        model.train()
        for batchidx, (x,y) in enumerate(cifar_train):
            #[b,3,32,32]
            #[b]
            #loss:tensor scalar
            x,label = x.to(device),label.to(device)
            logits = model(x)
            loss = criteon(logits,label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        print(epoch,loss.item())

        model.eval()
        with torch.no_grad():
            #test
            total_correct =0
            total_num =0
            for x,label in cifar_test:
                x,label = x.to(device),label.to(device)
                #[b,10]
                logits = model(x)
                #[b]
                pred = logits.argmax(dim=1)
                #[b] vs[b] =>scalar tensor
                total_correct +=torch.eq(pred,label).float().sum().item()
                total_num += x.size(0)
            acc = total_correct/total_num
            print(epoch, acc)
Пример #3
0
def main():
    batchsz = 128

    cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz)

    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    ]), download=True)    
    cifar_test = DataLoader(cifar_test, batch_size=batchsz)

    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)
    
    device = torch.device('cuda')
    model = Lenet5().to(device)

    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1000):
        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = criteon(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                x, label = x.to(device), label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)

            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)
Пример #4
0
def main():
    batchsz = 32
    cifar_train = datasets.CIFAR10("cifar",
                                   True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor()
                                   ]),
                                   download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

    cifar_test = datasets.CIFAR10("cifar",
                                  False,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      transforms.ToTensor()
                                  ]),
                                  download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

    x, label = iter(cifar_train).next()
    print("x:", x.shape, "label:", label.shape)

    device = torch.device('cuda')
    model = Lenet5().to(device)
    criteon = nn.CrossEntropyLoss().to(device)
    optimzer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)
    for epoch in range(1000):

        for batchidx, (x, label) in enumerate(cifar_train):
            # [b, 3, 32, 32]
            #[b]
            x, label = x.to(device), label.to(device)

            logits = model(x)
            # logits : [b, 10]
            # label : [b]
            # loss : tensot scalar
            loss = criteon(logits, label)

            #backprop
            optimzer.zero_grad()
            loss.bacward(
            )  # 这里得到的梯度会累加到原来的梯度上面,所以在上一步要有清零操作这样才能得到新的梯度而不是与旧锑度的相加
            optimzer.step()

        print(epoch, loss.item())
Пример #5
0
def main():
    torch.manual_seed(1234)
    net = Lenet5()
    net.eval()
    tmp = torch.ones(1, 1, 32, 32)
    execute_path = os.path.dirname(os.path.realpath(__file__))
    onnx_file = os.path.join(execute_path, "lenet.onnx")
    torch.onnx.export(net, tmp, onnx_file, export_params=True, opset_version=11, input_names = ['input'], output_names = ['output'])
    out = net(tmp)
    print('lenet out shape:', out.shape)
    print('lenet out:', out)

    model = onnx.load(onnx_file)
    model_simp, check = simplify(model)
    onnx_simplify_file = os.path.join(execute_path, "lenet_simplify.onnx")
    onnxmltools.utils.save_model(model_simp, onnx_simplify_file)
Пример #6
0
def main():
    device = torch.device('cuda:0')

    network = Lenet5()
    # 模型加载
    network.load_state_dict(torch.load('saved_models/lenet5_check_point.pkl'))
    network = network.to(device)

    correct_num = 0
    for x_test, y_test in test_loader:
        x_test, y_test = x_test.to(device), y_test.to(device)
        out_test = network(x_test)
        prediction = torch.argmax(out_test, dim=1)
        correct_num += prediction.eq(y_test).sum().float()

    accuracy = correct_num / len(test_loader.dataset)

    print(accuracy.item())
Пример #7
0
def main():

    device = torch.device('cuda:0')
    network = Lenet5().to(device)

    optimizer = optim.SGD(network.parameters(), lr=3e-3, momentum=0.9)

    for epoch in range(5):
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y= x.to(device), y.to(device)
            out = network(x)

            loss = F.cross_entropy(out, y).to(device)
            optimizer.zero_grad()
            loss.backward()
            # w' = w - lr * grad
            optimizer.step()

            if batch_idx % 50 == 0:
                print(epoch, batch_idx, loss.item())

    check_point = network.state_dict()
    torch.save(check_point, 'saved_models/lenet5_check_point.pkl')
Пример #8
0
def main():
    bachsz = 32
    cifar_train = datasets.CIFAR10('cifar',
                                   True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor()
                                   ]),
                                   download=True)  #一次加载一张
    cifar_train = DataLoader(cifar_train, batch_size=bachsz, shuffle=True)

    cifar_test = datasets.CIFAR10('cifar',
                                  False,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      transforms.ToTensor()
                                  ]),
                                  download=True)  # 一次加载一张
    cifar_test = DataLoader(cifar_test, batch_size=bachsz, shuffle=True)

    x, label = iter(cifar_train).nest()
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda')
    model = Lenet5.to(device)
    criteon = nn.CrossEntropyLoss().to(device)  #包含softmax
    optimizer = optim.Adam(model.parameter(), lr=1e-3)
    print(model)
    for epoch in range(1000):
        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            #[b,3,32,32]
            #[b]
            x, label = x.to(device), label.to(device)
            logits = model(x)
            #logits:[b,10]
            #label:[10]
            #loss :tensor scalar长度为0的标量
            loss = criteon(logits, label)

            #backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        #
        print('epoch:', epoch, 'loss', loss.item())
        #loss.item()转换为numpy打印,最后一个batch的loss

        model.eval()  #
        with torch.no_grad():  #不需要backprop
            #test
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                # [b,3,32,32]
                # [b]
                x, label = x.to(device), label.to(device)

                #[b,10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)  #返回数据上dim=1上最大值的索引
                total_correct += pred.eq(pred, label).float().sum().item()
                total_num += x.size(0)
            acc = total_correct / total_num
            print('epoch:', epoch, 'acc:', acc)
Пример #9
0
temp = np.zeros((y_train.size, 10))
temp[np.arange(y_train.size), y_train] = 1
y_train = temp
temp = np.zeros((y_test.size, 10))
temp[np.arange(y_test.size), y_test] = 1
y_test = temp

max_epochs = 10

network = Lenet5(input_dim=(1, 28, 28),
                 conv_param={
                     'filter_num1': 6,
                     'filter_size1': 3,
                     'filter_num2': 16,
                     'filter_size2': 3,
                     'pad': 1,
                     'stride': 1
                 },
                 hidden_size1=120,
                 hidden_size2=84,
                 output_size=10,
                 weight_init_std=0.01)

#经过测试,sgd,momentum无法收敛,adaGrad训练到后期乏力,只有80%准确率。rmsprop和Adam很优秀,特别是adam,收敛速度快,准确率高。
#实验结果符合之前的理论支持
trainer = Trainer(network,
                  x_train,
                  y_train,
                  x_test,
                  y_test,
                  epochs=max_epochs,
Пример #10
0
def main():
    batchsz = 32
    # cifar为torch自带的数据集,使用.来调用目标数据集,第一个位置为路径,如本地没有,则配合download=true在线下载到改路径里;
    # 第二个参数为“是否为训练集”,为布尔值,true的话则下载训练集,false则下载验证集。即训练集和测试集合是分开下载的。
    # 第三个参数trainsform,是将数据集中的图片根据具体的实验模型进行变化,为函数格式
    cifar_train = datasets.CIFAR10('CIFAR10',
                                   True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor()
                                   ]),
                                   download=True)
    #加载训练/测试集
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

    cifar_test = datasets.CIFAR10('CIFAR10',
                                  False,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      transforms.ToTensor()
                                  ]),
                                  download=True)

    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda')
    model = Lenet5().to(device)
    criteon = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    model.train()
    for epoch in range(1000):
        for batchsz, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)

            logits = model(x)
            loss = criteon(logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, loss.item())

    model.eval()
    with torch.no_grad():
        #test
        total_correct = 0
        total_num = 0
        for x, label in cifar_test:
            x, label = x.to(device), label.to(device)

            logits = model(x)
            pred = logits.argmax(dim=1)
            total_correct += torch.eq(pred, label).float().sum().item()
            total_num += x.size(0)

        acc = total_correct / total_num
        print(epoch, acc)
Пример #11
0
def main():

    #load
    batchsz = 128
    cifar_train = datasets.CIFAR10('cifar',
                                   True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor()
                                   ]),
                                   download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    cifar_test = datasets.CIFAR10('cifar',
                                  False,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      transforms.ToTensor()
                                  ]),
                                  download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

    #x, lable = iter(cifar_train).next()
    #print('x:', x.shape, 'lable:', lable.shape)

    device = torch.device('cuda:0')
    model = Lenet5().to(device)
    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    #train
    for epoch in range(1000):
        model.train()
        for batchsz, (x, lable) in enumerate(cifar_train):
            x, lable = x.to(device), lable.to(device)
            logits = model(x)
            #logits [b,10]
            #lable [b]
            #loss: tensor scalar
            loss = criteon(logits, lable)

            #backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item())

        model.eval()
        with torch.no_grad():
            #test
            total_correct = 0
            total_num = 0
            for x, lable in cifar_test:
                x, lable = x.to(device), lable.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                #
                total_correct += torch.eq(pred, lable).float().sum().item()
                total_num += x.size(0)

            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)
Пример #12
0
def main():
    batchsz = 5  # batchsz 大小为5
    logo_train = ReadData(train=True)  # 获得训练集
    # 加载训练集为DataLoader
    logo_train = DataLoader(logo_train, batch_size=batchsz, shuffle=True)
    # 获得测试集
    logo_test = ReadData(train=False)
    logo_test = DataLoader(logo_test, batch_size=batchsz, shuffle=True)

    x, label = iter(logo_train).next()
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda')  # cuda
    # model = Lenet5().to(device) # 使用cuda
    model = Lenet5()  # 实例化模型

    # criteon = nn.CrossEntropyLoss().to(device) # 使用cuda
    criteon = nn.CrossEntropyLoss()  # 不使用cuda
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)  # 打印模型
    for epoch in range(100):
        # 训练
        model.train()
        for batchidx, (x, label) in enumerate(logo_train):
            # [b, 3, 32, 32]
            # [b]
            # x, label = x.to(device), label.to(device)
            logits = model(x)
            # logits: [b, 10]
            # label:  [b]
            # loss: tensor scalar
            loss = criteon(logits, label)  # 计算损失率
            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # 进行预测,计算准确度
        model.eval()
        with torch.no_grad():
            #test
            total_correct = 0
            total_num = 0
            for x, label in logo_test:
                # [b, 3, 32, 32]
                # [b]
                # x, label = x.to(device), label.to(device)
                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                # print('预测:',torch.max(logits,1)[1].data.numpy(), "实际", label[:batchsz].numpy())
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                # print(correct)
            acc = total_correct / total_num
            print('epoch:', epoch, 'acc:', acc)
            if epoch % 10 == 0:
                save_path = "./model/" + str(epoch) + "_model.plk"
                torch.save(model, save_path)  # 保存模型
                print("保存模型成功")
    # 保存最后的模型
    last_model = "./model/last_model.plk"
    torch.save(model, last_model)
Пример #13
0
def main():
    batch_size = 32
    cifar_train = datasets.CIFAR10('cifar', True, transform = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()
    ]), download= True)
    cifar_train = DataLoader(cifar_train, batch_size = batch_size, shuffle = True)

    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=True)


    x, label = iter(cifar_train).next()
    print('x : ', x.shape, 'label : ', label.shape)

    # return

    # device = torch.device('cuda')
    # model = Lenet5().to(device)
    model = Lenet5()
    criteon = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr = 1e-3)
    print(model)

    for epoch in range(1000):

        model.train()
        for batch_idx, (x, label) in enumerate(cifar_train):
            # x, label = x.to(device), label.to(device)
            # x:[batch, 3, 32, 32] label : [b]
            logits = model(x)
            # logits:[b, 10]
            loss = criteon(logits, label)
            # loss: tensor scalar

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, loss.item())

        model.eval()
        with torch.no_grad():
        # test
            total_correct, total_num = 0, 0
            for x, label in cifar_test:
                # x, label = x.to(device), label.to(device)
                # [batch, 10]
                logits = model(x)
                # [batch]
                pred = logits.argmax(dim = 1)
                # [b] vs [b]
                total_correct += torch.eq(pred, label).float().sum()
                total_num += x.size(0)

            acc = total_correct / total_num
            print(epoch, acc)
Пример #14
0
#print(X_train.shape())
print(y_train[0])
print(mnist.train.images.shape, mnist.train.labels.shape)
#print(mnist.train.images.shape, y_train.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
#print(mnist.train.labels[0])
# for i in range(0, mnist.train.labels.shape[0]):
#     mnist.train.labels[i] = y_train[i]

X_train = np.pad(X_train, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant')
X_validation = np.pad(X_validation, ((0, 0), (2, 2), (2, 2), (0, 0)),
                      'constant')
X_test = np.pad(X_test, ((0, 0), (2, 2), (2, 2), (0, 0)), 'constant')

print(X_train.shape)
print(X_validation.shape)
print(X_test.shape)

print("New Input shape: {}".format(X_train[0].shape))

lenet_network = Lenet5(X_train, y_train, X_test, y_test, X_validation,
                       y_validation)
accuracy = lenet_network.train(epochs=150, batch_size=100)
print("Accuracy on test set: {:.3f}".format(accuracy))
'''
# TODO: some refactoring for restoring the model
lenet_network_restored = Lenet5(X_train, y_train, X_test, y_test, X_validation, y_validation)
lenet_network_restored.restore_model(path='tmp/model.ckpt')
'''
Пример #15
0
def main():
    batchsz = 128  #每次投喂的数据量
    #datasets加载CIFAR10数据集到本地,命名为cifar,transform对数据做变换,32*32的大小,自动下载数据集
    cifar_train = datasets.CIFAR10('cifar',
                                   True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           mean=[0.485, 0.456, 0.406],
                                           std=[0.229, 0.224, 0.225])
                                   ]),
                                   download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz,
                             shuffle=True)  #每次导入batchsz那么多的数据
    #定义测试集与训练集一样
    cifar_test = datasets.CIFAR10('cifar',
                                  False,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])
                                  ]),
                                  download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

    x, label = iter(cifar_train).next()  #打印训练集数据和标签形状
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda')  #调用cuda加速
    model = Lenet5().to(device)  #将进入的Lenet5也使用cuda加速

    criteon = nn.CrossEntropyLoss().to(device)  #调用损失函数
    optimizer = optim.Adam(model.parameters(), lr=1e-3)  #调用Adam优化器,
    print(model)  #打印类的实例

    for epoch in range(1000):
        model.train()  #变成训练模式
        for batchidx, (x, label) in enumerate(cifar_train):  #获取数据
            # [b, 3, 32, 32]
            x, label = x.to(device), label.to(device)  #cuda加速
            logits = model(x)  #通过lenet5训练
            # logits: [b, 10]   # label:  [b]
            # loss: tensor scalar
            loss = criteon(logits, label)  #计算损失
            # backprop
            optimizer.zero_grad()  #优化器把梯度清零 防梯度累加
            loss.backward()
            optimizer.step()  #运行优化器走流程
        print(epoch, 'loss:', loss.item())  #打印每次损失,item表示转化成numpy类型

        model.eval()  #变成测试模式
        with torch.no_grad():  #这里告诉pytorch运算时不需计算图的
            # test
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:  #获取测试集数据
                # [b, 3, 32, 32]
                # [b]
                x, label = x.to(device), label.to(device)  #调用cuda

                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)  #在第2个维度上索引最大的值的下标
                # [b] vs [b] => scalar tensor  比较预测值与真实值预测对的数量 eq是否相等
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)  #统计输入总数
                # print(correct)

            acc = total_correct / total_num  #计算平均准确率
            print(epoch, 'test acc:', acc)
Пример #16
0
def main():
    batchsz = 128

    cifar_train = datasets.CIFAR10('../../../use/data/cifar',
                                   True,
                                   transform=transforms.Compose([
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor(),
                                       transforms.Normalize(
                                           mean=[0.485, 0.456, 0.406],
                                           std=[0.229, 0.224, 0.225])
                                   ]),
                                   download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)

    cifar_test = datasets.CIFAR10('../../../use/data/cifar',
                                  False,
                                  transform=transforms.Compose([
                                      transforms.Resize((32, 32)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(
                                          mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])
                                  ]),
                                  download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)

    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)

    device = torch.device('cuda' if torch.cuda.is_available() else 'gpu')
    model = Lenet5().to(device)
    # model = ResNet18().to(device)

    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)

    for epoch in range(1000):

        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            # [b, 3, 32, 32]
            # [b]
            x, label = x.to(device), label.to(device)

            logits = model(x)
            # logits: [b, 10]
            # label:  [b]
            # loss: tensor scalar
            loss = criteon(logits, label)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, 'loss:', loss.item())

        model.eval()
        with torch.no_grad():
            # test
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                # [b, 3, 32, 32]
                # [b]
                x, label = x.to(device), label.to(device)

                # [b, 10]
                logits = model(x)
                # [b]
                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor
                correct = torch.eq(pred, label).float().sum().item()
                total_correct += correct
                total_num += x.size(0)
                # print(correct)

            acc = total_correct / total_num
            print(epoch, 'test acc:', acc)