コード例 #1
0
def main():
    # prepare dataset
    train_adversarial = 1
    use_cuda = True
    epochs = 2000
    lr = 0.0005

    train_set, test_set = get_datasets(balance_train=True)

    num_features = train_set[0][0].shape[0]
    batch_size = 400
    test_batch_size = 50
    torch.manual_seed(7347)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('using device {0}'.format(device))

    train_loader = torch.utils.data.DataLoader(train_set,
        batch_size=batch_size, shuffle=True)

    test_loader = torch.utils.data.DataLoader(test_set,
        batch_size=test_batch_size, shuffle=True)

    model = Net(activation=nn.LeakyReLU(),
            num_features=num_features,
            embed_size=80).to(device)
    PATH = None
    # model.load_state_dict(torch.load(PATH, map_location=device), strict=False)

    reconstruction_optimizer = optim.AdamW(model.autoenc_params(), lr=lr)
    discriminative_optimizer = optim.AdamW(model.disc_params(), lr=lr * 0.1)
    encoder_optimizer = optim.AdamW(model.enc_params(), lr=lr * 0.1)

    if train_adversarial:
        compute_loss = compute_loss_adversarial_enc
        optimizer = {'rec': reconstruction_optimizer,
                     'dis': discriminative_optimizer,
                     'enc': encoder_optimizer}
        tmp = [reconstruction_optimizer,
                discriminative_optimizer,
                encoder_optimizer]
        schedulers = [StepLR(x, step_size=70, gamma=0.9) for x in tmp]
    else:
        compute_loss = compute_loss_autoenc
        optimizer = {'rec': reconstruction_optimizer}
        schedulers = [StepLR(reconstruction_optimizer, step_size=50, gamma=0.9)]

    for epoch in range(1, epochs + 1):
        if epoch % 50 == 0:
            test(model, compute_loss, device, test_loader)
        train(model, compute_loss, device, train_loader, optimizer, epoch)
        for scheduler in schedulers:
            scheduler.step()
        if epoch % 100 == 0 and epoch:
            torch.save(model.state_dict(), "mnist_cnn{0}.pt".format(epoch))
        print('learning rate: {0}'.format(scheduler.get_lr()))