def test(args, model, epoch=None, grad=False): model.eval() if args.dataset == 'cifar': _, test_loader = datagen.load_cifar(args) elif args.dataset == 'cifar100': _, test_loader = datagen.load_10_class_cifar100(args) test_loss = 0. correct = 0. criterion = nn.CrossEntropyLoss() for i, (data, target) in enumerate(test_loader): data, target = data.cuda(), target.cuda() output = model(data) if grad is False: test_loss += criterion(output, target).item() else: test_loss += criterion(output, target) pred = output.data.max(1, keepdim=True)[1] output = None correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() test_loss /= len(test_loader.dataset) acc = correct.item() / len(test_loader.dataset) if epoch: print('Epoch: {}, Average loss: {}, Accuracy: {}/{} ({}%)'.format( epoch, test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) return acc, test_loss
def test(args, Z, names, ext): model = weights_to_clf(Z, names) model.eval() if ext == False: _, test_loader = datagen.load_cifar(args) else: test_loader = load_cifar_ext(args) test_loss = 0. correct = 0. criterion = nn.CrossEntropyLoss() for i, (data, target) in enumerate(test_loader): data, target = data.cuda(), target.cuda() data = data.view(100, 3, 32, 32) target = target.view(-1) output = model(data) test_loss += criterion(output, target).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() if ext is False: test_loss /= len(test_loader.dataset) acc = correct.item() / len(test_loader.dataset) else: test_loss /= 2000 acc = correct.item() / 2000 return acc, test_loss
def train(args, model, grad=False): if args.dataset == 'cifar': train_loader, _ = datagen.load_cifar(args) elif args.dataset == 'cifar100': train_loader, _ = datagen.load_10_class_cifar100(args) train_loss, train_acc = 0., 0. criterion = nn.CrossEntropyLoss() if args.ft: for child in list(model.children())[:-1]: # print ('removing {}'.format(child)) for param in child.parameters(): param.requires_grad = False optimizer = optim.Adam(model.parameters(), lr=2e-3, weight_decay=1e-4) #optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) for epoch in range(args.epochs): model.train() total = 0 correct = 0 for i, (data, target) in enumerate(train_loader): data, target = data.cuda(), target.cuda() optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = output.max(1) total += target.size(0) correct += predicted.eq(target).sum().item() acc = 100. * correct / total acc, loss = test(args, model, epoch + 1) return acc, loss
def test_cifar(args, Z, names, arch): _, test_loader = datagen.load_cifar(args) criterion = nn.CrossEntropyLoss() pop_size = args.batch_size with torch.no_grad(): correct = 0. test_loss = 0 for data, target in test_loader: data, target = data.cuda(), target.cuda() outputs = [] for i in range(pop_size): params = [Z[0][i], Z[1][i], Z[2][i], Z[3][i], Z[4][i]] model = weights_to_clf(params, names, arch) output = model(data) outputs.append(output) pop_outputs = torch.stack(outputs) pop_labels = pop_outputs.max(2, keepdim=True)[1].view( pop_size, 100, 1) modes, idxs = torch.mode(pop_labels, dim=0, keepdim=True) modes = modes.view(100, 1) correct += modes.eq(target.data.view_as(modes)).long().cpu().sum() test_loss += criterion(output, target).item() # sum up batch loss test_loss /= len(test_loader.dataset) acc = (correct.float() / len(test_loader.dataset)).item() return acc, test_loss
def __init__(self, args): self.s = args.s self.z = args.z self.batch_size = args.batch_size self.epochs = 200 self.alpha = 1 self.beta = args.beta self.target = args.target self.use_bn = args.use_bn self.bias = args.bias self.n_hidden = args.n_hidden self.pretrain_e = args.pretrain_e self.dataset = args.dataset self.test_ensemble = args.test_ensemble self.test_uncertainty = args.test_uncertainty self.vote = args.vote self.device = torch.device('cuda') torch.manual_seed(8734) self.hypergan = HyperGAN(args, self.device) self.hypergan.print_hypergan() self.hypergan.attach_optimizers(5e-3, 1e-4, 5e-5) if self.dataset == 'mnist': self.data_train, self.data_test = datagen.load_mnist() elif self.dataset == 'cifar': self.data_train, self.data_test = datagen.load_cifar() self.best_test_acc = 0. self.best_test_loss = np.inf
def load_data(args): if args.dataset == 'mnist': return datagen.load_mnist(args) if args.dataset == 'cifar': return datagen.load_cifar(args) if args.dataset == 'fmnist': return datagen.load_fashion_mnist(args) if args.dataset == 'cifar_hidden': class_list = [0] ## just load class 0 return datagen.load_cifar_hidden(args, class_list) else: print ('Dataset not specified correctly') print ('choose --dataset <mnist, fmnist, cifar, cifar_hidden>')
def __init__(self, args): self.lr = args.lr self.wd = args.wd self.epochs = 200 self.dataset = args.dataset self.test_uncertainty = args.test_uncertainty self.vote = args.vote self.device = torch.device('cuda') torch.manual_seed(8734) self.model = models.LeNet_Dropout().to(self.device) self.optimizer = torch.optim.Adam(model.parameters(), self.lr, weight_decay=self.wd) if self.dataset == 'mnist': self.data_train, self.data_test = datagen.load_mnist() elif self.dataset == 'cifar': self.data_train, self.data_test = datagen.load_cifar() self.best_test_acc = 0. self.best_test_loss = np.inf print (self.model)
def __init__(self, args): self.lr = args.lr self.wd = args.wd self.epochs = 200 self.dataset = args.dataset self.test_uncertainty = args.test_uncertainty self.vote = args.vote self.n_models = args.n_models self.device = torch.device('cuda') torch.manual_seed(8734) self.ensemble = [ models.LeNet().to(self.device) for _ in range(self.n_models) ] self.attach_optimizers() if self.dataset == 'mnist': self.data_train, self.data_test = datagen.load_mnist() elif self.dataset == 'cifar': self.data_train, self.data_test = datagen.load_cifar() self.best_test_acc = 0. self.best_test_loss = np.inf print(self.ensemble[0], ' X {}'.format(self.n_models))
def train(args): torch.manual_seed(8734) netE = Encoder(args).cuda() W1 = GeneratorW1(args).cuda() W2 = GeneratorW2(args).cuda() W3 = GeneratorW3(args).cuda() W4 = GeneratorW4(args).cuda() W5 = GeneratorW5(args).cuda() netD = DiscriminatorZ(args).cuda() print(netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-3, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW4 = optim.Adam(W4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW5 = optim.Adam(W5.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc if args.hidden: c_idx = [0, 1, 2, 3, 4] cifar_train, cifar_test = datagen.load_cifar_hidden(args, c_idx) else: cifar_train, cifar_test = datagen.load_cifar(args) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(2000): x = sample_z_like((e_batch_size, args.ze)) z = sample_z_like((e_batch_size, args.z)) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(cifar_train): batch_zero_grad([netE, W1, W2, W3, W4, W5, netD]) z = sample_z_like(( args.batch_size, args.ze, )) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) # Z Adversary free_params([netD]) frozen_params([netE, W1, W2, W3, W4, W5]) for code in codes: noise = sample_z_like((args.batch_size, args.z)) d_real = netD(noise) d_fake = netD(code) d_real_loss = -1 * torch.log((1 - d_real).mean()) d_fake_loss = -1 * torch.log(d_fake.mean()) d_real_loss.backward(retain_graph=True) d_fake_loss.backward(retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() # Generator (Mean test) frozen_params([netD]) free_params([netE, W1, W2, W3, W4, W5]) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data, target) scaled_loss = args.beta * loss scaled_loss.backward(retain_graph=True) optimE.step() optimW1.step() optimW2.step() optimW3.step() optimW4.step() optimW5.step() loss = loss.item() """ Update Statistics """ if batch_idx % 50 == 0: acc = (correct / 1) norm_z1 = np.linalg.norm(l1.data) norm_z2 = np.linalg.norm(l2.data) norm_z3 = np.linalg.norm(l3.data) norm_z4 = np.linalg.norm(l4.data) norm_z5 = np.linalg.norm(l5.data) print('**************************************') print('Mean Test: Enc, Dz, Lscale: {} test'.format(args.beta)) print('Acc: {}, G Loss: {}, D Loss: {}'.format( acc, loss, d_loss)) print('Filter norm: ', norm_z1) print('Filter norm: ', norm_z2) print('Filter norm: ', norm_z3) print('Linear norm: ', norm_z4) print('Linear norm: ', norm_z5) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('**************************************') if batch_idx % 100 == 0: test_acc = 0. test_loss = 0. for i, (data, y) in enumerate(cifar_test): z = sample_z_like(( args.batch_size, args.ze, )) w1_code, w2_code, w3_code, w4_code, w5_code = netE(z) l1 = W1(w1_code) l2 = W2(w2_code) l3 = W3(w3_code) l4 = W4(w4_code) l5 = W5(w5_code) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data, y) test_acc += correct.item() test_loss += loss.item() test_loss /= len(cifar_test.dataset) * args.batch_size test_acc /= len(cifar_test.dataset) * args.batch_size print('Test Accuracy: {}, Test Loss: {}'.format( test_acc, test_loss)) if test_loss < best_test_loss or test_acc > best_test_acc: utils.save_hypernet_cifar(args, [netE, W1, W2, W3, W4, W5, netD], test_acc) print('==> new best stats, saving') if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc
def train(args): torch.manual_seed(8734) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() W4 = models.GeneratorW4(args).cuda() W5 = models.GeneratorW5(args).cuda() netD = models.DiscriminatorZ(args).cuda() print(netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-3, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW4 = optim.Adam(W4.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW5 = optim.Adam(W5.parameters(), lr=5e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=1e-4) m_best_test_acc, m_best_test_loss = 0., np.inf c_best_test_acc, c_best_test_loss = 0., np.inf args.m_best_loss, args.m_best_acc = m_best_test_loss, m_best_test_acc args.c_best_loss, args.c_best_acc = c_best_test_loss, c_best_test_acc mnist_train, mnist_test = datagen.load_mnist(args) cifar_train, cifar_test = datagen.load_cifar(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: mask1 = torch.zeros(e_batch_size, args.ze).cuda() mask2 = torch.ones(e_batch_size, args.ze).cuda() for j in range(500): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) if j % 2 == 0: x = torch.cat((x, mask1), dim=0) if j % 2 == 1: x = torch.cat((x, mask2), dim=0) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (mnist, cifar) in enumerate(zip(mnist_train, cifar_train)): if batch_idx % 2 == 0: data, target = mnist mask = torch.zeros(args.batch_size, args.ze).cuda() else: data, target = cifar mask = torch.ones(args.batch_size, args.ze).cuda() batch_zero_grad([netE, W1, W2, W3, W4, W5, netD]) z = utils.sample_d(x_dist, args.batch_size) z = torch.cat((z, mask), dim=0) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) # Z Adversary for code in codes: noise = utils.sample_d(z_dist, args.batch_size) d_real = netD(noise) d_fake = netD(code) d_real_loss = -1 * torch.log((1 - d_real).mean()) d_fake_loss = -1 * torch.log(d_fake.mean()) d_real_loss.backward(retain_graph=True) d_fake_loss.backward(retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() # Generator (Mean test) netD.zero_grad() z = utils.sample_d(x_dist, args.batch_size) z = torch.cat((z, mask), dim=0) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) d_real = [] for code in codes: d = netD(code) d_real.append(d) netD.zero_grad() d_loss = torch.stack(d_real).log().mean() * 10. for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data, target) scaled_loss = args.beta * loss if loss != loss: sys.exit(0) scaled_loss.backward(retain_graph=True) d_loss.backward(torch.tensor(-1, dtype=torch.float).cuda(), retain_graph=True) optimE.step() optimW1.step() optimW2.step() optimW3.step() optimW4.step() optimW5.step() loss = loss.item() """ Update Statistics """ if batch_idx % 50 == 0 or batch_idx % 50 == 1: acc = correct print('**************************************') if batch_idx % 50 == 0: print('MNIST Test: Enc, Dz, Lscale: {} test'.format( args.beta)) if batch_idx % 50 == 1: print('CIFAR Test: Enc, Dz, Lscale: {} test'.format( args.beta)) print('Acc: {}, G Loss: {}, D Loss: {}'.format( acc, loss, d_loss)) print('best test loss: {}, {}'.format(args.m_best_loss, args.c_best_loss)) print('best test acc: {}, {}'.format(args.m_best_acc, args.c_best_acc)) print('**************************************') if batch_idx > 1 and batch_idx % 100 == 0: m_test_acc = 0. m_test_loss = 0. for i, (data, y) in enumerate(mnist_test): z = utils.sample_d(x_dist, args.batch_size) z = torch.cat( (z, torch.zeros(args.batch_size, args.ze).cuda()), dim=0) w1_code, w2_code, w3_code, w4_code, w5_code = netE(z) l1 = W1(w1_code) l2 = W2(w2_code) l3 = W3(w3_code) l4 = W4(w4_code) l5 = W5(w5_code) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data, y) m_test_acc += correct.item() m_test_loss += loss.item() m_test_loss /= len(mnist_test.dataset) * args.batch_size m_test_acc /= len(mnist_test.dataset) * args.batch_size print('MNIST Test Accuracy: {}, Test Loss: {}'.format( m_test_acc, m_test_loss)) c_test_acc = 0. c_test_loss = 0 for i, (data, y) in enumerate(cifar_test): z = utils.sample_d(x_dist, args.batch_size) z = torch.cat( (z, torch.ones(args.batch_size, args.ze).cuda()), dim=0) w1_code, w2_code, w3_code, w4_code, w5_code = netE(z) l1 = W1(w1_code) l2 = W2(w2_code) l3 = W3(w3_code) l4 = W4(w4_code) l5 = W5(w5_code) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): correct, loss = train_clf(args, [g1, g2, g3, g4, g5], data, y) c_test_acc += correct.item() c_test_loss += loss.item() c_test_loss /= len(cifar_test.dataset) * args.batch_size c_test_acc /= len(cifar_test.dataset) * args.batch_size print('CIFAR Test Accuracy: {}, Test Loss: {}'.format( c_test_acc, c_test_loss)) if m_test_loss < m_best_test_loss or m_test_acc > m_best_test_acc: #utils.save_hypernet_cifar(args, [netE, W1, W2, W3, W4, W5, netD], test_acc) print('==> new best stats, saving') if m_test_loss < m_best_test_loss: m_best_test_loss = m_test_loss args.m_best_loss = m_test_loss if m_test_acc > m_best_test_acc: m_best_test_acc = m_test_acc args.m_best_acc = m_test_acc if c_test_loss < c_best_test_loss or c_test_acc > c_best_test_acc: #utils.save_hypernet_cifar(args, [netE, W1, W2, W3, W4, W5, netD], test_acc) print('==> new best stats, saving') if c_test_loss < c_best_test_loss: c_best_test_loss = c_test_loss args.c_best_loss = c_test_loss if c_test_acc > c_best_test_acc: c_best_test_acc = c_test_acc args.c_best_acc = c_test_acc
def train(args): torch.manual_seed(1) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() W4 = models.GeneratorW4(args).cuda() W5 = models.GeneratorW5(args).cuda() netD = models.DiscriminatorZ(args).cuda() print (netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-3, betas=(0.5, 0.9), weight_decay=5e-4) optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimW4 = optim.Adam(W4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimW5 = optim.Adam(W5.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=5e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc args_best_clf_loss, args.best_clf_acc = np.inf, 0. cifar_train, cifar_test = datagen.load_cifar(args)#_hidden(args, [0, 1, 2, 3, 4]) one = torch.tensor(1).cuda() mone = (one * -1).cuda() print ("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(1000): x = utils.sample_z_like((e_batch_size, args.ze)) z = utils.sample_z_like((e_batch_size, args.z)) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print ('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print ('Finished Pretraining Encoder') break bb = 0 print ('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(cifar_train): utils.batch_zero_grad([netE, W1, W2, W3, W4, W5, netD]) z1 = utils.sample_z_like((args.batch_size, args.z,)) z2 = utils.sample_z_like((args.batch_size, args.z,)) z3 = utils.sample_z_like((args.batch_size, args.z,)) z4 = utils.sample_z_like((args.batch_size, args.z,)) z5 = utils.sample_z_like((args.batch_size, args.z,)) #codes = netE(z) codes = [z1, z2, z3, z4, z5] l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) """ # Z Adversary for code in codes: noise = utils.sample_z_like((args.batch_size, args.z)) d_real = netD(noise) d_fake = netD(code) d_real_loss = -1 * torch.log((1-d_real).mean()) d_fake_loss = -1 * torch.log(d_fake.mean()) d_real_loss.backward(retain_graph=True) d_fake_loss.backward(retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() """ clf_loss = 0 for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, target) clf_loss += loss clf_loss *= args.beta clf_loss.backward() optimE.step(); optimW1.step(); optimW2.step() optimW3.step(); optimW4.step(); optimW5.step() utils.batch_zero_grad([optimE, optimW1, optimW2, optimW3, optimW4, optimW5]) loss = loss.item() """ Update Statistics """ if batch_idx % 50 == 0: bb += 1 acc = correct print ('**************************************') print ("epoch: {}".format(bb)) print ('Mean Test: Enc, Dz, Lscale: {} test'.format(args.beta)) print ('Acc: {}, G Loss: {}, D Loss: {}'.format(acc, loss, 0))# d_loss)) print ('best test loss: {}'.format(args.best_loss)) print ('best test acc: {}'.format(args.best_acc)) print ('best clf acc: {}'.format(best_clf_acc)) print ('**************************************') if batch_idx % 100 == 0: with torch.no_grad(): test_acc = 0. test_loss = 0. for i, (data, y) in enumerate(cifar_test): w1_code = utils.sample_z_like((args.batch_size, args.z,)) w2_code = utils.sample_z_like((args.batch_size, args.z,)) w3_code = utils.sample_z_like((args.batch_size, args.z,)) w4_code = utils.sample_z_like((args.batch_size, args.z,)) w5_code = utils.sample_z_like((args.batch_size, args.z,)) #w1_code, w2_code, w3_code, w4_code, w5_code = netE(z) l1 = W1(w1_code) l2 = W2(w2_code) l3 = W3(w3_code) l4 = W4(w4_code) l5 = W5(w5_code) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, y) test_acc += correct.item() test_loss += loss.item() #clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5]) test_loss /= len(cifar_test.dataset) * args.batch_size test_acc /= len(cifar_test.dataset) * args.batch_size stats.update_logger(l1, l2, l3, l4, l5, logger) stats.update_acc(logger, test_acc) #stats.update_grad(logger, grads, norms) stats.save_logger(logger, args.exp) stats.plot_logger(logger) print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss)) #print ('Clf Accuracy: {}, Clf Loss: {}'.format(clf_acc, clf_loss)) if test_loss < best_test_loss: best_test_loss, args.best_loss = test_loss, test_loss if test_acc > best_test_acc: best_test_acc, args.best_acc = test_acc, test_acc
def train(self): cifar_train, cifar_test = datagen.load_cifar() best_test_acc, best_test_loss = 0., np.inf self.best_loss, self.best_acc = best_test_loss, best_test_acc one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() if self.pretrain_e: print("==> pretraining encoder") self.pretrain_encoder() print('==> Begin Training') for epoch in range(1000): for batch_idx, (data, target) in enumerate(cifar_train): s = torch.randn(self.batch_size, self.s).cuda() codes = self.hypergan.mixer(s) params = self.hypergan.generator(codes) # Z Adversary for code in codes: noise = torch.randn(self.batch_size, self.z).cuda() d_real = self.hypergan.discriminator(noise) d_fake = self.hypergan.discriminator(code) d_real_loss = -1 * torch.log((1 - d_real).mean()) d_fake_loss = -1 * torch.log(d_fake.mean()) d_real_loss.backward(retain_graph=True) d_fake_loss.backward(retain_graph=True) d_loss = d_real_loss + d_fake_loss self.hypergan.optim_disc.step() d_loss = 0 losses, corrects = [], [] for (layers) in zip(*params): correct, loss = self.train_clf(layers, data, target, val=True) losses.append(loss) corrects.append(correct) loss = torch.stack(losses).mean() correct = torch.stack(corrects).mean() scaled_loss = self.beta * loss scaled_loss.backward() self.hypergan.optim_mixer.step() self.hypergan.update_generator() self.hypergan.zero_grad() loss = loss.item() """ Update Statistics """ if batch_idx % 100 == 0: acc = (100 * (correct / self.batch_size)) print('**************************************') print('CIFAR Test, epoch: {}'.format(epoch)) print('Acc: {}, G Loss: {}, D Loss: {}'.format( acc, loss, d_loss)) print('best test loss: {}'.format(self.best_loss)) print('best test acc: {}'.format(self.best_acc)) print('**************************************') with torch.no_grad(): test_acc = 0. test_loss = 0. total_correct = 0. for i, (data, target) in enumerate(cifar_test): z = torch.randn(self.batch_size, self.s).cuda() codes = self.hypergan.mixer(z) params = self.hypergan.generator(codes) losses, corrects = [], [] for (layers) in zip(*params): correct, loss = self.train_clf(layers, data, target, val=True) losses.append(loss) corrects.append(correct) losses = torch.stack(losses) corrects = torch.stack(corrects) test_acc += corrects.mean().item() total_correct += corrects.sum().item() test_loss += losses.mean().item() test_loss /= len(cifar_test.dataset) test_acc /= len(cifar_test.dataset) total_correct /= self.batch_size print('[Epoch {}] Test Loss: {}, Test Accuracy: {}, ({}/{})'. format(epoch, test_loss, test_acc, total_correct, len(cifar_test.dataset))) if test_loss < best_test_loss or test_acc > best_test_acc: print('==> new best stats, saving') if test_loss < best_test_loss: best_test_loss = test_loss self.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc self.best_acc = best_test_acc
def train(args): torch.manual_seed(1) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() W4 = models.GeneratorW4(args).cuda() W5 = models.GeneratorW5(args).cuda() netD = models.DiscriminatorZ(args).cuda() print(netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW1 = optim.Adam(W1.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW2 = optim.Adam(W2.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW3 = optim.Adam(W3.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW4 = optim.Adam(W4.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimW5 = optim.Adam(W5.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=5e-4) best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc args.best_clf_loss, args.best_clf_acc = np.inf, 0. cifar_train, cifar_test = datagen.load_cifar(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) qz_dist = utils.create_d(args.z * 5) one = torch.tensor(1).cuda() mone = one * -1 print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(700): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = torch.stack(netE(x)).view(-1, args.z * 5) qz = utils.sample_d(qz_dist, e_batch_size) mean_loss, cov_loss = ops.pretrain_loss(codes, qz) loss = mean_loss + cov_loss loss.backward() optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(args.epochs): for batch_idx, (data, target) in enumerate(cifar_train): z = utils.sample_d(x_dist, args.batch_size) ze = utils.sample_d(z_dist, args.batch_size) qz = utils.sample_d(qz_dist, args.batch_size) codes = netE(z) noise = utils.sample_d(qz_dist, args.batch_size) log_pz = ops.log_density(ze, 2).view(-1, 1) d_loss, d_q = ops.calc_d_loss(args, netD, ze, codes, log_pz, cifar=True) optimD.zero_grad() d_loss.backward(retain_graph=True) optimD.step() l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) gp, grads, norms = ops.calc_gradient_penalty(z, [W1, W2, W3, W4, W5], netE, cifar=True) reduce = lambda x: x.mean(0).mean(0).item() grads = [reduce(grad) for grad in grads] clf_loss = 0. for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, target) clf_loss += loss G_loss = clf_loss / args.batch_size one_qz = torch.ones((160, 1), requires_grad=True).cuda() log_qz = ops.log_density(torch.ones(160, 1), 2).view(-1, 1) Q_loss = F.binary_cross_entropy_with_logits(d_q + log_qz, one_qz) total_hyper_loss = Q_loss + G_loss total_hyper_loss.backward() optimE.step() optimW1.step() optimW2.step() optimW4.step() optimW5.step() optimE.zero_grad() optimW1.zero_grad() optimW2.zero_grad() optimW3.zero_grad() optimW4.zero_grad() optimW5.zero_grad() total_loss = total_hyper_loss.item() if batch_idx % 50 == 0: acc = correct print('**************************************') print('Acc: {}, MD Loss: {}, D Loss: {}'.format( acc, total_hyper_loss, d_loss)) #print ('penalties: ', [gp[x].item() for x in range(len(gp))]) print('grads: ', grads) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('best clf acc: {}'.format(args.best_clf_acc)) print('**************************************') if batch_idx > 1 and batch_idx % 100 == 0: test_acc = 0. test_loss = 0. with torch.no_grad(): for i, (data, y) in enumerate(cifar_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]) l2 = W2(codes[1]) l3 = W3(codes[2]) l4 = W4(codes[3]) l5 = W5(codes[4]) for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5): loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data, y) test_acc += correct.item() test_loss += loss.item() test_loss /= len(cifar_test.dataset) * args.batch_size test_acc /= len(cifar_test.dataset) * args.batch_size clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5]) stats.update_logger(l1, l2, l3, l4, l5, logger) stats.update_acc(logger, test_acc) stats.update_grad(logger, grads, norms) stats.save_logger(logger, args.exp) stats.plot_logger(logger) print('Test Accuracy: {}, Test Loss: {}'.format( test_acc, test_loss)) print('Clf Accuracy: {}, Clf Loss: {}'.format( clf_acc, clf_loss)) if test_loss < best_test_loss: best_test_loss, args.best_loss = test_loss, test_loss if test_acc > best_test_acc: best_test_acc, args.best_acc = test_acc, test_acc if clf_acc > best_clf_acc: best_clf_acc, args.best_clf_acc = clf_acc, clf_acc utils.save_hypernet_cifar( args, [netE, netD, W1, W2, W3, W4, W5], clf_acc)
def train(args): torch.manual_seed(8734) netG = Generator(args).cuda() netD = Discriminator(args).cuda() print(netG, netD) optimG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) cifar_train, cifar_test = datagen.load_cifar(args) train = inf_gen(cifar_train) print('saving reals') reals, _ = next(train) if not os.path.exists('results/'): os.makedirs('results') save_image(reals, 'results/reals.png') one = torch.tensor(1.).cuda() mone = (one * -1) total_batches = 0 print('==> Begin Training') for iter in range(args.epochs): total_batches += 1 ops.batch_zero_grad([netG, netD]) for p in netD.parameters(): p.requires_grad = True for _ in range(args.disc_iters): data, targets = next(train) netD.zero_grad() d_real = netD(data).mean() d_real.backward(mone, retain_graph=True) noise = torch.randn(args.batch_size, args.z, requires_grad=True).cuda() with torch.no_grad(): fake = netG(noise) fake.requires_grad_(True) d_fake = netD(fake) d_fake = d_fake.mean() d_fake.backward(one, retain_graph=True) gp = ops.grad_penalty_3dim(args, netD, data, fake) gp.backward() d_cost = d_fake - d_real + gp wasserstein_d = d_real - d_fake optimD.step() for p in netD.parameters(): p.requires_grad = False netG.zero_grad() noise = torch.randn(args.batch_size, args.z, requires_grad=True).cuda() fake = netG(noise) G = netD(fake) G = G.mean() G.backward(mone) g_cost = -G optimG.step() if iter % 100 == 0: print('iter: ', iter, 'train D cost', d_cost.cpu().item()) print('iter: ', iter, 'train G cost', g_cost.cpu().item()) if iter % 300 == 0: val_d_costs = [] for i, (data, target) in enumerate(cifar_test): data = data.cuda() d = netD(data) val_d_cost = -d.mean().item() val_d_costs.append(val_d_cost) utils.generate_image(args, iter, netG)
def train(args): torch.manual_seed(8734) netE = models.Encoder(args).cuda() W1 = models.GeneratorW1(args).cuda() W2 = models.GeneratorW2(args).cuda() W3 = models.GeneratorW3(args).cuda() W4 = models.GeneratorW4(args).cuda() W5 = models.GeneratorW5(args).cuda() netD = models.DiscriminatorZ(args).cuda() print(netE, W1, W2, W3, W4, W5, netD) optimE = optim.Adam(netE.parameters(), lr=5e-3, betas=(0.5, 0.9), weight_decay=1e-4) optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW4 = optim.Adam(W4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimW5 = optim.Adam(W5.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4) optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=1e-4) best_test_acc, best_test_loss = 0., np.inf args.best_loss, args.best_acc = best_test_loss, best_test_acc cifar_train, cifar_test = datagen.load_cifar(args) x_dist = utils.create_d(args.ze) z_dist = utils.create_d(args.z) one = torch.FloatTensor([1]).cuda() mone = (one * -1).cuda() print("==> pretraining encoder") j = 0 final = 100. e_batch_size = 1000 if args.pretrain_e: for j in range(700): x = utils.sample_d(x_dist, e_batch_size) z = utils.sample_d(z_dist, e_batch_size) codes = netE(x) for i, code in enumerate(codes): code = code.view(e_batch_size, args.z) mean_loss, cov_loss = ops.pretrain_loss(code, z) loss = mean_loss + cov_loss loss.backward(retain_graph=True) optimE.step() netE.zero_grad() print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format( j, mean_loss.item(), cov_loss.item())) final = loss.item() if loss.item() < 0.1: print('Finished Pretraining Encoder') break print('==> Begin Training') for _ in range(1000): for batch_idx, (data, target) in enumerate(cifar_train): batch_zero_grad([netE, W1, W2, W3, W4, W5, netD]) z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]).mean(0) l2 = W2(codes[1]).mean(0) l3 = W3(codes[2]).mean(0) l4 = W4(codes[3]).mean(0) l5 = W5(codes[4]).mean(0) # Z Adversary free_params([netD]) frozen_params([netE, W1, W2, W3, W4, W5]) for code in codes: noise = utils.sample_d(z_dist, args.batch_size) d_real = netD(noise) d_fake = netD(code) d_real_loss = -1 * torch.log((1 - d_real).mean()) d_fake_loss = -1 * torch.log(d_fake.mean()) d_real_loss.backward(retain_graph=True) d_fake_loss.backward(retain_graph=True) d_loss = d_real_loss + d_fake_loss optimD.step() frozen_params([netD]) free_params([netE, W1, W2, W3, W4, W5]) correct, loss = train_clf(args, [l1, l2, l3, l4, l5], data, target, val=True) scaled_loss = args.beta * loss scaled_loss.backward() optimE.step() optimW1.step() optimW2.step() optimW3.step() optimW4.step() optimW5.step() loss = loss.item() """ Update Statistics """ if batch_idx % 50 == 0: acc = (correct / 1) print('**************************************') print('{} CIFAR Test, beta: {}'.format(args.model, args.beta)) print('Acc: {}, G Loss: {}, D Loss: {}'.format( acc, loss, d_loss)) print('best test loss: {}'.format(args.best_loss)) print('best test acc: {}'.format(args.best_acc)) print('**************************************') if batch_idx > 1 and batch_idx % 199 == 0: test_acc = 0. test_loss = 0. total_correct = 0. for i, (data, y) in enumerate(cifar_test): z = utils.sample_d(x_dist, args.batch_size) codes = netE(z) l1 = W1(codes[0]).mean(0) l2 = W2(codes[1]).mean(0) l3 = W3(codes[2]).mean(0) l4 = W4(codes[3]).mean(0) l5 = W5(codes[4]).mean(0) correct, loss = train_clf(args, [l1, l2, l3, l4, l5], data, y, val=True) test_acc += correct.item() total_correct += correct.item() test_loss += loss.item() test_loss /= len(cifar_test.dataset) test_acc /= len(cifar_test.dataset) print('Test Accuracy: {}, Test Loss: {}, ({}/{})'.format( test_acc, test_loss, total_correct, len(cifar_test.dataset))) if test_loss < best_test_loss or test_acc > best_test_acc: print('==> new best stats, saving') utils.save_clf(args, [l1, l2, l3, l4, l5], test_acc) #utils.save_hypernet_cifar(args, [netE, W1, W2, W3, W4, W5, netD], test_acc) if test_loss < best_test_loss: best_test_loss = test_loss args.best_loss = test_loss if test_acc > best_test_acc: best_test_acc = test_acc args.best_acc = test_acc