def main(): # prepare dataset train_adversarial = 1 use_cuda = True epochs = 2000 lr = 0.0005 train_set, test_set = get_datasets(balance_train=True) num_features = train_set[0][0].shape[0] batch_size = 400 test_batch_size = 50 torch.manual_seed(7347) device = 'cuda' if torch.cuda.is_available() else 'cpu' print('using device {0}'.format(device)) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_batch_size, shuffle=True) model = Net(activation=nn.LeakyReLU(), num_features=num_features, embed_size=80).to(device) PATH = None # model.load_state_dict(torch.load(PATH, map_location=device), strict=False) reconstruction_optimizer = optim.AdamW(model.autoenc_params(), lr=lr) discriminative_optimizer = optim.AdamW(model.disc_params(), lr=lr * 0.1) encoder_optimizer = optim.AdamW(model.enc_params(), lr=lr * 0.1) if train_adversarial: compute_loss = compute_loss_adversarial_enc optimizer = {'rec': reconstruction_optimizer, 'dis': discriminative_optimizer, 'enc': encoder_optimizer} tmp = [reconstruction_optimizer, discriminative_optimizer, encoder_optimizer] schedulers = [StepLR(x, step_size=70, gamma=0.9) for x in tmp] else: compute_loss = compute_loss_autoenc optimizer = {'rec': reconstruction_optimizer} schedulers = [StepLR(reconstruction_optimizer, step_size=50, gamma=0.9)] for epoch in range(1, epochs + 1): if epoch % 50 == 0: test(model, compute_loss, device, test_loader) train(model, compute_loss, device, train_loader, optimizer, epoch) for scheduler in schedulers: scheduler.step() if epoch % 100 == 0 and epoch: torch.save(model.state_dict(), "mnist_cnn{0}.pt".format(epoch)) print('learning rate: {0}'.format(scheduler.get_lr()))