def train(): args = load_args() train_gen, dev_gen, test_gen = utils.dataset_iterator(args) torch.manual_seed(1) netG, netD, netE = load_models(args) # optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) vgg_scale = 0.0784 # 1/12.75 mse_criterion = nn.MSELoss() one = torch.FloatTensor([1]).cuda(0) mone = (one * -1).cuda(0) gen = utils.inf_train_gen(train_gen) """ train SRResNet with MSE """ for iteration in range(1, 20000): start_time = time.time() # for p in netD.parameters(): # p.requires_grad = False for p in netE.parameters(): p.requires_grad = False netG.zero_grad() _data = next(gen) real_data, vgg_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) #Perceptual loss #vgg_data_v = autograd.Variable(vgg_data) #vgg_features_real = netE(vgg_data_v) fake = netG(real_data_v) #vgg_features_fake = netE(fake) #diff = vgg_features_fake - vgg_features_real.cuda(0) #perceptual_loss = vgg_scale * ((diff.pow(2)).sum(3).mean()) # mean(sum(square(diff))) #perceptual_loss.backward(one) mse_loss = mse_criterion(fake, real_data_v) mse_loss.backward(one) optimizerG.step() save_dir = './plots/' + args.dataset plot.plot(save_dir, '/mse cost SRResNet', np.round(mse_loss.data.cpu().numpy(), 4)) if iteration % 50 == 49: utils.generate_sr_image(iteration, netG, save_dir, args, real_data_v) if (iteration < 5) or (iteration % 50 == 49): plot.flush() plot.tick() if iteration % 5000 == 0: torch.save(netG.state_dict(), './SRResNet_PL.pt') for iteration in range(args.epochs): start_time = time.time() """ Update AutoEncoder """ for p in netD.parameters(): p.requires_grad = False netG.zero_grad() netE.zero_grad() _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) encoding = netE(real_data_v) fake = netG(encoding) ae_loss = ae_criterion(fake, real_data_v) ae_loss.backward(one) optimizerE.step() optimizerG.step() """ Update D network """ for p in netD.parameters(): # reset requires_grad p.requires_grad = True # they are set to False below in netG update for i in range(5): _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) # train with real data netD.zero_grad() D_real = netD(real_data_v) D_real = D_real.mean() D_real.backward(mone) # train with fake data noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise, volatile=True) # totally freeze netG # instead of noise, use image fake = autograd.Variable(netG(real_data_v).data) inputv = fake D_fake = netD(inputv) D_fake = D_fake.mean() D_fake.backward(one) # train with gradient penalty gradient_penalty = ops.calc_gradient_penalty( args, netD, real_data_v.data, fake.data) gradient_penalty.backward() D_cost = D_fake - D_real + gradient_penalty Wasserstein_D = D_real - D_fake optimizerD.step() # Update generator network (GAN) # noise = torch.randn(args.batch_size, args.dim).cuda() # noisev = autograd.Variable(noise) _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) # again use real data instead of noise fake = netG(real_data_v) G = netD(fake) G = G.mean() G.backward(mone) G_cost = -G optimizerG.step() # Write logs and save samples save_dir = './plots/' + args.dataset plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(), 4)) plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(), 4)) plot.plot(save_dir, '/w1 distance', np.round(Wasserstein_D.cpu().data.numpy(), 4)) # plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(), 4)) # Calculate dev loss and generate samples every 100 iters if iteration % 100 == 99: dev_disc_costs = [] for images, _ in dev_gen(): imgs = stack_data(args, images) imgs_v = autograd.Variable(imgs, volatile=True) D = netD(imgs_v) _dev_disc_cost = -D.mean().cpu().data.numpy() dev_disc_costs.append(_dev_disc_cost) plot.plot(save_dir, '/dev disc cost', np.round(np.mean(dev_disc_costs), 4)) # utils.generate_image(iteration, netG, save_dir, args) # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v) utils.generate_sr_image(iteration, netG, save_dir, args, real_data_v) # Save logs every 100 iters if (iteration < 5) or (iteration % 100 == 99): plot.flush() plot.tick()
def train(): args = load_args() train_gen, test_gen = load_data(args) torch.manual_seed(1) netG, netD, netE = load_models(args) if args.use_spectral_norm: optimizerD = optim.Adam(filter(lambda p: p.requires_grad, netD.parameters()), lr=2e-4, betas=(0.0,0.9)) else: optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.9)) optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.9)) optimizerE = optim.Adam(netE.parameters(), lr=2e-4, betas=(0.5, 0.9)) schedulerD = optim.lr_scheduler.ExponentialLR(optimizerD, gamma=0.99) schedulerG = optim.lr_scheduler.ExponentialLR(optimizerG, gamma=0.99) schedulerE = optim.lr_scheduler.ExponentialLR(optimizerE, gamma=0.99) ae_criterion = nn.MSELoss() one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() iteration = 0 for epoch in range(args.epochs): for i, (data, targets) in enumerate(train_gen): start_time = time.time() """ Update AutoEncoder """ for p in netD.parameters(): p.requires_grad = False netG.zero_grad() netE.zero_grad() real_data_v = autograd.Variable(data).cuda() real_data_v = real_data_v.view(args.batch_size, -1) encoding = netE(real_data_v) fake = netG(encoding) ae_loss = ae_criterion(fake, real_data_v) ae_loss.backward(one) optimizerE.step() optimizerG.step() """ Update D network """ for p in netD.parameters(): p.requires_grad = True for i in range(5): real_data_v = autograd.Variable(data).cuda() # train with real data netD.zero_grad() D_real = netD(real_data_v) D_real = D_real.mean() D_real.backward(mone) # train with fake data noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise, volatile=True) fake = autograd.Variable(netG(noisev).data) inputv = fake D_fake = netD(inputv) D_fake = D_fake.mean() D_fake.backward(one) # train with gradient penalty gradient_penalty = ops.calc_gradient_penalty(args, netD, real_data_v.data, fake.data) gradient_penalty.backward() D_cost = D_fake - D_real + gradient_penalty Wasserstein_D = D_real - D_fake optimizerD.step() # Update generator network (GAN) noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise) fake = netG(noisev) G = netD(fake) G = G.mean() G.backward(mone) G_cost = -G optimizerG.step() schedulerD.step() schedulerG.step() schedulerE.step() # Write logs and save samples save_dir = './plots/'+args.dataset plot.plot(save_dir, '/disc cost', D_cost.cpu().data.numpy()) plot.plot(save_dir, '/gen cost', G_cost.cpu().data.numpy()) plot.plot(save_dir, '/w1 distance', Wasserstein_D.cpu().data.numpy()) plot.plot(save_dir, '/ae cost', ae_loss.data.cpu().numpy()) # Calculate dev loss and generate samples every 100 iters if iteration % 100 == 99: dev_disc_costs = [] for i, (images, targets) in enumerate(test_gen): imgs_v = autograd.Variable(images, volatile=True).cuda() D = netD(imgs_v) _dev_disc_cost = -D.mean().cpu().data.numpy() dev_disc_costs.append(_dev_disc_cost) plot.plot(save_dir ,'/dev disc cost', np.mean(dev_disc_costs)) utils.generate_image(iteration, netG, save_dir, args) # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v) # Save logs every 100 iters if (iteration < 5) or (iteration % 100 == 99): plot.flush() plot.tick() if iteration % 100 == 0: utils.save_model(netG, optimizerG, iteration, 'models/{}/G_{}'.format(args.dataset, iteration)) utils.save_model(netD, optimizerD, iteration, 'models/{}/D_{}'.format(args.dataset, iteration)) iteration += 1
def train(): with torch.cuda.device(1): args = load_args() train_gen, dev_gen, test_gen = utils.dataset_iterator(args) torch.manual_seed(1) netG = first_layer.FirstG(args).cuda() SecondG = second_layer.SecondG(args).cuda() SecondE = second_layer.SecondE(args).cuda() ThridG = third_layer.ThirdG(args).cuda() ThridE = third_layer.ThirdE(args).cuda() ThridD = third_layer.ThirdD(args).cuda() netG.load_state_dict(torch.load('./1stLayer/1stLayerG71999.model')) SecondG.load_state_dict(torch.load('./2ndLayer/2ndLayerG71999.model')) SecondE.load_state_dict(torch.load('./2ndLayer/2ndLayerE71999.model')) ThridE.load_state_dict(torch.load('./3rdLayer/3rdLayerE10999.model')) ThridG.load_state_dict(torch.load('./3rdLayer/3rdLayerG10999.model')) optimizerD = optim.Adam(ThridD.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerG = optim.Adam(ThridG.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerE = optim.Adam(ThridE.parameters(), lr=1e-4, betas=(0.5, 0.9)) ae_criterion = nn.MSELoss() one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() dataLoader = BSDDataLoader(args.dataset, args.batch_size, args) for iteration in range(args.epochs): start_time = time.time() """ Update AutoEncoder """ for p in ThridD.parameters(): p.requires_grad = False ThridG.zero_grad() ThridE.zero_grad() real_data = dataLoader.getNextHDBatch().cuda() real_data_v = autograd.Variable(real_data) encoding = ThridE(real_data_v) fake = ThridG(encoding) ae_loss = ae_criterion(fake, real_data_v) ae_loss.backward(one) optimizerE.step() optimizerG.step() """ Update D network """ for p in ThridD.parameters(): p.requires_grad = True for i in range(5): real_data = dataLoader.getNextHDBatch().cuda() real_data_v = autograd.Variable(real_data) # train with real data ThridD.zero_grad() D_real = ThridD(real_data_v) D_real = D_real.mean() D_real.backward(mone) # train with fake data noise = generateTensor(args.batch_size).cuda() noisev = autograd.Variable(noise, volatile=True) fake = autograd.Variable(ThridG(ThridE(SecondG(SecondE(netG(noisev, True), True)), True)).data) inputv = fake D_fake = ThridD(inputv) D_fake = D_fake.mean() D_fake.backward(one) # train with gradient penalty gradient_penalty = ops.calc_gradient_penalty(args, ThridD, real_data_v.data, fake.data) gradient_penalty.backward() optimizerD.step() # Update generator network (GAN) noise = generateTensor(args.batch_size).cuda() noisev = autograd.Variable(noise) fake = ThridG(ThridE(SecondG(SecondE(netG(noisev, True), True)), True)) G = ThridD(fake) G = G.mean() G.backward(mone) G_cost = -G optimizerG.step() # Write logs and save samples save_dir = './plots/' + args.dataset # Calculate dev loss and generate samples every 100 iters if iteration % 1000 == 999: torch.save(ThridE.state_dict(), './3rdLayer/3rdLayerE%d.model' % iteration) torch.save(ThridG.state_dict(), './3rdLayer/3rdLayerG%d.model' % iteration) utils.generate_image(iteration, netG, save_dir, args) utils.generate_MidImage(iteration, netG, SecondE, SecondG, save_dir, args) utils.generate_HDImage(iteration, netG, SecondE, SecondG, ThridE, ThridG, save_dir, args) if iteration % 2000 == 1999: noise = generateTensor(args.batch_size).cuda() noisev = autograd.Variable(noise, volatile=True) fake = autograd.Variable(ThridG(ThridE(SecondG(SecondE(netG(noisev, True), True)), True)).data) print(inception_score(fake.data.cpu().numpy(), resize=True, batch_size=5)[0]) endtime = time.time() print('iter:', iteration, 'total time %4f' % (endtime-start_time), 'ae loss %4f' % ae_loss.data[0], 'G cost %4f' % G_cost.data[0])
def train(args): torch.manual_seed(1) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() W4 = models.GeneratorW4(args).cuda() W5 = models.GeneratorW5(args).cuda() netD = models.DiscriminatorZ(args).cuda() print(netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW1 = optim.Adam(W1.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW2 = optim.Adam(W2.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW3 = optim.Adam(W3.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW4 = optim.Adam(W4.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW5 = optim.Adam(W5.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc args.best_clf_loss, args.best_clf_acc = np.inf, 0. cifar_train, cifar_test = datagen.load_cifar(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) qz_dist = utils.create_d(args.z * 5) one = torch.tensor(1).cuda() mone = one * -1 print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(700): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = torch.stack(netE(x)).view(-1, args.z * 5) qz = utils.sample_d(qz_dist, e_batch_size) mean_loss, cov_loss = ops.pretrain_loss(codes, qz) loss = mean_loss + cov_loss loss.backward() optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(cifar_train): z = utils.sample_d(x_dist, args.batch_size) ze = utils.sample_d(z_dist, args.batch_size) qz = utils.sample_d(qz_dist, args.batch_size) codes = netE(z) noise = utils.sample_d(qz_dist, args.batch_size) log_pz = ops.log_density(ze, 2).view(-1, 1) d_loss, d_q = ops.calc_d_loss(args, netD, ze, codes, log_pz, cifar=True) optimD.zero_grad() d_loss.backward(retain_graph=True) optimD.step() l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) gp, grads, norms = ops.calc_gradient_penalty(z, [W1, W2, W3, W4, W5], netE, cifar=True) reduce = lambda x: x.mean(0).mean(0).item() grads = [reduce(grad) for grad in grads] clf_loss = 0. for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, target) clf_loss += loss G_loss = clf_loss / args.batch_size one_qz = torch.ones((160, 1), requires_grad=True).cuda() log_qz = ops.log_density(torch.ones(160, 1), 2).view(-1, 1) Q_loss = F.binary_cross_entropy_with_logits(d_q + log_qz, one_qz) total_hyper_loss = Q_loss + G_loss total_hyper_loss.backward() optimE.step() optimW1.step() optimW2.step() optimW4.step() optimW5.step() optimE.zero_grad() optimW1.zero_grad() optimW2.zero_grad() optimW3.zero_grad() optimW4.zero_grad() optimW5.zero_grad() total_loss = total_hyper_loss.item() if batch_idx % 50 == 0: acc = correct print('**************************************') print('Acc: {}, MD Loss: {}, D Loss: {}'.format( acc, total_hyper_loss, d_loss)) #print ('penalties: ', [gp[x].item() for x in range(len(gp))]) print('grads: ', grads) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('best clf acc: {}'.format(args.best_clf_acc)) print('**************************************') if batch_idx > 1 and batch_idx % 100 == 0: test_acc = 0. test_loss = 0. with torch.no_grad(): for i, (data, y) in enumerate(cifar_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, y) test_acc += correct.item() test_loss += loss.item() test_loss /= len(cifar_test.dataset) * args.batch_size test_acc /= len(cifar_test.dataset) * args.batch_size clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5]) stats.update_logger(l1, l2, l3, l4, l5, logger) stats.update_acc(logger, test_acc) stats.update_grad(logger, grads, norms) stats.save_logger(logger, args.exp) stats.plot_logger(logger) print('Test Accuracy: {}, Test Loss: {}'.format( test_acc, test_loss)) print('Clf Accuracy: {}, Clf Loss: {}'.format( clf_acc, clf_loss)) if test_loss < best_test_loss: best_test_loss, args.best_loss = test_loss, test_loss if test_acc > best_test_acc: best_test_acc, args.best_acc = test_acc, test_acc if clf_acc > best_clf_acc: best_clf_acc, args.best_clf_acc = clf_acc, clf_acc utils.save_hypernet_cifar( args, [netE, netD, W1, W2, W3, W4, W5], clf_acc)
def train(): args = load_args() train_gen, dev_gen, test_gen = utils.dataset_iterator(args) torch.manual_seed(1) np.set_printoptions(precision=4) netG, netD, netE = load_models(args) optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerE = optim.Adam(netE.parameters(), lr=1e-4, betas=(0.5, 0.9)) ae_criterion = nn.MSELoss() one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() gen = utils.inf_train_gen(train_gen) preprocess = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) for iteration in range(args.epochs): start_time = time.time() """ Update AutoEncoder """ for p in netD.parameters(): p.requires_grad = False netG.zero_grad() netE.zero_grad() _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) encoding = netE(real_data_v) fake = netG(encoding) ae_loss = ae_criterion(fake, real_data_v) ae_loss.backward(one) optimizerE.step() optimizerG.step() """ Update D network """ for p in netD.parameters(): # reset requires_grad p.requires_grad = True # they are set to False below in netG update for i in range(5): _data = next(gen) real_data = stack_data(args, _data) real_data_v = autograd.Variable(real_data) # train with real data netD.zero_grad() D_real = netD(real_data_v) D_real = D_real.mean() D_real.backward(mone) # train with fake data noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise, volatile=True) # totally freeze netG fake = autograd.Variable(netG(noisev).data) inputv = fake D_fake = netD(inputv) D_fake = D_fake.mean() D_fake.backward(one) # train with gradient penalty gradient_penalty = ops.calc_gradient_penalty( args, netD, real_data_v.data, fake.data) gradient_penalty.backward() D_cost = D_fake - D_real + gradient_penalty Wasserstein_D = D_real - D_fake optimizerD.step() # Update generator network (GAN) noise = torch.randn(args.batch_size, args.dim).cuda() noisev = autograd.Variable(noise) fake = netG(noisev) G = netD(fake) G = G.mean() G.backward(mone) G_cost = -G optimizerG.step() # Write logs and save samples save_dir = './plots/' + args.dataset plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(), 4)) plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(), 4)) plot.plot(save_dir, '/w1 distance', np.round(Wasserstein_D.cpu().data.numpy(), 4)) plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(), 4)) # Calculate dev loss and generate samples every 100 iters if iteration % 100 == 99: dev_disc_costs = [] for images, _ in dev_gen(): imgs = stack_data(args, images) imgs_v = autograd.Variable(imgs, volatile=True) D = netD(imgs_v) _dev_disc_cost = -D.mean().cpu().data.numpy() dev_disc_costs.append(_dev_disc_cost) plot.plot(save_dir, '/dev disc cost', np.round(np.mean(dev_disc_costs), 4)) # utils.generate_image(iteration, netG, save_dir, args) utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v) # Save logs every 100 iters if (iteration < 5) or (iteration % 100 == 99): plot.flush() plot.tick()
def train(args): torch.manual_seed(1) netE = models.Encoderz(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() netD = models.DiscriminatorQz(args).cuda() print(netE, W1, W2, W3) optimE = optim.Adam(netE.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW1 = optim.Adam(W1.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW2 = optim.Adam(W2.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW3 = optim.Adam(W3.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimD = optim.Adam(netD.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=5e-4) best_test_acc, best_clf_acc, best_test_loss, = 0., 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc args.best_clf_loss, args.best_clf_acc = np.inf, 0 mnist_train, mnist_test = datagen.load_mnist(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) qz_dist = utils.create_d(args.z * 3) u_dist = utils.create_uniform() one = torch.tensor(1.).cuda() mone = one * -1 print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e is True: for j in range(1000): x = utils.sample_uniform(u_dist, (e_batch_size, args.ze)) z = utils.sample_uniform(u_dist, (e_batch_size, args.z)) codes = torch.stack(netE(x)).view(-1, args.z * 3) qz = utils.sample_uniform(u_dist, (e_batch_size, args.z * 3)) mean_loss, cov_loss = ops.pretrain_loss(codes, qz) loss = mean_loss + cov_loss loss.backward() optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(mnist_train): z = utils.sample_uniform(u_dist, (args.batch_size, args.ze)) ze = utils.sample_uniform(u_dist, (args.batch_size, args.z)) qz = utils.sample_uniform(u_dist, (args.batch_size, args.z * 3)) codes = netE(z) noise = utils.sample_uniform(u_dist, (args.batch_size, args.z * 3)) log_pz = ops.log_density(ze, 2).view(-1, 1) d_loss, d_q = ops.calc_d_loss(args, netD, ze, codes, log_pz) optimD.zero_grad() d_loss.backward(retain_graph=True) optimD.step() l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) gp, grads, norms = ops.calc_gradient_penalty(z, [W1, W2, W3], netE) reduce = lambda x: x.mean(0).mean(0).item() grads = reduce(grads[0]), reduce(grads[1]), reduce(grads[2]) clf_loss = 0 for (g1, g2, g3) in zip(l1, l2, l3): loss, correct = train_clf(args, [g1, g2, g3], data, target) clf_loss += loss G_loss = clf_loss / args.batch_size # * args.beta one_qz = torch.ones((args.batch_size * 3, 1), requires_grad=True).cuda() log_qz = ops.log_density(torch.ones(args.batch_size * 3, 1), 2).view(-1, 1) Q_loss = F.binary_cross_entropy_with_logits(d_q + log_qz, one_qz) total_hyper_loss = Q_loss + G_loss #+ (gp.sum().cuda())#mean().cuda() total_hyper_loss.backward() optimE.step() optimW1.step() optimW2.step() optimW3.step() optimE.zero_grad() optimW1.zero_grad(), optimW2.zero_grad(), optimW3.zero_grad() total_loss = total_hyper_loss.item() if batch_idx % 50 == 0: acc = correct print('**************************************') print('Iter: {}'.format(len(logger['acc']))) print('Acc: {}, MD Loss: {}, D loss: {}'.format( acc, total_hyper_loss, d_loss)) print('penalties: ', gp[0].item(), gp[1].item(), gp[2].item()) print('grads: ', grads) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('best clf acc: {}'.format(args.best_clf_acc)) print('**************************************') if batch_idx > 1 and batch_idx % 100 == 0: test_acc = 0. test_loss = 0. with torch.no_grad(): for i, (data, y) in enumerate(mnist_test): z = utils.sample_uniform(u_dist, (args.batch_size, args.ze)) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) for (g1, g2, g3) in zip(l1, l2, l3): loss, correct = train_clf(args, [g1, g2, g3], data, y) test_acc += correct.item() test_loss += loss.item() test_loss /= len(mnist_test.dataset) * args.batch_size test_acc /= len(mnist_test.dataset) * args.batch_size clf_acc, clf_loss = test_clf(args, [l1, l2, l3]) stats.update_logger(l1, l2, l3, logger) stats.update_acc(logger, test_acc) #stats.update_grad(logger, grads, norms) #stats.save_logger(logger, args.exp) #stats.plot_logger(logger) print('Test Accuracy: {}, Test Loss: {}'.format( test_acc, test_loss)) print('Clf Accuracy: {}, Clf Loss: {}'.format( clf_acc, clf_loss)) if test_loss < best_test_loss: best_test_loss, args.best_loss = test_loss, test_loss if test_acc > best_test_acc: best_test_acc, args.best_acc = test_acc, test_acc if clf_acc > best_clf_acc: best_clf_acc, args.best_clf_acc = clf_acc, clf_acc utils.save_hypernet_mnist(args, [netE, netD, W1, W2, W3], clf_acc)
def train(): args = load_args() torch.manual_seed(1) netG = first_layer.FirstG(args).cuda() netD = first_layer.FirstD(args).cuda() netE = first_layer.FirstE(args).cuda() optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) optimizerE = optim.Adam(netE.parameters(), lr=1e-4, betas=(0.5, 0.9)) ae_criterion = nn.MSELoss() one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() dataLoader = BSDDataLoader(args.dataset, args.batch_size, args) incep_score = 0 zeros = autograd.Variable(torch.zeros(args.batch_size, 4 * 4 * 5).cuda()) for iteration in range(args.epochs): start_time = time.time() """ Update AutoEncoder """ for p in netD.parameters(): p.requires_grad = False netG.zero_grad() netE.zero_grad() real_data = dataLoader.getNextLoBatch().cuda() real_data_v = autograd.Variable(real_data) encoding = netE(real_data_v) fake = netG(encoding) ae_loss = ae_criterion(fake, real_data_v) + ae_criterion( encoding, zeros) ae_loss.backward(one) optimizerE.step() optimizerG.step() """ Update D network """ for p in netD.parameters(): p.requires_grad = True for i in range(5): real_data = dataLoader.getNextLoBatch().cuda() real_data_v = autograd.Variable(real_data) # train with real data netD.zero_grad() D_real = netD(real_data_v) D_real = D_real.mean() D_real.backward(mone) # train with fake data noise = generateTensor(args.batch_size).cuda() noisev = autograd.Variable(noise, volatile=True) fake = autograd.Variable(netG(noisev, True).data) inputv = fake D_fake = netD(inputv) D_fake = D_fake.mean() D_fake.backward(one) # train with gradient penalty gradient_penalty = ops.calc_gradient_penalty( args, netD, real_data_v.data, fake.data) gradient_penalty.backward() optimizerD.step() # Update generator network (GAN) noise = generateTensor(args.batch_size).cuda() noisev = autograd.Variable(noise) fake = netG(noisev, True) G = netD(fake) G = G.mean() G.backward(mone) G_cost = -G optimizerG.step() # Write logs and save samples save_dir = './plots/' + args.dataset # Calculate dev loss and generate samples every 100 iters if iteration % 1000 == 999: torch.save(netE.state_dict(), './1stLayer/1stLayerE%d.model' % iteration) torch.save(netG.state_dict(), './1stLayer/1stLayerG%d.model' % iteration) utils.generate_image(iteration, netG, save_dir, args) endtime = time.time() if iteration % 2000 == 1999: noise = generateTensor(1000).cuda() noisev = autograd.Variable(noise, volatile=True) fake = autograd.Variable(netG(noisev, True).data) incep_score = (inception_score(fake.data.cpu().numpy(), resize=True, batch_size=5))[0] print('iter:', iteration, 'total time %4f' % (endtime - start_time), 'ae loss %4f' % ae_loss.data[0], 'G cost %4f' % G_cost.data[0], 'inception score %4f' % incep_score)