示例#1
0
def train():
    args = load_args()
    train_gen, dev_gen, test_gen = utils.dataset_iterator(args)
    torch.manual_seed(1)
    netG, netD, netE = load_models(args)

    # optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    vgg_scale = 0.0784  # 1/12.75
    mse_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda(0)
    mone = (one * -1).cuda(0)

    gen = utils.inf_train_gen(train_gen)
    """ train SRResNet with MSE """
    for iteration in range(1, 20000):
        start_time = time.time()
        # for p in netD.parameters():
        #     p.requires_grad = False
        for p in netE.parameters():
            p.requires_grad = False

        netG.zero_grad()
        _data = next(gen)
        real_data, vgg_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        #Perceptual loss
        #vgg_data_v = autograd.Variable(vgg_data)
        #vgg_features_real = netE(vgg_data_v)
        fake = netG(real_data_v)
        #vgg_features_fake = netE(fake)
        #diff = vgg_features_fake - vgg_features_real.cuda(0)
        #perceptual_loss = vgg_scale * ((diff.pow(2)).sum(3).mean())  # mean(sum(square(diff)))
        #perceptual_loss.backward(one)
        mse_loss = mse_criterion(fake, real_data_v)
        mse_loss.backward(one)
        optimizerG.step()

        save_dir = './plots/' + args.dataset
        plot.plot(save_dir, '/mse cost SRResNet',
                  np.round(mse_loss.data.cpu().numpy(), 4))
        if iteration % 50 == 49:
            utils.generate_sr_image(iteration, netG, save_dir, args,
                                    real_data_v)
        if (iteration < 5) or (iteration % 50 == 49):
            plot.flush()
        plot.tick()
        if iteration % 5000 == 0:
            torch.save(netG.state_dict(), './SRResNet_PL.pt')

    for iteration in range(args.epochs):
        start_time = time.time()
        """ Update AutoEncoder """

        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()
        netE.zero_grad()
        _data = next(gen)
        real_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        encoding = netE(real_data_v)
        fake = netG(encoding)
        ae_loss = ae_criterion(fake, real_data_v)
        ae_loss.backward(one)
        optimizerE.step()
        optimizerG.step()
        """ Update D network """

        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update
        for i in range(5):
            _data = next(gen)
            real_data = stack_data(args, _data)
            real_data_v = autograd.Variable(real_data)
            # train with real data
            netD.zero_grad()
            D_real = netD(real_data_v)
            D_real = D_real.mean()
            D_real.backward(mone)
            # train with fake data
            noise = torch.randn(args.batch_size, args.dim).cuda()
            noisev = autograd.Variable(noise,
                                       volatile=True)  # totally freeze netG
            # instead of noise, use image
            fake = autograd.Variable(netG(real_data_v).data)
            inputv = fake
            D_fake = netD(inputv)
            D_fake = D_fake.mean()
            D_fake.backward(one)

            # train with gradient penalty
            gradient_penalty = ops.calc_gradient_penalty(
                args, netD, real_data_v.data, fake.data)
            gradient_penalty.backward()

            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            optimizerD.step()

        # Update generator network (GAN)
        # noise = torch.randn(args.batch_size, args.dim).cuda()
        # noisev = autograd.Variable(noise)
        _data = next(gen)
        real_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        # again use real data instead of noise
        fake = netG(real_data_v)
        G = netD(fake)
        G = G.mean()
        G.backward(mone)
        G_cost = -G
        optimizerG.step()

        # Write logs and save samples

        save_dir = './plots/' + args.dataset
        plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(),
                                                   4))
        plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(),
                                                  4))
        plot.plot(save_dir, '/w1 distance',
                  np.round(Wasserstein_D.cpu().data.numpy(), 4))
        # plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(), 4))

        # Calculate dev loss and generate samples every 100 iters
        if iteration % 100 == 99:
            dev_disc_costs = []
            for images, _ in dev_gen():
                imgs = stack_data(args, images)
                imgs_v = autograd.Variable(imgs, volatile=True)
                D = netD(imgs_v)
                _dev_disc_cost = -D.mean().cpu().data.numpy()
                dev_disc_costs.append(_dev_disc_cost)
            plot.plot(save_dir, '/dev disc cost',
                      np.round(np.mean(dev_disc_costs), 4))

            # utils.generate_image(iteration, netG, save_dir, args)
            # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v)
            utils.generate_sr_image(iteration, netG, save_dir, args,
                                    real_data_v)
        # Save logs every 100 iters
        if (iteration < 5) or (iteration % 100 == 99):
            plot.flush()
        plot.tick()
示例#2
0
def train():
    args = load_args()
    train_gen, test_gen = load_data(args)
    torch.manual_seed(1)
    netG, netD, netE = load_models(args)

    if args.use_spectral_norm:
        optimizerD = optim.Adam(filter(lambda p: p.requires_grad,
            netD.parameters()), lr=2e-4, betas=(0.0,0.9))
    else:
        optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.9))
    optimizerE = optim.Adam(netE.parameters(), lr=2e-4, betas=(0.5, 0.9))

    schedulerD = optim.lr_scheduler.ExponentialLR(optimizerD, gamma=0.99)
    schedulerG = optim.lr_scheduler.ExponentialLR(optimizerG, gamma=0.99) 
    schedulerE = optim.lr_scheduler.ExponentialLR(optimizerE, gamma=0.99)
    
    ae_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    iteration = 0 
    for epoch in range(args.epochs):
        for i, (data, targets) in enumerate(train_gen):
            start_time = time.time()
            """ Update AutoEncoder """
            for p in netD.parameters():
                p.requires_grad = False
            netG.zero_grad()
            netE.zero_grad()
            real_data_v = autograd.Variable(data).cuda()
            real_data_v = real_data_v.view(args.batch_size, -1)
            encoding = netE(real_data_v)
            fake = netG(encoding)
            ae_loss = ae_criterion(fake, real_data_v)
            ae_loss.backward(one)
            optimizerE.step()
            optimizerG.step()
            
            """ Update D network """
            for p in netD.parameters():  
                p.requires_grad = True 
            for i in range(5):
                real_data_v = autograd.Variable(data).cuda()
                # train with real data
                netD.zero_grad()
                D_real = netD(real_data_v)
                D_real = D_real.mean()
                D_real.backward(mone)
                # train with fake data
                noise = torch.randn(args.batch_size, args.dim).cuda()
                noisev = autograd.Variable(noise, volatile=True)
                fake = autograd.Variable(netG(noisev).data)
                inputv = fake
                D_fake = netD(inputv)
                D_fake = D_fake.mean()
                D_fake.backward(one)

                # train with gradient penalty 
                gradient_penalty = ops.calc_gradient_penalty(args,
                        netD, real_data_v.data, fake.data)
                gradient_penalty.backward()

                D_cost = D_fake - D_real + gradient_penalty
                Wasserstein_D = D_real - D_fake
                optimizerD.step()

            # Update generator network (GAN)
            noise = torch.randn(args.batch_size, args.dim).cuda()
            noisev = autograd.Variable(noise)
            fake = netG(noisev)
            G = netD(fake)
            G = G.mean()
            G.backward(mone)
            G_cost = -G
            optimizerG.step() 

            schedulerD.step()
            schedulerG.step()
            schedulerE.step()
            # Write logs and save samples 
            save_dir = './plots/'+args.dataset
            plot.plot(save_dir, '/disc cost', D_cost.cpu().data.numpy())
            plot.plot(save_dir, '/gen cost', G_cost.cpu().data.numpy())
            plot.plot(save_dir, '/w1 distance', Wasserstein_D.cpu().data.numpy())
            plot.plot(save_dir, '/ae cost', ae_loss.data.cpu().numpy())
            
            # Calculate dev loss and generate samples every 100 iters
            if iteration % 100 == 99:
                dev_disc_costs = []
                for i, (images, targets) in enumerate(test_gen):
                    imgs_v = autograd.Variable(images, volatile=True).cuda()
                    D = netD(imgs_v)
                    _dev_disc_cost = -D.mean().cpu().data.numpy()
                    dev_disc_costs.append(_dev_disc_cost)
                plot.plot(save_dir ,'/dev disc cost', np.mean(dev_disc_costs))
                utils.generate_image(iteration, netG, save_dir, args)
                # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v)

            # Save logs every 100 iters 
            if (iteration < 5) or (iteration % 100 == 99):
                plot.flush()
            plot.tick()
            if iteration % 100 == 0:
                utils.save_model(netG, optimizerG, iteration,
                        'models/{}/G_{}'.format(args.dataset, iteration))
                utils.save_model(netD, optimizerD, iteration, 
                        'models/{}/D_{}'.format(args.dataset, iteration))
            iteration += 1
def train():
    with torch.cuda.device(1):
        args = load_args()
        train_gen, dev_gen, test_gen = utils.dataset_iterator(args)
        torch.manual_seed(1)
        netG = first_layer.FirstG(args).cuda()
        SecondG = second_layer.SecondG(args).cuda()
        SecondE = second_layer.SecondE(args).cuda()

        ThridG = third_layer.ThirdG(args).cuda()
        ThridE = third_layer.ThirdE(args).cuda()
        ThridD = third_layer.ThirdD(args).cuda()

        netG.load_state_dict(torch.load('./1stLayer/1stLayerG71999.model'))
        SecondG.load_state_dict(torch.load('./2ndLayer/2ndLayerG71999.model'))
        SecondE.load_state_dict(torch.load('./2ndLayer/2ndLayerE71999.model'))
        ThridE.load_state_dict(torch.load('./3rdLayer/3rdLayerE10999.model'))
        ThridG.load_state_dict(torch.load('./3rdLayer/3rdLayerG10999.model'))

        optimizerD = optim.Adam(ThridD.parameters(), lr=1e-4, betas=(0.5, 0.9))
        optimizerG = optim.Adam(ThridG.parameters(), lr=1e-4, betas=(0.5, 0.9))
        optimizerE = optim.Adam(ThridE.parameters(), lr=1e-4, betas=(0.5, 0.9))
        ae_criterion = nn.MSELoss()
        one = torch.FloatTensor([1]).cuda()
        mone = (one * -1).cuda()

        dataLoader = BSDDataLoader(args.dataset, args.batch_size, args)

        for iteration in range(args.epochs):
            start_time = time.time()
            """ Update AutoEncoder """
            for p in ThridD.parameters():
                p.requires_grad = False
            ThridG.zero_grad()
            ThridE.zero_grad()
            real_data = dataLoader.getNextHDBatch().cuda()
            real_data_v = autograd.Variable(real_data)
            encoding = ThridE(real_data_v)
            fake = ThridG(encoding)
            ae_loss = ae_criterion(fake, real_data_v)
            ae_loss.backward(one)
            optimizerE.step()
            optimizerG.step()

            """ Update D network """

            for p in ThridD.parameters():
                p.requires_grad = True
            for i in range(5):
                real_data = dataLoader.getNextHDBatch().cuda()
                real_data_v = autograd.Variable(real_data)
                # train with real data
                ThridD.zero_grad()
                D_real = ThridD(real_data_v)
                D_real = D_real.mean()
                D_real.backward(mone)
                # train with fake data
                noise = generateTensor(args.batch_size).cuda()
                noisev = autograd.Variable(noise, volatile=True)
                fake = autograd.Variable(ThridG(ThridE(SecondG(SecondE(netG(noisev, True), True)), True)).data)
                inputv = fake
                D_fake = ThridD(inputv)
                D_fake = D_fake.mean()
                D_fake.backward(one)

                # train with gradient penalty
                gradient_penalty = ops.calc_gradient_penalty(args,
                                                             ThridD, real_data_v.data, fake.data)
                gradient_penalty.backward()
                optimizerD.step()

            # Update generator network (GAN)
            noise = generateTensor(args.batch_size).cuda()
            noisev = autograd.Variable(noise)
            fake = ThridG(ThridE(SecondG(SecondE(netG(noisev, True), True)), True))
            G = ThridD(fake)
            G = G.mean()
            G.backward(mone)
            G_cost = -G
            optimizerG.step()

            # Write logs and save samples
            save_dir = './plots/' + args.dataset

            # Calculate dev loss and generate samples every 100 iters
            if iteration % 1000 == 999:
                torch.save(ThridE.state_dict(), './3rdLayer/3rdLayerE%d.model' % iteration)
                torch.save(ThridG.state_dict(), './3rdLayer/3rdLayerG%d.model' % iteration)
                utils.generate_image(iteration, netG, save_dir, args)
                utils.generate_MidImage(iteration, netG, SecondE, SecondG, save_dir, args)
                utils.generate_HDImage(iteration, netG, SecondE, SecondG, ThridE, ThridG, save_dir, args)

            if iteration % 2000 == 1999:
                noise = generateTensor(args.batch_size).cuda()
                noisev = autograd.Variable(noise, volatile=True)
                fake = autograd.Variable(ThridG(ThridE(SecondG(SecondE(netG(noisev, True), True)), True)).data)
                print(inception_score(fake.data.cpu().numpy(), resize=True, batch_size=5)[0])

            endtime = time.time()
            print('iter:', iteration, 'total time %4f' % (endtime-start_time), 'ae loss %4f' % ae_loss.data[0],
                            'G cost %4f' % G_cost.data[0])
示例#4
0
def train(args):

    torch.manual_seed(1)
    netE = models.Encoder(args).cuda()
    W1 = models.GeneratorW1(args).cuda()
    W2 = models.GeneratorW2(args).cuda()
    W3 = models.GeneratorW3(args).cuda()
    W4 = models.GeneratorW4(args).cuda()
    W5 = models.GeneratorW5(args).cuda()
    netD = models.DiscriminatorZ(args).cuda()
    print(netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-5,
                        betas=(0.5, 0.9),
                        weight_decay=5e-4)
    optimW1 = optim.Adam(W1.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW2 = optim.Adam(W2.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW3 = optim.Adam(W3.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW4 = optim.Adam(W4.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW5 = optim.Adam(W5.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=5e-5,
                        betas=(0.5, 0.9),
                        weight_decay=5e-4)

    best_test_acc, best_clf_acc, best_test_loss = 0., 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc
    args.best_clf_loss, args.best_clf_acc = np.inf, 0.

    cifar_train, cifar_test = datagen.load_cifar(args)
    x_dist = utils.create_d(args.ze)
    z_dist = utils.create_d(args.z)
    qz_dist = utils.create_d(args.z * 5)

    one = torch.tensor(1).cuda()
    mone = one * -1
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e:
        for j in range(700):
            x = utils.sample_d(x_dist, e_batch_size)
            z = utils.sample_d(z_dist, e_batch_size)
            codes = torch.stack(netE(x)).view(-1, args.z * 5)
            qz = utils.sample_d(qz_dist, e_batch_size)
            mean_loss, cov_loss = ops.pretrain_loss(codes, qz)
            loss = mean_loss + cov_loss
            loss.backward()
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(cifar_train):
            z = utils.sample_d(x_dist, args.batch_size)
            ze = utils.sample_d(z_dist, args.batch_size)
            qz = utils.sample_d(qz_dist, args.batch_size)
            codes = netE(z)
            noise = utils.sample_d(qz_dist, args.batch_size)
            log_pz = ops.log_density(ze, 2).view(-1, 1)
            d_loss, d_q = ops.calc_d_loss(args,
                                          netD,
                                          ze,
                                          codes,
                                          log_pz,
                                          cifar=True)
            optimD.zero_grad()
            d_loss.backward(retain_graph=True)
            optimD.step()

            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])
            l4 = W4(codes[3])
            l5 = W5(codes[4])

            gp, grads, norms = ops.calc_gradient_penalty(z,
                                                         [W1, W2, W3, W4, W5],
                                                         netE,
                                                         cifar=True)
            reduce = lambda x: x.mean(0).mean(0).item()
            grads = [reduce(grad) for grad in grads]
            clf_loss = 0.
            for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                loss, correct = train_clf(args, [g1, g2, g3, g4, g5], data,
                                          target)
                clf_loss += loss
            G_loss = clf_loss / args.batch_size
            one_qz = torch.ones((160, 1), requires_grad=True).cuda()
            log_qz = ops.log_density(torch.ones(160, 1), 2).view(-1, 1)
            Q_loss = F.binary_cross_entropy_with_logits(d_q + log_qz, one_qz)
            total_hyper_loss = Q_loss + G_loss
            total_hyper_loss.backward()

            optimE.step()
            optimW1.step()
            optimW2.step()
            optimW4.step()
            optimW5.step()
            optimE.zero_grad()
            optimW1.zero_grad()
            optimW2.zero_grad()
            optimW3.zero_grad()
            optimW4.zero_grad()
            optimW5.zero_grad()

            total_loss = total_hyper_loss.item()

            if batch_idx % 50 == 0:
                acc = correct
                print('**************************************')
                print('Acc: {}, MD Loss: {}, D Loss: {}'.format(
                    acc, total_hyper_loss, d_loss))
                #print ('penalties: ', [gp[x].item() for x in range(len(gp))])
                print('grads: ', grads)
                print('best test loss: {}'.format(args.best_loss))
                print('best test acc: {}'.format(args.best_acc))
                print('best clf acc: {}'.format(args.best_clf_acc))
                print('**************************************')

            if batch_idx > 1 and batch_idx % 100 == 0:
                test_acc = 0.
                test_loss = 0.
                with torch.no_grad():
                    for i, (data, y) in enumerate(cifar_test):
                        z = utils.sample_d(x_dist, args.batch_size)
                        codes = netE(z)
                        l1 = W1(codes[0])
                        l2 = W2(codes[1])
                        l3 = W3(codes[2])
                        l4 = W4(codes[3])
                        l5 = W5(codes[4])
                        for (g1, g2, g3, g4, g5) in zip(l1, l2, l3, l4, l5):
                            loss, correct = train_clf(args,
                                                      [g1, g2, g3, g4, g5],
                                                      data, y)
                            test_acc += correct.item()
                            test_loss += loss.item()
                    test_loss /= len(cifar_test.dataset) * args.batch_size
                    test_acc /= len(cifar_test.dataset) * args.batch_size
                    clf_acc, clf_loss = test_clf(args, [l1, l2, l3, l4, l5])
                    stats.update_logger(l1, l2, l3, l4, l5, logger)
                    stats.update_acc(logger, test_acc)
                    stats.update_grad(logger, grads, norms)
                    stats.save_logger(logger, args.exp)
                    stats.plot_logger(logger)

                    print('Test Accuracy: {}, Test Loss: {}'.format(
                        test_acc, test_loss))
                    print('Clf Accuracy: {}, Clf Loss: {}'.format(
                        clf_acc, clf_loss))
                    if test_loss < best_test_loss:
                        best_test_loss, args.best_loss = test_loss, test_loss
                    if test_acc > best_test_acc:
                        best_test_acc, args.best_acc = test_acc, test_acc
                    if clf_acc > best_clf_acc:
                        best_clf_acc, args.best_clf_acc = clf_acc, clf_acc
                        utils.save_hypernet_cifar(
                            args, [netE, netD, W1, W2, W3, W4, W5], clf_acc)
示例#5
0
def train():
    args = load_args()
    train_gen, dev_gen, test_gen = utils.dataset_iterator(args)
    torch.manual_seed(1)
    np.set_printoptions(precision=4)
    netG, netD, netE = load_models(args)

    optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerE = optim.Adam(netE.parameters(), lr=1e-4, betas=(0.5, 0.9))
    ae_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()

    gen = utils.inf_train_gen(train_gen)

    preprocess = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    for iteration in range(args.epochs):
        start_time = time.time()
        """ Update AutoEncoder """
        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()
        netE.zero_grad()
        _data = next(gen)
        real_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        encoding = netE(real_data_v)
        fake = netG(encoding)
        ae_loss = ae_criterion(fake, real_data_v)
        ae_loss.backward(one)
        optimizerE.step()
        optimizerG.step()
        """ Update D network """

        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update
        for i in range(5):
            _data = next(gen)
            real_data = stack_data(args, _data)
            real_data_v = autograd.Variable(real_data)
            # train with real data
            netD.zero_grad()
            D_real = netD(real_data_v)
            D_real = D_real.mean()
            D_real.backward(mone)
            # train with fake data
            noise = torch.randn(args.batch_size, args.dim).cuda()
            noisev = autograd.Variable(noise,
                                       volatile=True)  # totally freeze netG
            fake = autograd.Variable(netG(noisev).data)
            inputv = fake
            D_fake = netD(inputv)
            D_fake = D_fake.mean()
            D_fake.backward(one)

            # train with gradient penalty
            gradient_penalty = ops.calc_gradient_penalty(
                args, netD, real_data_v.data, fake.data)
            gradient_penalty.backward()

            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            optimizerD.step()

        # Update generator network (GAN)
        noise = torch.randn(args.batch_size, args.dim).cuda()
        noisev = autograd.Variable(noise)
        fake = netG(noisev)
        G = netD(fake)
        G = G.mean()
        G.backward(mone)
        G_cost = -G
        optimizerG.step()

        # Write logs and save samples

        save_dir = './plots/' + args.dataset
        plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(),
                                                   4))
        plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(),
                                                  4))
        plot.plot(save_dir, '/w1 distance',
                  np.round(Wasserstein_D.cpu().data.numpy(), 4))
        plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(),
                                                 4))

        # Calculate dev loss and generate samples every 100 iters
        if iteration % 100 == 99:
            dev_disc_costs = []
            for images, _ in dev_gen():
                imgs = stack_data(args, images)
                imgs_v = autograd.Variable(imgs, volatile=True)
                D = netD(imgs_v)
                _dev_disc_cost = -D.mean().cpu().data.numpy()
                dev_disc_costs.append(_dev_disc_cost)
            plot.plot(save_dir, '/dev disc cost',
                      np.round(np.mean(dev_disc_costs), 4))

            # utils.generate_image(iteration, netG, save_dir, args)
            utils.generate_ae_image(iteration, netE, netG, save_dir, args,
                                    real_data_v)
        # Save logs every 100 iters
        if (iteration < 5) or (iteration % 100 == 99):
            plot.flush()
        plot.tick()
def train(args):

    torch.manual_seed(1)
    netE = models.Encoderz(args).cuda()
    W1 = models.GeneratorW1(args).cuda()
    W2 = models.GeneratorW2(args).cuda()
    W3 = models.GeneratorW3(args).cuda()
    netD = models.DiscriminatorQz(args).cuda()
    print(netE, W1, W2, W3)

    optimE = optim.Adam(netE.parameters(),
                        lr=5e-5,
                        betas=(0.5, 0.9),
                        weight_decay=5e-4)
    optimW1 = optim.Adam(W1.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW2 = optim.Adam(W2.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimW3 = optim.Adam(W3.parameters(),
                         lr=5e-5,
                         betas=(0.5, 0.9),
                         weight_decay=5e-4)
    optimD = optim.Adam(netD.parameters(),
                        lr=5e-4,
                        betas=(0.5, 0.9),
                        weight_decay=5e-4)

    best_test_acc, best_clf_acc, best_test_loss, = 0., 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc
    args.best_clf_loss, args.best_clf_acc = np.inf, 0

    mnist_train, mnist_test = datagen.load_mnist(args)
    x_dist = utils.create_d(args.ze)
    z_dist = utils.create_d(args.z)
    qz_dist = utils.create_d(args.z * 3)
    u_dist = utils.create_uniform()
    one = torch.tensor(1.).cuda()
    mone = one * -1
    print("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.pretrain_e is True:
        for j in range(1000):
            x = utils.sample_uniform(u_dist, (e_batch_size, args.ze))
            z = utils.sample_uniform(u_dist, (e_batch_size, args.z))
            codes = torch.stack(netE(x)).view(-1, args.z * 3)
            qz = utils.sample_uniform(u_dist, (e_batch_size, args.z * 3))
            mean_loss, cov_loss = ops.pretrain_loss(codes, qz)
            loss = mean_loss + cov_loss
            loss.backward()
            optimE.step()
            netE.zero_grad()
            print('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print('Finished Pretraining Encoder')
                break

    print('==> Begin Training')
    for _ in range(args.epochs):
        for batch_idx, (data, target) in enumerate(mnist_train):
            z = utils.sample_uniform(u_dist, (args.batch_size, args.ze))
            ze = utils.sample_uniform(u_dist, (args.batch_size, args.z))
            qz = utils.sample_uniform(u_dist, (args.batch_size, args.z * 3))
            codes = netE(z)
            noise = utils.sample_uniform(u_dist, (args.batch_size, args.z * 3))
            log_pz = ops.log_density(ze, 2).view(-1, 1)
            d_loss, d_q = ops.calc_d_loss(args, netD, ze, codes, log_pz)
            optimD.zero_grad()
            d_loss.backward(retain_graph=True)
            optimD.step()

            l1 = W1(codes[0])
            l2 = W2(codes[1])
            l3 = W3(codes[2])

            gp, grads, norms = ops.calc_gradient_penalty(z, [W1, W2, W3], netE)
            reduce = lambda x: x.mean(0).mean(0).item()
            grads = reduce(grads[0]), reduce(grads[1]), reduce(grads[2])
            clf_loss = 0
            for (g1, g2, g3) in zip(l1, l2, l3):
                loss, correct = train_clf(args, [g1, g2, g3], data, target)
                clf_loss += loss
            G_loss = clf_loss / args.batch_size  # * args.beta
            one_qz = torch.ones((args.batch_size * 3, 1),
                                requires_grad=True).cuda()
            log_qz = ops.log_density(torch.ones(args.batch_size * 3, 1),
                                     2).view(-1, 1)
            Q_loss = F.binary_cross_entropy_with_logits(d_q + log_qz, one_qz)
            total_hyper_loss = Q_loss + G_loss  #+ (gp.sum().cuda())#mean().cuda()
            total_hyper_loss.backward()

            optimE.step()
            optimW1.step()
            optimW2.step()
            optimW3.step()
            optimE.zero_grad()
            optimW1.zero_grad(), optimW2.zero_grad(), optimW3.zero_grad()

            total_loss = total_hyper_loss.item()

            if batch_idx % 50 == 0:
                acc = correct
                print('**************************************')
                print('Iter: {}'.format(len(logger['acc'])))
                print('Acc: {}, MD Loss: {}, D loss: {}'.format(
                    acc, total_hyper_loss, d_loss))
                print('penalties: ', gp[0].item(), gp[1].item(), gp[2].item())
                print('grads: ', grads)
                print('best test loss: {}'.format(args.best_loss))
                print('best test acc: {}'.format(args.best_acc))
                print('best clf acc: {}'.format(args.best_clf_acc))
                print('**************************************')

            if batch_idx > 1 and batch_idx % 100 == 0:
                test_acc = 0.
                test_loss = 0.
                with torch.no_grad():
                    for i, (data, y) in enumerate(mnist_test):
                        z = utils.sample_uniform(u_dist,
                                                 (args.batch_size, args.ze))
                        codes = netE(z)
                        l1 = W1(codes[0])
                        l2 = W2(codes[1])
                        l3 = W3(codes[2])
                        for (g1, g2, g3) in zip(l1, l2, l3):
                            loss, correct = train_clf(args, [g1, g2, g3], data,
                                                      y)
                            test_acc += correct.item()
                            test_loss += loss.item()
                    test_loss /= len(mnist_test.dataset) * args.batch_size
                    test_acc /= len(mnist_test.dataset) * args.batch_size
                    clf_acc, clf_loss = test_clf(args, [l1, l2, l3])
                    stats.update_logger(l1, l2, l3, logger)
                    stats.update_acc(logger, test_acc)
                    #stats.update_grad(logger, grads, norms)
                    #stats.save_logger(logger, args.exp)
                    #stats.plot_logger(logger)

                    print('Test Accuracy: {}, Test Loss: {}'.format(
                        test_acc, test_loss))
                    print('Clf Accuracy: {}, Clf Loss: {}'.format(
                        clf_acc, clf_loss))
                    if test_loss < best_test_loss:
                        best_test_loss, args.best_loss = test_loss, test_loss
                    if test_acc > best_test_acc:
                        best_test_acc, args.best_acc = test_acc, test_acc
                    if clf_acc > best_clf_acc:
                        best_clf_acc, args.best_clf_acc = clf_acc, clf_acc
                        utils.save_hypernet_mnist(args,
                                                  [netE, netD, W1, W2, W3],
                                                  clf_acc)
def train():
    args = load_args()
    torch.manual_seed(1)
    netG = first_layer.FirstG(args).cuda()
    netD = first_layer.FirstD(args).cuda()
    netE = first_layer.FirstE(args).cuda()

    optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerE = optim.Adam(netE.parameters(), lr=1e-4, betas=(0.5, 0.9))
    ae_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()

    dataLoader = BSDDataLoader(args.dataset, args.batch_size, args)
    incep_score = 0
    zeros = autograd.Variable(torch.zeros(args.batch_size, 4 * 4 * 5).cuda())

    for iteration in range(args.epochs):
        start_time = time.time()
        """ Update AutoEncoder """
        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()
        netE.zero_grad()
        real_data = dataLoader.getNextLoBatch().cuda()
        real_data_v = autograd.Variable(real_data)
        encoding = netE(real_data_v)
        fake = netG(encoding)
        ae_loss = ae_criterion(fake, real_data_v) + ae_criterion(
            encoding, zeros)
        ae_loss.backward(one)
        optimizerE.step()
        optimizerG.step()
        """ Update D network """

        for p in netD.parameters():
            p.requires_grad = True
        for i in range(5):
            real_data = dataLoader.getNextLoBatch().cuda()
            real_data_v = autograd.Variable(real_data)
            # train with real data
            netD.zero_grad()
            D_real = netD(real_data_v)
            D_real = D_real.mean()
            D_real.backward(mone)
            # train with fake data
            noise = generateTensor(args.batch_size).cuda()
            noisev = autograd.Variable(noise, volatile=True)
            fake = autograd.Variable(netG(noisev, True).data)
            inputv = fake
            D_fake = netD(inputv)
            D_fake = D_fake.mean()
            D_fake.backward(one)

            # train with gradient penalty
            gradient_penalty = ops.calc_gradient_penalty(
                args, netD, real_data_v.data, fake.data)
            gradient_penalty.backward()
            optimizerD.step()

        # Update generator network (GAN)
        noise = generateTensor(args.batch_size).cuda()
        noisev = autograd.Variable(noise)
        fake = netG(noisev, True)
        G = netD(fake)
        G = G.mean()
        G.backward(mone)
        G_cost = -G
        optimizerG.step()

        # Write logs and save samples
        save_dir = './plots/' + args.dataset

        # Calculate dev loss and generate samples every 100 iters
        if iteration % 1000 == 999:
            torch.save(netE.state_dict(),
                       './1stLayer/1stLayerE%d.model' % iteration)
            torch.save(netG.state_dict(),
                       './1stLayer/1stLayerG%d.model' % iteration)
            utils.generate_image(iteration, netG, save_dir, args)
        endtime = time.time()

        if iteration % 2000 == 1999:
            noise = generateTensor(1000).cuda()
            noisev = autograd.Variable(noise, volatile=True)
            fake = autograd.Variable(netG(noisev, True).data)
            incep_score = (inception_score(fake.data.cpu().numpy(),
                                           resize=True,
                                           batch_size=5))[0]

        print('iter:', iteration, 'total time %4f' % (endtime - start_time),
              'ae loss %4f' % ae_loss.data[0], 'G cost %4f' % G_cost.data[0],
              'inception score %4f' % incep_score)