Ejemplo n.º 1
0
if __name__ == '__main__':
    use_cuda = len(sys.argv) > 1 and sys.argv[1] == 'cuda'
    num_epochs = 80
    # get cifar 10 data
    trainloader, testloader = get_dataset()
    benchmark, debug = False, True
    resnet = Resnet(n=2,dbg=debug)
    resnet.train()
    if use_cuda:
        resnet = resnet.cuda()
        for block in resnet.residual_blocks:
            block.cuda()
    current_lr = 1e-4
#     optimizer = optim.SGD(resnet.parameters(), lr=current_lr, weight_decay=0.0001, momentum=0.9)
    optimizer = optim.Adam(resnet.parameters(), lr=1e-4, weight_decay=0.0001)
    train_accs, test_accs = [], []
    gradient_norms = []
    def train_model():
      current_lr=1e-4
      stopping_threshold, current_count = 3, 0
      n_iters = 0
      for e in range(num_epochs):
        # modify learning rate at 
          for i, data in enumerate(trainloader, 0):
              x, y = data
              if use_cuda:
                x, y = x.cuda(), y.cuda()
              # zero the grad
              optimizer.zero_grad()
              preds = resnet(x)
def main():
    model = Resnet().to(device)
    # trained_model = resnet18(pretrained=True)
    # print(trained_model)
    # model = nn.Sequential(*list(trained_model.children())[:-1],#[b, 512, 1, 1]
    #                       Flatten(), # [b, 512 ,1,1] >=[b, 512]
    #                       nn.Linear(512, 10),
    #                       ).to(device)
    print(model)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()

    best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [0],
             win='loss',
             opts=dict(title='loss', xlabel='batch', ylabel='loss'))
    viz.line([0], [0],
             win='val_acc',
             opts=dict(title='val_acc', xlabel='batch', ylabel='accuracy'))
    for epoch in range(epochs):

        for step, (x, y) in enumerate(train_loader):
            # x: [b, 3, 224, 224], y: [b]
            # print(x.shape,x,y.shape,y)

            x, y = x.to(device), y.to(device)
            model.train()
            logits = model(x)
            # print('logits is:', logits.cpu().detach().numpy())
            loss = criteon(logits, y)

            # print("loss:", loss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
            print('已进行到:', batchsz * global_step)

        if epoch % 1 == 0:

            val_acc = evalute(model, val_loader)
            print(val_acc)
            # viz.line([val_acc], [global_step], win='val_acc', update='append')
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc

                torch.save(model.state_dict(), 'best_canny.mdl')

            viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch)

    model.load_state_dict(torch.load('best_canny.mdl'))
    print('loaded from ckpt!')

    test_acc = evalute(model, test_loader)

    print('test acc:', test_acc)