def train(args, model): torch.manual_seed(1) netE = models.Encoderz(args).cuda() netG = models.Final_Small(args).cuda() netD = models.DiscriminatorQz(args).cuda() print(netE, netG, netD) optimE = optim.Adam(netE.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) best_test_acc, best_clf_acc, best_test_loss, = 0., 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc args.best_clf_loss, args.best_clf_acc = np.inf, 0 mnist_train, mnist_test = datagen.load_mnist(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.tensor(1.).cuda() mone = one * -1 print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e is True: for j in range(500): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) code = netE(x) qz = utils.sample_d(z_dist, e_batch_size) mean_loss, cov_loss = ops.pretrain_loss(code, qz) loss = mean_loss + cov_loss loss.backward() optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(mnist_train): data, target = data.cuda(), target.cuda() z = utils.sample_d(x_dist, args.batch_size) ze = utils.sample_d(z_dist, args.batch_size) qz = utils.sample_d(z_dist, args.batch_size) code = netE(z) log_pz = ops.log_density(ze, 2).view(-1, 1) d_loss, d_q = ops.calc_d_loss(args, netD, ze, [code], log_pz) d_loss.backward(retain_graph=True) optimD.step() optimE.step() optimD.zero_grad() optimE.zero_grad() gen_layers = netG(code) gp, grads, norms = ops.calc_gradient_penalty_layer(z, netG, netE) grads = grads.mean(0).mean(0).item() accs = torch.zeros(len(gen_layers)).cuda() losses = torch.zeros(len(gen_layers)).cuda() for i, layer in enumerate(gen_layers): output = model(data, layer) loss = F.cross_entropy(output, target) pred = output.data.max(1, keepdim=True)[1] correct = pred.eq(target.data.view_as(pred)).long().cpu().sum() losses[i] = loss accs[i] = correct G_loss, correct = losses.max(), accs.mean() one_qz = torch.ones((args.batch_size, 1), requires_grad=True).cuda() log_qz = ops.log_density(torch.ones(args.batch_size, 1), 2).view(-1, 1) Q_loss = F.binary_cross_entropy_with_logits(d_q + log_qz, one_qz) total_hyper_loss = Q_loss + G_loss #+ (gp.sum().cuda())#mean().cuda() total_hyper_loss.backward() optimE.step() optimG.step() optimE.zero_grad() optimG.zero_grad() total_loss = total_hyper_loss.item() if batch_idx % 50 == 0: acc = correct print('**************************************') print('Iter: {}'.format(len(logger['acc']))) print('Acc: {}, MD Loss: {}, D loss: {}'.format( acc, total_hyper_loss, d_loss)) print('penalties: ', gp.item()) print('grads: ', grads) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('best clf acc: {}'.format(args.best_clf_acc)) print('**************************************') if batch_idx % 100 == 0: test_acc = 0. test_loss = 0. with torch.no_grad(): for i, (data, target) in enumerate(mnist_test): data, target = data.cuda(), target.cuda() z = utils.sample_d(x_dist, args.batch_size) code = netE(z) gen_layers = netG(code) for i, layer in enumerate(gen_layers): output = model(data, layer) test_loss += F.cross_entropy(output, target) pred = output.data.max(1, keepdim=True)[1] test_acc += pred.eq( target.data.view_as(pred)).float().sum() test_loss /= len(mnist_test.dataset) * args.batch_size test_acc /= len(mnist_test.dataset) * args.batch_size clf_acc, clf_loss = test_clf(args, gen_layers) stats.update_logger(gen_layers, logger) stats.update_acc(logger, test_acc) stats.update_grad(logger, grads, norms) stats.save_logger(logger, args.exp) stats.plot_logger(logger) print('Test Accuracy: {}, Test Loss: {}'.format( test_acc, test_loss)) print('Clf Accuracy: {}, Clf Loss: {}'.format( clf_acc, clf_loss)) if test_loss < best_test_loss: best_test_loss, args.best_loss = test_loss, test_loss if test_acc > best_test_acc: best_test_acc, args.best_acc = test_acc, test_acc if clf_acc > best_clf_acc: best_clf_acc, args.best_clf_acc = clf_acc, clf_acc utils.save_hypernet_layer(args, [netE, netD, netG], clf_acc)
def train(args): torch.manual_seed(1) netE = models.Encoderz(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() netD = models.DiscriminatorQz(args).cuda() print(netE, W1, W2, W3) if args.resume is not None: d = torch.load(args.resume) netE = utils.load_net_only(netE, d['E']) netD = utils.load_net_only(netD, d['D']) W1 = utils.load_net_only(W1, d['W1']) W2 = utils.load_net_only(W2, d['W2']) W3 = utils.load_net_only(W3, d['W3']) optimE = optim.Adam(netE.parameters(), lr=args.lr, betas=(0.5, 0.9), weight_decay=5e-4) optimW1 = optim.Adam(W1.parameters(), lr=args.lr, betas=(0.5, 0.9), weight_decay=5e-4) optimW2 = optim.Adam(W2.parameters(), lr=args.lr, betas=(0.5, 0.9), weight_decay=5e-4) optimW3 = optim.Adam(W3.parameters(), lr=args.lr, betas=(0.5, 0.9), weight_decay=5e-4) optimD = optim.Adam(netD.parameters(), lr=args.lr, betas=(0.5, 0.9), weight_decay=5e-4) best_test_acc, best_clf_acc, best_test_loss, = 0., 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc args.best_clf_loss, args.best_clf_acc = np.inf, 0 mnist_train, mnist_test = datagen.load_mnist(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) qz_dist = utils.create_d(args.z * 3) one = torch.tensor(1.).cuda() mone = one * -1 print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.resume is None: if args.pretrain_e is True: for j in range(1000): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = torch.stack(netE(x)).view(-1, args.z * 3) qz = utils.sample_d(qz_dist, e_batch_size) mean_loss, cov_loss = ops.pretrain_loss(codes, qz) loss = mean_loss + cov_loss loss.backward() optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'. format(j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(mnist_train): z = utils.sample_d(x_dist, args.batch_size) ze = utils.sample_d(z_dist, args.batch_size) qz = utils.sample_d(qz_dist, args.batch_size) codes = netE(z) noise = utils.sample_d(qz_dist, args.batch_size) log_pz = ops.log_density(ze, 2).view(-1, 1) d_loss, d_q = ops.calc_d_loss(args, netD, ze, codes, log_pz) optimD.zero_grad() d_loss.backward(retain_graph=True) optimD.step() l1_w, l1_b = W1(codes[0]) l2_w, l2_b = W2(codes[1]) l3_w, l3_b = W3(codes[2]) clf_loss = 0 for (g1_w, g1_b, g2_w, g2_b, g3_w, g3_b) in zip(l1_w, l1_b, l2_w, l2_b, l3_w, l3_b): g1 = (g1_w, g1_b) g2 = (g2_w, g2_b) g3 = (g3_w, g3_b) loss, correct = train_clf(args, [g1, g2, g3], data, target) clf_loss += loss G_loss = clf_loss / args.batch_size # * args.beta one_qz = torch.ones((args.batch_size * 3, 1), requires_grad=True).cuda() log_qz = ops.log_density(torch.ones(args.batch_size * 3, 1), 2).view(-1, 1) Q_loss = F.binary_cross_entropy_with_logits(d_q + log_qz, one_qz) total_hyper_loss = Q_loss + G_loss #+ (gp.sum().cuda())#mean().cuda() total_hyper_loss.backward() optimE.step() optimW1.step() optimW2.step() optimW3.step() optimE.zero_grad() optimW1.zero_grad(), optimW2.zero_grad(), optimW3.zero_grad() total_loss = total_hyper_loss.item() if batch_idx % 50 == 0: acc = correct print('**************************************') print('Iter: {}'.format(len(logger['acc']))) print('Acc: {}, MD Loss: {}, D loss: {}'.format( acc, total_hyper_loss, d_loss)) #print ('penalties: ', gp[0].item(), gp[1].item(), gp[2].item()) #print ('grads: ', grads) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) #print ('best clf acc: {}'.format(args.best_clf_acc)) print('**************************************') if batch_idx > 1 and batch_idx % 100 == 0: test_acc = 0. test_loss = 0. with torch.no_grad(): for i, (data, y) in enumerate(mnist_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1_w, l1_b = W1(codes[0]) l2_w, l2_b = W2(codes[1]) l3_w, l3_b = W3(codes[2]) for (g1_w, g1_b, g2_w, g2_b, g3_w, g3_b) in zip(l1_w, l1_b, l2_w, l2_b, l3_w, l3_b): g1 = (g1_w, g1_b) g2 = (g2_w, g2_b) g3 = (g3_w, g3_b) loss, correct = train_clf(args, [g1, g2, g3], data, y) test_acc += correct.item() test_loss += loss.item() test_loss /= len(mnist_test.dataset) * args.batch_size test_acc /= len(mnist_test.dataset) * args.batch_size print('Test Accuracy: {}, Test Loss: {}'.format( test_acc, test_loss)) #print ('Clf Accuracy: {}, Clf Loss: {}'.format(clf_acc, clf_loss)) if test_loss < best_test_loss: best_test_loss, args.best_loss = test_loss, test_loss if test_acc > best_test_acc: # best_clf_acc, args.best_clf_acc = clf_acc, clf_acc utils.save_hypernet_mnist(args, [netE, netD, W1, W2, W3], test_acc) if test_acc > best_test_acc: best_test_acc, args.best_acc = test_acc, test_acc
def train(args): torch.manual_seed(8734) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() netD = models.DiscriminatorZ(args).cuda() print (netE, W1, W2, W3) optimE = optim.Adam(netE.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=1e-5, betas=(0.5, 0.9), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc mnist_train, mnist_test = datagen.load_mnist(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() print ("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(2000): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print ('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print ('Finished Pretraining Encoder') break print ('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(mnist_train): ops.batch_zero_grad([netE, W1, W2, W3, netD]) z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) if args.use_d: ops.free_params([netD]) ops.frozen_params([netE, W1, W2, W3]) for code in codes: noise = utils.sample_d(z_dist, args.batch_size) d_real = netD(noise) d_fake = netD(code) d_real_loss = -1 * torch.log((1-d_real).mean()) d_fake_loss = -1 * torch.log(d_fake.mean()) d_real_loss.backward(retain_graph=True) d_fake_loss.backward(retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() ops.frozen_params([netD]) ops.free_params([netE, W1, W2, W3]) for (g1, g2, g3) in zip(l1, l2, l3): correct, loss = train_clf(args, [g1, g2, g3], data, target) scaled_loss = args.beta * loss scaled_loss.backward(retain_graph=True) optimE.step() optimW1.step() optimW2.step() optimW3.step() loss = loss.item() if batch_idx % 50 == 0: acc = (correct / 1) print ('**************************************') print ('{} MNIST Test, beta: {}'.format(args.model, args.beta)) print ('Acc: {}, Loss: {}'.format(acc, loss)) print ('best test loss: {}'.format(args.best_loss)) print ('best test acc: {}'.format(args.best_acc)) print ('**************************************') if batch_idx > 1 and batch_idx % 199 == 0: test_acc = 0. test_loss = 0. ensemble = 5 for i, (data, y) in enumerate(mnist_test): en1, en2, en3 = [], [], [] for i in range(ensemble): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) rand = np.random.randint(32) en1.append(W1(codes[0])[rand]) en2.append(W2(codes[1])[rand]) en3.append(W3(codes[2])[rand]) g1 = torch.stack(en1).mean(0) g2 = torch.stack(en2).mean(0) g3 = torch.stack(en3).mean(0) correct, loss = train_clf(args, [g1, g2, g3], data, y) test_acc += correct.item() test_loss += loss.item() test_loss /= len(mnist_test.dataset) test_acc /= len(mnist_test.dataset) """ for (g1, g2, g3) in zip(l1, l2, l3): correct, loss = train_clf(args, [g1, g2, g3], data, y) test_acc += correct.item() test_loss += loss.item() test_loss /= len(mnist_test.dataset) * args.batch_size test_acc /= len(mnist_test.dataset) * args.batch_size """ print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss)) if test_loss < best_test_loss or test_acc > best_test_acc: print ('==> new best stats, saving') #utils.save_clf(args, z_test, test_acc) if test_acc > .95: utils.save_hypernet_mnist(args, [netE, W1, W2, W3], test_acc) if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc
def train(args): torch.manual_seed(1) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() W4 = models.GeneratorW4(args).cuda() W5 = models.GeneratorW5(args).cuda() netD = models.DiscriminatorZ(args).cuda() print (netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-3, betas=(0.5, 0.9), weight_decay=5e-4) optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimW4 = optim.Adam(W4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimW5 = optim.Adam(W5.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc args_best_clf_loss, args.best_clf_acc = np.inf, 0. cifar_train, cifar_test = datagen.load_cifar(args)#_hidden(args, [0, 1, 2, 3, 4]) one = torch.tensor(1).cuda() mone = (one * -1).cuda() print ("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(1000): x = utils.sample_z_like((e_batch_size, args.ze)) z = utils.sample_z_like((e_batch_size, args.z)) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print ('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print ('Finished Pretraining Encoder') break bb = 0 print ('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(cifar_train): utils.batch_zero_grad([netE, W1, W2, W3, W4, W5, netD]) z1 = utils.sample_z_like((args.batch_size, args.z,)) z2 = utils.sample_z_like((args.batch_size, args.z,)) z3 = utils.sample_z_like((args.batch_size, args.z,)) z4 = utils.sample_z_like((args.batch_size, args.z,)) z5 = utils.sample_z_like((args.batch_size, args.z,)) #codes = netE(z) codes = [z1, z2, z3, z4, z5] l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) """ # Z Adversary for code in codes: noise = utils.sample_z_like((args.batch_size, args.z)) d_real = netD(noise) d_fake = netD(code) d_real_loss = -1 * torch.log((1-d_real).mean()) d_fake_loss = -1 * torch.log(d_fake.mean()) d_real_loss.backward(retain_graph=True) d_fake_loss.backward(retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() """ clf_loss = 0 for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, target) clf_loss += loss clf_loss *= args.beta clf_loss.backward() optimE.step(); optimW1.step(); optimW2.step() optimW3.step(); optimW4.step(); optimW5.step() utils.batch_zero_grad([optimE, optimW1, optimW2, optimW3, optimW4, optimW5]) loss = loss.item() """ Update Statistics """ if batch_idx % 50 == 0: bb += 1 acc = correct print ('**************************************') print ("epoch: {}".format(bb)) print ('Mean Test: Enc, Dz, Lscale: {} test'.format(args.beta)) print ('Acc: {}, G Loss: {}, D Loss: {}'.format(acc, loss, 0))# d_loss)) print ('best test loss: {}'.format(args.best_loss)) print ('best test acc: {}'.format(args.best_acc)) print ('best clf acc: {}'.format(best_clf_acc)) print ('**************************************') if batch_idx % 100 == 0: with torch.no_grad(): test_acc = 0. test_loss = 0. for i, (data, y) in enumerate(cifar_test): w1_code = utils.sample_z_like((args.batch_size, args.z,)) w2_code = utils.sample_z_like((args.batch_size, args.z,)) w3_code = utils.sample_z_like((args.batch_size, args.z,)) w4_code = utils.sample_z_like((args.batch_size, args.z,)) w5_code = utils.sample_z_like((args.batch_size, args.z,)) #w1_code, w2_code, w3_code, w4_code, w5_code = netE(z) l1 = W1(w1_code) l2 = W2(w2_code) l3 = W3(w3_code) l4 = W4(w4_code) l5 = W5(w5_code) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, y) test_acc += correct.item() test_loss += loss.item() #clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5]) test_loss /= len(cifar_test.dataset) * args.batch_size test_acc /= len(cifar_test.dataset) * args.batch_size stats.update_logger(l1, l2, l3, l4, l5, logger) stats.update_acc(logger, test_acc) #stats.update_grad(logger, grads, norms) stats.save_logger(logger, args.exp) stats.plot_logger(logger) print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss)) #print ('Clf Accuracy: {}, Clf Loss: {}'.format(clf_acc, clf_loss)) if test_loss < best_test_loss: best_test_loss, args.best_loss = test_loss, test_loss if test_acc > best_test_acc: best_test_acc, args.best_acc = test_acc, test_acc
def train(args): torch.manual_seed(1) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() W4 = models.GeneratorW4(args).cuda() W5 = models.GeneratorW5(args).cuda() netD = models.DiscriminatorZ(args).cuda() print(netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW1 = optim.Adam(W1.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW2 = optim.Adam(W2.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW3 = optim.Adam(W3.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW4 = optim.Adam(W4.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW5 = optim.Adam(W5.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc args.best_clf_loss, args.best_clf_acc = np.inf, 0. cifar_train, cifar_test = datagen.load_cifar(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) qz_dist = utils.create_d(args.z * 5) one = torch.tensor(1).cuda() mone = one * -1 print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(700): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = torch.stack(netE(x)).view(-1, args.z * 5) qz = utils.sample_d(qz_dist, e_batch_size) mean_loss, cov_loss = ops.pretrain_loss(codes, qz) loss = mean_loss + cov_loss loss.backward() optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(cifar_train): z = utils.sample_d(x_dist, args.batch_size) ze = utils.sample_d(z_dist, args.batch_size) qz = utils.sample_d(qz_dist, args.batch_size) codes = netE(z) noise = utils.sample_d(qz_dist, args.batch_size) log_pz = ops.log_density(ze, 2).view(-1, 1) d_loss, d_q = ops.calc_d_loss(args, netD, ze, codes, log_pz, cifar=True) optimD.zero_grad() d_loss.backward(retain_graph=True) optimD.step() l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) gp, grads, norms = ops.calc_gradient_penalty(z, [W1, W2, W3, W4, W5], netE, cifar=True) reduce = lambda x: x.mean(0).mean(0).item() grads = [reduce(grad) for grad in grads] clf_loss = 0. for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, target) clf_loss += loss G_loss = clf_loss / args.batch_size one_qz = torch.ones((160, 1), requires_grad=True).cuda() log_qz = ops.log_density(torch.ones(160, 1), 2).view(-1, 1) Q_loss = F.binary_cross_entropy_with_logits(d_q + log_qz, one_qz) total_hyper_loss = Q_loss + G_loss total_hyper_loss.backward() optimE.step() optimW1.step() optimW2.step() optimW4.step() optimW5.step() optimE.zero_grad() optimW1.zero_grad() optimW2.zero_grad() optimW3.zero_grad() optimW4.zero_grad() optimW5.zero_grad() total_loss = total_hyper_loss.item() if batch_idx % 50 == 0: acc = correct print('**************************************') print('Acc: {}, MD Loss: {}, D Loss: {}'.format( acc, total_hyper_loss, d_loss)) #print ('penalties: ', [gp[x].item() for x in range(len(gp))]) print('grads: ', grads) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('best clf acc: {}'.format(args.best_clf_acc)) print('**************************************') if batch_idx > 1 and batch_idx % 100 == 0: test_acc = 0. test_loss = 0. with torch.no_grad(): for i, (data, y) in enumerate(cifar_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, y) test_acc += correct.item() test_loss += loss.item() test_loss /= len(cifar_test.dataset) * args.batch_size test_acc /= len(cifar_test.dataset) * args.batch_size clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5]) stats.update_logger(l1, l2, l3, l4, l5, logger) stats.update_acc(logger, test_acc) stats.update_grad(logger, grads, norms) stats.save_logger(logger, args.exp) stats.plot_logger(logger) print('Test Accuracy: {}, Test Loss: {}'.format( test_acc, test_loss)) print('Clf Accuracy: {}, Clf Loss: {}'.format( clf_acc, clf_loss)) if test_loss < best_test_loss: best_test_loss, args.best_loss = test_loss, test_loss if test_acc > best_test_acc: best_test_acc, args.best_acc = test_acc, test_acc if clf_acc > best_clf_acc: best_clf_acc, args.best_clf_acc = clf_acc, clf_acc utils.save_hypernet_cifar( args, [netE, netD, W1, W2, W3, W4, W5], clf_acc)
def train(args): from torch import optim #torch.manual_seed(8734) netE = models.Encoderz(args).cuda() netD = models.DiscriminatorZ(args).cuda() E1 = models.GeneratorE1(args).cuda() E2 = models.GeneratorE2(args).cuda() #E3 = models.GeneratorE3(args).cuda() #E4 = models.GeneratorE4(args).cuda() #D1 = models.GeneratorD1(args).cuda() D1 = models.GeneratorD2(args).cuda() D2 = models.GeneratorD3(args).cuda() D3 = models.GeneratorD4(args).cuda() print(netE, netD) print(E1, E2, D1, D2, D3) optimE = optim.Adam(netE.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) Eoptim = [ optim.Adam(E1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4), optim.Adam(E2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4), #optim.Adam(E3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4), #optim.Adam(E4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) ] Doptim = [ #optim.Adam(D1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4), optim.Adam(D1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4), optim.Adam(D2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4), optim.Adam(D3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) ] Enets = [E1, E2] Dnets = [D1, D2, D3] best_test_loss = np.inf args.best_loss = best_test_loss mnist_train, mnist_test = datagen.load_mnist(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(100): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(mnist_train): netE.zero_grad() for optim in Eoptim: optim.zero_grad() for optim in Doptim: optim.zero_grad() z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) for code in codes: noise = utils.sample_z_like((args.batch_size, args.z)) d_real = netD(noise) d_fake = netD(code) d_real_loss = torch.log((1 - d_real).mean()) d_fake_loss = torch.log(d_fake.mean()) d_real_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(), retain_graph=True) d_fake_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(), retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() netD.zero_grad() z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) Eweights, Dweights = [], [] i = 0 for net in Enets: Eweights.append(net(codes[i])) i += 1 for net in Dnets: Dweights.append(net(codes[i])) i += 1 d_real = [] for code in codes: d = netD(code) d_real.append(d) netD.zero_grad() d_loss = torch.stack(d_real).log().mean() * 10. for layers in zip(*(Eweights + Dweights)): loss, _ = train_clf(args, layers, data, target) scaled_loss = args.beta * loss scaled_loss.backward(retain_graph=True) d_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(), retain_graph=True) optimE.step() for optim in Eoptim: optim.step() for optim in Doptim: optim.step() loss = loss.item() if batch_idx % 50 == 0: print('**************************************') print('AE MNIST Test, beta: {}'.format(args.beta)) print('MSE Loss: {}'.format(loss)) print('D loss: {}'.format(d_loss)) print('best test loss: {}'.format(args.best_loss)) print('**************************************') if batch_idx > 1 and batch_idx % 199 == 0: test_acc = 0. test_loss = 0. for i, (data, y) in enumerate(mnist_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) Eweights, Dweights = [], [] i = 0 for net in Enets: Eweights.append(net(codes[i])) i += 1 for net in Dnets: Dweights.append(net(codes[i])) i += 1 for layers in zip(*(Eweights + Dweights)): loss, out = train_clf(args, layers, data, y) test_loss += loss.item() if i == 10: break test_loss /= 10 * len(y) * args.batch_size print('Test Loss: {}'.format(test_loss)) if test_loss < best_test_loss: print('==> new best stats, saving') #utils.save_clf(args, z_test, test_acc) if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss archE = sampleE(args).cuda() archD = sampleD(args).cuda() rand = np.random.randint(args.batch_size) eweight = list(zip(*Eweights))[rand] dweight = list(zip(*Dweights))[rand] modelE = utils.weights_to_clf(eweight, archE, args.statE['layer_names']) modelD = utils.weights_to_clf(dweight, archD, args.statD['layer_names']) utils.generate_image(args, batch_idx, modelE, modelD, data.cuda())
def train(args): torch.manual_seed(8734) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() netD = models.DiscriminatorQ(args).cuda() netQ = Q().cuda() print(netE, W1, W2, W3, netD, netQ) netD.apply(weight_init) optimE = optim.Adam(netE.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=2e-4, betas=(0.5, 0.999), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=2e-4, betas=(0.5, 0.999), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=2e-4, betas=(0.5, 0.999), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999), weight_decay=1e-4) #q_params = itertools.chain(W2.parameters(), netD.parameters()) q_params = itertools.chain(W2.parameters(), netQ.parameters()) optimQ = optim.Adam(q_params, lr=2e-4, betas=(0.5, 0.999), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc mnist_train, mnist_test = datagen.load_mnist(args) real_filters = x_gen() x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() if args.pretrain_e: j = 0 final = 100. e_batch_size = 1000 print("==> pretraining encoder") for j in range(300): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break c_acc, c_loss = [], [] print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(mnist_train): """ generate encoding """ valid = torch.ones((args.batch_size, 1), dtype=torch.float32, requires_grad=False).cuda() fake = torch.zeros((args.batch_size, 1), dtype=torch.float32, requires_grad=False).cuda() real = torch.tensor(next(real_filters), requires_grad=True).cuda() labels = to_categorical(target.numpy(), cols=10) ops.batch_zero_grad([optimE, optimW1, optimW2, optimW3]) z = utils.sample_d(x_dist, args.batch_size) c = torch.tensor( np.random.uniform( -1, 1, (args.batch_size, args.factors))).float().cuda() ycat = torch.tensor(np.random.randint( 0, 10, args.batch_size)).long().cuda() y = to_categorical(ycat, cols=10).float() codes = netE(z) """ train generator """ l1 = W1(codes[0]) l2 = W2(codes[1], y) #, c) l3 = W3(codes[2]) clf_loss = [] for i, (g1, g2, g3) in enumerate(zip(l1, l2, l3)): #d_valid, d_y, d_f = netD(g2) correct, loss = train_clf(args, [g1, g2, g3], data, target) clf_loss.append(loss) #adv_loss = F.mse_loss(d_valid, valid) scaled_loss = args.beta * loss # + adv_loss scaled_loss.backward(retain_graph=True) optimE.step() optimW1.step() optimW2.step() optimW3.step() loss = torch.stack(clf_loss).mean().item() """ train discriminator """ """ optimD.zero_grad() for g2 in l2: real_pred, _, _ = netD(real) d_real_loss = F.mse_loss(real_pred, valid) fake_pred, _, _ = netD(g2.detach()) d_fake_loss = F.mse_loss(fake_pred, fake) d_loss = (d_real_loss + d_fake_loss) / 2 d_loss.backward(retain_graph=True) optimD.step() """ """ MI loss """ # want to maximize the mutaul information between the labels and a given filter # last conv layer only optimQ.zero_grad() optimE.zero_grad() sampled_labels = np.random.randint(0, 10, args.batch_size) gt_labels = torch.tensor(sampled_labels, requires_grad=False).long().cuda() label = to_categorical(sampled_labels, cols=10) #code = torch.tensor(np.random.normal(-1, 1, # (args.batch_size, args.factors))).float().cuda() z = utils.sample_d(x_dist, args.batch_size) embedding = netE(z)[1] gen_final_conv = W2(embedding, label) #, code) inter_acc, inter_loss = 0, 10. for m in gen_final_conv: #_, pred_label, pred_code = netD(m) #mi_loss, (cat_acc, cat_loss) = MI_loss(args, gt_labels, pred_label, code, pred_code) q_pred = netQ(m) pred = q_pred.max(1, keepdim=True)[1] inter_acc = pred.eq( gt_labels.data.view_as(pred)).long().cpu().sum() inter_loss = F.cross_entropy(q_pred, gt_labels) mi_loss = args.alpha * inter_loss """ if cat_acc.item() > inter_acc: inter_acc = cat_acc.item() if cat_loss.item() < inter_loss: inter_loss = cat_loss.item() """ mi_loss.backward(retain_graph=True) optimQ.step() optimE.step() c_acc.append(inter_acc) c_loss.append(inter_loss) if batch_idx % 50 == 0: acc = (correct / 1) print('**************************************') print('{} MNIST Test, beta: {}'.format(args.model, args.beta)) print('Acc: {}, Loss: {}, MI loss: {}'.format( acc, loss, mi_loss)) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('categorical acc: {}'.format( torch.tensor(c_acc, dtype=torch.float).max() / len(label))) print('categorical loss: {}'.format( torch.tensor(c_loss, dtype=torch.float).max() / len(label))) print('**************************************') c_acc, c_loss = [], [] #if batch_idx > 1 and batch_idx % 199 == 0: if batch_idx % 199 == 0: test_acc = 0. test_loss = 0. for i, (data, target) in enumerate(mnist_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) c = torch.tensor( np.random.uniform( -1, 1, (args.batch_size, args.factors))).float().cuda() y = to_categorical(np.random.randint( 0, 10, args.batch_size), cols=10).float().cuda() l1 = W1(codes[0]) l2 = W2(codes[1], y) #, c) l3 = W3(codes[2]) #sample1, sample2 = sample_layer(args, netE, W2, 10) for (g1, g2, g3) in zip(l1, l2, l3): correct, loss = train_clf(args, [g1, g2, g3], data, target) test_acc += correct.item() test_loss += loss.item() test_loss /= len(mnist_test.dataset) * args.batch_size test_acc /= len(mnist_test.dataset) * args.batch_size print('Accuracy: {}, Loss: {}'.format(test_acc, test_loss)) if test_loss < best_test_loss or test_acc > best_test_acc: print('==> new best stats, saving') utils.save_clf(args, [g1, g2, g3], test_acc) print('this') #args.exp = 'sample1' #utils.save_clf(args, [g1, sample1, g3], test_acc) #args.exp = 'sample2' #utils.save_clf(args, [g1, sample2, g3], test_acc) #utils.save_hypernet_mnist(args, [netE, W1, W2, W3], test_acc) if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc
def train(args): torch.manual_seed(8734) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() W4 = models.GeneratorW4(args).cuda() W5 = models.GeneratorW5(args).cuda() netD = models.DiscriminatorZ(args).cuda() print(netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW4 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW5 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc mnist_train, mnist_test = datagen.load_mnist(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(100): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(mnist_train): netE.zero_grad() W1.zero_grad() W2.zero_grad() W3.zero_grad() W4.zero_grad() W5.zero_grad() z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) #ops.free_params([netD]); ops.frozen_params([netE, W1, W2, W3]) for code in codes: noise = utils.sample_z_like((args.batch_size, args.z)) d_real = netD(noise) d_fake = netD(code) d_real_loss = torch.log((1 - d_real).mean()) d_fake_loss = torch.log(d_fake.mean()) d_real_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(), retain_graph=True) d_fake_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(), retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() #ops.frozen_params([netD]) #ops.free_params([netE, W1, W2, W3]) netD.zero_grad() z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) d_real = [] for code in codes: d = netD(code) d_real.append(d) netD.zero_grad() d_loss = torch.stack(d_real).log().mean() * 10. for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data, target) scaled_loss = args.beta * loss scaled_loss.backward(retain_graph=True) d_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(), retain_graph=True) optimE.step() optimW1.step() optimW2.step() optimW3.step() optimW4.step() optimW5.step() loss = loss.item() if batch_idx % 50 == 0: acc = correct print('**************************************') print('{} MNIST Test, beta: {}'.format(args.model, args.beta)) print('Acc: {}, Loss: {}'.format(acc, loss)) print('D loss: {}'.format(d_loss)) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('**************************************') if batch_idx > 1 and batch_idx % 199 == 0: test_acc = 0. test_loss = 0. for i, (data, y) in enumerate(mnist_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data, y) test_acc += correct.item() test_loss += loss.item() test_loss /= len(mnist_test.dataset) * args.batch_size test_acc /= len(mnist_test.dataset) * args.batch_size print('Test Accuracy: {}, Test Loss: {}'.format( test_acc, test_loss)) if test_loss < best_test_loss or test_acc > best_test_acc: print('==> new best stats, saving') #utils.save_clf(args, z_test, test_acc) if test_acc > .85: utils.save_hypernet_cifar(args, [netE, W1, W2, W3, W4, W5], test_acc) if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc
def train(args): torch.manual_seed(8734) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() W4 = models.GeneratorW4(args).cuda() W5 = models.GeneratorW5(args).cuda() netD = models.DiscriminatorZ(args).cuda() print(netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-3, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW4 = optim.Adam(W4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW5 = optim.Adam(W5.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc cifar_train, cifar_test = datagen.load_cifar(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(700): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(1000): for batch_idx, (data, target) in enumerate(cifar_train): batch_zero_grad([netE, W1, W2, W3, W4, W5, netD]) z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]).mean(0) l2 = W2(codes[1]).mean(0) l3 = W3(codes[2]).mean(0) l4 = W4(codes[3]).mean(0) l5 = W5(codes[4]).mean(0) # Z Adversary free_params([netD]) frozen_params([netE, W1, W2, W3, W4, W5]) for code in codes: noise = utils.sample_d(z_dist, args.batch_size) d_real = netD(noise) d_fake = netD(code) d_real_loss = -1 * torch.log((1 - d_real).mean()) d_fake_loss = -1 * torch.log(d_fake.mean()) d_real_loss.backward(retain_graph=True) d_fake_loss.backward(retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() frozen_params([netD]) free_params([netE, W1, W2, W3, W4, W5]) correct, loss = train_clf(args, [l1, l2, l3, l4, l5], data, target, val=True) scaled_loss = args.beta * loss scaled_loss.backward() optimE.step() optimW1.step() optimW2.step() optimW3.step() optimW4.step() optimW5.step() loss = loss.item() """ Update Statistics """ if batch_idx % 50 == 0: acc = (correct / 1) print('**************************************') print('{} CIFAR Test, beta: {}'.format(args.model, args.beta)) print('Acc: {}, G Loss: {}, D Loss: {}'.format( acc, loss, d_loss)) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('**************************************') if batch_idx > 1 and batch_idx % 199 == 0: test_acc = 0. test_loss = 0. total_correct = 0. for i, (data, y) in enumerate(cifar_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]).mean(0) l2 = W2(codes[1]).mean(0) l3 = W3(codes[2]).mean(0) l4 = W4(codes[3]).mean(0) l5 = W5(codes[4]).mean(0) correct, loss = train_clf(args, [l1, l2, l3, l4, l5], data, y, val=True) test_acc += correct.item() total_correct += correct.item() test_loss += loss.item() test_loss /= len(cifar_test.dataset) test_acc /= len(cifar_test.dataset) print('Test Accuracy: {}, Test Loss: {}, ({}/{})'.format( test_acc, test_loss, total_correct, len(cifar_test.dataset))) if test_loss < best_test_loss or test_acc > best_test_acc: print('==> new best stats, saving') utils.save_clf(args, [l1, l2, l3, l4, l5], test_acc) #utils.save_hypernet_cifar(args, [netE, W1, W2, W3, W4, W5, netD], test_acc) if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc
def train(args): torch.manual_seed(8734) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() netD = models.DiscriminatorZ(args).cuda() Aux = AuxDz(args).cuda() print (netE, W1, W2, W3, Aux)#netD) optimE = optim.Adam(netE.parameters(), lr=.0005, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=1e-4) optimAux = optim.Adam(Aux.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc mnist_train, mnist_test = datagen.load_mnist(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() print ("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.load_e: netE, optimE, _ = utils.load_model(args, netE, optimE) print ('==> loading pretrained encoder') if args.pretrain_e: for j in range(2000): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print ('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print ('Finished Pretraining Encoder') break print ('==> Begin Training') for _ in range(1000): for batch_idx, (data, target) in enumerate(mnist_train): ops.batch_zero_grad([netE, W1, W2, W3, netD]) ops.batch_zero_grad([optimE, optimW1, optimW2, optimW3, optimAux]) z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]).mean(0) l2 = W2(codes[1]).mean(0) l3 = W3(codes[2]).mean(0) if args.use_aux: #free_params([Aux]) #frozen_params([netE, W1, W2, W3]) # just do the last conv layer # split latent space into chunks -- each representing a class factors = torch.split(codes[1], args.z//10, 1) for y, factor in enumerate(factors): target = (torch.ones(args.batch_size, dtype=torch.long) * y).cuda() aux_pred = Aux(factor) aux_loss = F.cross_entropy(aux_pred, target) aux_loss.backward(retain_graph=True) optimAux.step() #frozen_params([Aux]) #free_params([netE, W1, W2, W3]) correct, loss = train_clf(args, [l1, l2, l3], data, target, val=True) scaled_loss = args.beta*loss scaled_loss.backward() optimE.step() optimW1.step() optimW2.step() optimW3.step() loss = loss.item() if batch_idx % 50 == 0: acc = (correct / 1) print ('**************************************') print ('MNIST Test, beta: {}'.format(args.beta)) print ('Acc: {}, Loss: {}'.format(acc, loss)) print ('best test loss: {}'.format(args.best_loss)) print ('best test acc: {}'.format(args.best_acc)) print ('**************************************') if batch_idx % 200 == 0: test_acc = 0. test_loss = 0. for i, (data, y) in enumerate(mnist_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]).mean(0) l2 = W2(codes[1]).mean(0) l3 = W3(codes[2]).mean(0) min_loss_batch = 10. correct, loss = train_clf(args, [l1, l2, l3], data, y, val=True) test_acc += correct.item() test_loss += loss.item() test_loss /= len(mnist_test.dataset) test_acc /= len(mnist_test.dataset) print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss)) if test_loss < best_test_loss or test_acc > best_test_acc: print ('==> new best stats, saving') if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc
def train(args): torch.manual_seed(8734) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() #netQ = models.DiscriminatorQ(args).cuda() netQ = Q().cuda() print (netE, W1, W2, W3, netQ) optimE = optim.Adam(netE.parameters(), lr=.0005, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) q_params = list(W2.parameters()) + list(netQ.parameters()) optimQ = optim.Adam(q_params, lr=2e-4, betas=(0.5, 0.9), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc mnist_train, mnist_test = datagen.load_mnist(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() print ("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.load_e: netE, optimE, _ = utils.load_model(args, netE, optimE) print ('==> loading pretrained encoder') if args.pretrain_e: for j in range(1000): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print ('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print ('Finished Pretraining Encoder') break print ('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(mnist_train): ops.batch_zero_grad([netE, W1, W2, W3, netQ]) """ generate encoding """ z = utils.sample_d(x_dist, args.batch_size) c = sample_categorical(args.batch_size) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1], c) l3 = W3(codes[2]) """ train discriminator """ if args.use_d: for _ in range(args.disc_iters): d_losses = [] noise = utils.sample_d(z_dist, args.batch_size) codes_d = netE(z) for code in codes_d: d_real = netD(noise) d_fake = netD(code) d_losses.append(-(torch.mean(d_real) - torch.mean(d_fake))) for d_loss in d_losses: d_loss.backward(retain_graph=True) optimD.step() netD.zero_grad(); optimD.zero_grad() """ train generator """ d_fake, clf_loss = [], [] if args.use_d: for code in codes: d_fake.append(netD(code)) for i, (g1, g2, g3) in enumerate(zip(l1, l2, l3)): correct, loss = train_clf(args, [g1, g2, g3], data, target) clf_loss.append(loss) if args.use_d: scaled_loss = args.beta * (loss + d_fake[i]) else: scaled_loss = args.beta * loss scaled_loss.backward(retain_graph=True) optimE.step() optimW1.step() optimW2.step() optimW3.step() loss = torch.stack(clf_loss).mean().item() """ MI loss """ # want to maximize the mutaul information between the labels and a given filter # last conv layer only netQ.zero_grad() gen_final_conv = W2(codes[1], c) for m in gen_final_conv: q_c_x = netQ(m) mi_loss = MI_loss(args, q_c_x, c) _, q_loss = embedding_clf(args, m, netQ, c) (q_loss * 10).backward(retain_graph=True) mi_loss.backward(retain_graph=True) optimE.step() optimW2.step() optimQ.step() if batch_idx % 50 == 0: acc = (correct / 1) print ('**************************************') print ('{} MNIST Test, beta: {}'.format(args.model, args.beta)) print ('Acc: {}, Loss: {}, MI loss: {}, Q loss: {}'.format(acc, loss, mi_loss, q_loss)) print ('best test loss: {}'.format(args.best_loss)) print ('best test acc: {}'.format(args.best_acc)) print ('**************************************') #if batch_idx > 1 and batch_idx % 199 == 0: if batch_idx % 199 == 0: test_acc = 0. test_loss = 0. q_test_acc = 0. q_test_loss = 0. for i, (data, y) in enumerate(mnist_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) idx = [0, 0] + [i for i in range(10) for _ in range(3)] c = np.zeros([args.batch_size, 10]) c[range(args.batch_size), idx] = 1 c = torch.tensor(c, dtype=torch.float32).cuda() l1 = W1(codes[0]) l2 = W2(codes[1], c) l3 = W3(codes[2]) for (g1, g2, g3) in zip(l1, l2, l3): correct, loss = train_clf(args, [g1, g2, g3], data, y) test_acc += correct.item() test_loss += loss.item() q_correct, q_loss = embedding_clf(args, g2, netQ, c) q_test_acc += q_correct.item() q_test_loss += q_loss.item() test_loss /= len(mnist_test.dataset) * args.batch_size test_acc /= len(mnist_test.dataset) * args.batch_size q_test_acc /= len(mnist_test.dataset) * args.batch_size q_test_loss /= len(mnist_test.dataset) * args.batch_size print ('Accuracy: {}, Loss: {}'.format(test_acc, test_loss)) print ('Q Accuracy: {}, Q Loss: {}'.format(q_test_acc, q_test_loss)) if test_loss < best_test_loss or test_acc > best_test_acc: print ('==> new best stats, saving') #utils.save_clf(args, z_test, test_acc) utils.save_hypernet_mnist(args, [netE, W1, W2, W3], test_acc) if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc