Example #1
0
    def __init__(self, dataType='BSD', batch_size=10, args=None):
        self.dataType = dataType
        if dataType == 'BSD':
            self.dataPath = './dataset/'
            self.imgList = os.listdir(self.dataPath)
            self.batchSize = batch_size
            self.len = len(self.imgList)
            self.loimgs = torch.zeros((300, 3, 32, 32))
            self.midImgs = torch.zeros((300, 3, 64, 64))
            self.HDImgs = torch.zeros((300, 3, 128, 128))
            self.iter = 0
            preprocess = torchTrans.Compose([
                torchTrans.ToTensor(),
                torchTrans.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            for i in range(self.len):
                imgH = cv2.resize(
                    mpimg.imread(self.dataPath + self.imgList[i + self.iter]),
                    (128, 128))[:, :, 0:-1]
                imgM = cv2.resize(imgH, (64, 64))
                img = cv2.resize(imgM, (32, 32))
                imgH = preprocess(imgH)
                imgM = preprocess(imgM)
                img = preprocess(img)
                self.loimgs[i, :, :, :] = img
                self.midImgs[i, :, :, :] = imgM
                self.HDImgs[i, :, :, :] = imgH
        elif dataType == 'CIFAR':
            train_gen, dev_gen, test_gen = utils.dataset_iterator(args)
            self.batchSize = batch_size
            self.gen = utils.inf_train_gen(train_gen)
        elif dataType == 'PASCAL':
            self.dataPath = './VOCdevkit/VOC2012/'
            self.imgList = []
            for line in open(self.dataPath + 'ImageSets/Main/trainval.txt'):
                self.imgList.append(line[0:-1])

            self.batchSize = batch_size
            self.len = len(self.imgList)
            self.loimgs = torch.zeros((self.len, 3, 32, 32))
            self.midImgs = torch.zeros((self.len, 3, 64, 64))
            self.HDImgs = torch.zeros((self.len, 3, 128, 128))
            self.iter = 0
            preprocess = torchTrans.Compose([
                torchTrans.ToTensor(),
                torchTrans.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            for i in range(self.len):
                imgH = cv2.resize(
                    mpimg.imread(self.dataPath + 'JPEGImages/' +
                                 self.imgList[i + self.iter] + '.jpg'),
                    (128, 128))
                imgM = cv2.resize(imgH, (64, 64))
                img = cv2.resize(imgH, (32, 32))
                imgH = preprocess(imgH)
                imgM = preprocess(imgM)
                img = preprocess(img)
                self.loimgs[i, :, :, :] = img
                self.midImgs[i, :, :, :] = imgM
                self.HDImgs[i, :, :, :] = imgH
def train():
    args = load_args()
    train_gen, dev_gen, test_gen = utils.dataset_iterator(args)
    torch.manual_seed(1)
    netG, netD, netE = load_models(args)

    # optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    vgg_scale = 0.0784  # 1/12.75
    mse_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda(0)
    mone = (one * -1).cuda(0)

    gen = utils.inf_train_gen(train_gen)
    """ train SRResNet with MSE """
    for iteration in range(1, 20000):
        start_time = time.time()
        # for p in netD.parameters():
        #     p.requires_grad = False
        for p in netE.parameters():
            p.requires_grad = False

        netG.zero_grad()
        _data = next(gen)
        real_data, vgg_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        #Perceptual loss
        #vgg_data_v = autograd.Variable(vgg_data)
        #vgg_features_real = netE(vgg_data_v)
        fake = netG(real_data_v)
        #vgg_features_fake = netE(fake)
        #diff = vgg_features_fake - vgg_features_real.cuda(0)
        #perceptual_loss = vgg_scale * ((diff.pow(2)).sum(3).mean())  # mean(sum(square(diff)))
        #perceptual_loss.backward(one)
        mse_loss = mse_criterion(fake, real_data_v)
        mse_loss.backward(one)
        optimizerG.step()

        save_dir = './plots/' + args.dataset
        plot.plot(save_dir, '/mse cost SRResNet',
                  np.round(mse_loss.data.cpu().numpy(), 4))
        if iteration % 50 == 49:
            utils.generate_sr_image(iteration, netG, save_dir, args,
                                    real_data_v)
        if (iteration < 5) or (iteration % 50 == 49):
            plot.flush()
        plot.tick()
        if iteration % 5000 == 0:
            torch.save(netG.state_dict(), './SRResNet_PL.pt')

    for iteration in range(args.epochs):
        start_time = time.time()
        """ Update AutoEncoder """

        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()
        netE.zero_grad()
        _data = next(gen)
        real_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        encoding = netE(real_data_v)
        fake = netG(encoding)
        ae_loss = ae_criterion(fake, real_data_v)
        ae_loss.backward(one)
        optimizerE.step()
        optimizerG.step()
        """ Update D network """

        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update
        for i in range(5):
            _data = next(gen)
            real_data = stack_data(args, _data)
            real_data_v = autograd.Variable(real_data)
            # train with real data
            netD.zero_grad()
            D_real = netD(real_data_v)
            D_real = D_real.mean()
            D_real.backward(mone)
            # train with fake data
            noise = torch.randn(args.batch_size, args.dim).cuda()
            noisev = autograd.Variable(noise,
                                       volatile=True)  # totally freeze netG
            # instead of noise, use image
            fake = autograd.Variable(netG(real_data_v).data)
            inputv = fake
            D_fake = netD(inputv)
            D_fake = D_fake.mean()
            D_fake.backward(one)

            # train with gradient penalty
            gradient_penalty = ops.calc_gradient_penalty(
                args, netD, real_data_v.data, fake.data)
            gradient_penalty.backward()

            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            optimizerD.step()

        # Update generator network (GAN)
        # noise = torch.randn(args.batch_size, args.dim).cuda()
        # noisev = autograd.Variable(noise)
        _data = next(gen)
        real_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        # again use real data instead of noise
        fake = netG(real_data_v)
        G = netD(fake)
        G = G.mean()
        G.backward(mone)
        G_cost = -G
        optimizerG.step()

        # Write logs and save samples

        save_dir = './plots/' + args.dataset
        plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(),
                                                   4))
        plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(),
                                                  4))
        plot.plot(save_dir, '/w1 distance',
                  np.round(Wasserstein_D.cpu().data.numpy(), 4))
        # plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(), 4))

        # Calculate dev loss and generate samples every 100 iters
        if iteration % 100 == 99:
            dev_disc_costs = []
            for images, _ in dev_gen():
                imgs = stack_data(args, images)
                imgs_v = autograd.Variable(imgs, volatile=True)
                D = netD(imgs_v)
                _dev_disc_cost = -D.mean().cpu().data.numpy()
                dev_disc_costs.append(_dev_disc_cost)
            plot.plot(save_dir, '/dev disc cost',
                      np.round(np.mean(dev_disc_costs), 4))

            # utils.generate_image(iteration, netG, save_dir, args)
            # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v)
            utils.generate_sr_image(iteration, netG, save_dir, args,
                                    real_data_v)
        # Save logs every 100 iters
        if (iteration < 5) or (iteration % 100 == 99):
            plot.flush()
        plot.tick()
Example #3
0
def train():
    args = load_args()
    train_gen, dev_gen, test_gen = utils.dataset_iterator(args)
    torch.manual_seed(1)
    np.set_printoptions(precision=4)
    netG, netD, netE = load_models(args)

    optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))
    optimizerE = optim.Adam(netE.parameters(), lr=1e-4, betas=(0.5, 0.9))
    ae_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()

    gen = utils.inf_train_gen(train_gen)

    preprocess = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    for iteration in range(args.epochs):
        start_time = time.time()
        """ Update AutoEncoder """
        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()
        netE.zero_grad()
        _data = next(gen)
        real_data = stack_data(args, _data)
        real_data_v = autograd.Variable(real_data)
        encoding = netE(real_data_v)
        fake = netG(encoding)
        ae_loss = ae_criterion(fake, real_data_v)
        ae_loss.backward(one)
        optimizerE.step()
        optimizerG.step()
        """ Update D network """

        for p in netD.parameters():  # reset requires_grad
            p.requires_grad = True  # they are set to False below in netG update
        for i in range(5):
            _data = next(gen)
            real_data = stack_data(args, _data)
            real_data_v = autograd.Variable(real_data)
            # train with real data
            netD.zero_grad()
            D_real = netD(real_data_v)
            D_real = D_real.mean()
            D_real.backward(mone)
            # train with fake data
            noise = torch.randn(args.batch_size, args.dim).cuda()
            noisev = autograd.Variable(noise,
                                       volatile=True)  # totally freeze netG
            fake = autograd.Variable(netG(noisev).data)
            inputv = fake
            D_fake = netD(inputv)
            D_fake = D_fake.mean()
            D_fake.backward(one)

            # train with gradient penalty
            gradient_penalty = ops.calc_gradient_penalty(
                args, netD, real_data_v.data, fake.data)
            gradient_penalty.backward()

            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            optimizerD.step()

        # Update generator network (GAN)
        noise = torch.randn(args.batch_size, args.dim).cuda()
        noisev = autograd.Variable(noise)
        fake = netG(noisev)
        G = netD(fake)
        G = G.mean()
        G.backward(mone)
        G_cost = -G
        optimizerG.step()

        # Write logs and save samples

        save_dir = './plots/' + args.dataset
        plot.plot(save_dir, '/disc cost', np.round(D_cost.cpu().data.numpy(),
                                                   4))
        plot.plot(save_dir, '/gen cost', np.round(G_cost.cpu().data.numpy(),
                                                  4))
        plot.plot(save_dir, '/w1 distance',
                  np.round(Wasserstein_D.cpu().data.numpy(), 4))
        plot.plot(save_dir, '/ae cost', np.round(ae_loss.data.cpu().numpy(),
                                                 4))

        # Calculate dev loss and generate samples every 100 iters
        if iteration % 100 == 99:
            dev_disc_costs = []
            for images, _ in dev_gen():
                imgs = stack_data(args, images)
                imgs_v = autograd.Variable(imgs, volatile=True)
                D = netD(imgs_v)
                _dev_disc_cost = -D.mean().cpu().data.numpy()
                dev_disc_costs.append(_dev_disc_cost)
            plot.plot(save_dir, '/dev disc cost',
                      np.round(np.mean(dev_disc_costs), 4))

            # utils.generate_image(iteration, netG, save_dir, args)
            utils.generate_ae_image(iteration, netE, netG, save_dir, args,
                                    real_data_v)
        # Save logs every 100 iters
        if (iteration < 5) or (iteration % 100 == 99):
            plot.flush()
        plot.tick()
def train(args):
    
    torch.manual_seed(8734)
    
    netE = Encoder(args).cuda()
    W1 = GeneratorW1(args).cuda()
    W2 = GeneratorW2(args).cuda()
    W3 = GeneratorW3(args).cuda()
    W4 = GeneratorW4(args).cuda()
    W5 = GeneratorW5(args).cuda()
    netD = DiscriminatorZ(args).cuda()
    print (netE, W1, W2, W3, W4, W5, netD)

    optimE = optim.Adam(netE.parameters(), lr=0.005, betas=(0.5, 0.9), weight_decay=1e-4)
    optimW1 = optim.Adam(W1.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimW2 = optim.Adam(W2.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimW3 = optim.Adam(W3.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimW4 = optim.Adam(W4.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimW5 = optim.Adam(W5.parameters(), lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimD = optim.Adam(netD.parameters(), lr=5e-5, betas=(0.5, 0.9), weight_decay=1e-4)
    
    best_test_acc, best_test_loss = 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc
    if args.resume:
        netE, optimE, stats = load_model(args, netE, optimE, 'single_code1')
        W1, optimW1, stats = load_model(args, W1, optimrW1, 'single_code1')
        W2, optimW2, stats = load_model(args, W2, optimW2, 'single_code1')
        W3, optimW3, stats = load_model(args, W3, optimW3, 'single_code1')
        W4, optimW3, stats = load_model(args, W4, optimW4, 'single_code1')
        W4, optimW3, stats = load_model(args, W5, optimW5, 'single_code1')
        netD, optimD, stats = load_model(args, netD, optimD, 'single_code1')
        best_test_acc, best_test_loss = stats
        print ('==> resuming models at ', stats)

    cifar_train, cifar_test = load_cifar(args)
    if args.use_x:
        base_gen = datagen.load(args)
        w1_gen = utils.inf_train_gen(base_gen[0])
        w2_gen = utils.inf_train_gen(base_gen[1])
        w3_gen = utils.inf_train_gen(base_gen[2])
        w4_gen = utils.inf_train_gen(base_gen[3])
        w5_gen = utils.inf_train_gen(base_gen[4])

    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    if args.use_x:
        X = sample_x(args, [w1_gen, w2_gen, w3_gen, w4_gen, w5_gen], 0)
        X = list(map(lambda x: (x+1e-10).float(), X))

    print ("==> pretraining encoder")
    j = 0
    final = 100.
    e_batch_size = 1000
    if args.load_e:
        netE, optimE, _ = utils.load_model(args, netE, optimE, 'Encoder_cifar.pt')
        print ('==> loading pretrained encoder')
    if args.pretrain_e:
        for j in range(200):
            #x = sample_x(args, [w1_gen, w2_gen, w3_gen, w4_gen, w5_gen], 0)
            x = sample_z_like((e_batch_size, args.ze))
            z = sample_z_like((e_batch_size, args.z))
            codes = netE(x)
            for i, code in enumerate(codes):
                code = code.view(e_batch_size, args.z)
                mean_loss, cov_loss = pretrain_loss(code, z)
                loss = mean_loss + cov_loss
                loss.backward(retain_graph=True)
            optimE.step()
            netE.zero_grad()
            print ('Pretrain Enc iter: {}, Mean Loss: {}, Cov Loss: {}'.format(
                j, mean_loss.item(), cov_loss.item()))
            final = loss.item()
            if loss.item() < 0.1:
                print ('Finished Pretraining Encoder')
                break
        utils.save_model(args, netE, optimE)

    print ('==> Begin Training')
    for _ in range(1000):
        for batch_idx, (data, target) in enumerate(cifar_train):

            batch_zero_grad([netE, W1, W2, W3, W4, W5, netD])
            #netE.zero_grad(); W1.zero_grad(); W2.zero_grad()
            #W3.zero_grad(); W4.zero_grad(); W5.zero_grad()
            z = sample_z_like((args.batch_size, args.ze,))
            codes = netE(z)
            l1 = W1(codes[0]).mean(0)
            l2 = W2(codes[1]).mean(0)
            l3 = W3(codes[2]).mean(0)
            l4 = W4(codes[3]).mean(0)
            l5 = W5(codes[4]).mean(0)
            
            # Z Adversary 
            free_params([netD])
            frozen_params([netE, W1, W2, W3, W4, W5])
            for code in codes:
                noise = sample_z_like((args.batch_size, args.z))
                d_real = netD(noise)
                d_fake = netD(code)
                d_real_loss = -1 * torch.log((1-d_real).mean())
                d_fake_loss = -1 * torch.log(d_fake.mean())
                d_real_loss.backward(retain_graph=True)
                d_fake_loss.backward(retain_graph=True)
                d_loss = d_real_loss + d_fake_loss
            optimD.step()

            # Generator (Mean test)
            frozen_params([netD])
            free_params([netE, W1, W2, W3, W4, W5])
            d_costs = []
            for code in codes:
                d_costs.append(netD(code))
            d_loss = torch.cat(d_costs).mean()
            correct, loss = train_clf(args, [l1, l2, l3, l4, l5], data, target, val=True)
            scaled_loss = (args.beta*loss) + d_loss
            scaled_loss.backward()
               
            optimE.step(); optimW1.step(); optimW2.step()
            optimW3.step(); optimW4.step(); optimW5.step()
            loss = loss.item()
            
            """ Update Statistics """
            if batch_idx % 50 == 0:
                acc = (correct / 1) 
                norm_z1 = np.linalg.norm(l1.data)
                norm_z2 = np.linalg.norm(l2.data)
                norm_z3 = np.linalg.norm(l3.data)
                norm_z4 = np.linalg.norm(l4.data)
                norm_z5 = np.linalg.norm(l5.data)
                print ('**************************************')
                print ('Mean Test: Enc, Dz, Lscale: {} test'.format(args.beta))
                print ('Acc: {}, G Loss: {}, D Loss: {}'.format(acc, loss, d_loss))
                print ('Filter norm: ', norm_z1)
                print ('Filter norm: ', norm_z2)
                print ('Filter norm: ', norm_z3)
                print ('Linear norm: ', norm_z4)
                print ('Linear norm: ', norm_z5)
                print ('best test loss: {}'.format(args.best_loss))
                print ('best test acc: {}'.format(args.best_acc))
                print ('**************************************')
            if batch_idx % 100 == 0:
                test_acc = 0.
                test_loss = 0.
                for i, (data, y) in enumerate(cifar_test):
                    z = sample_z_like((args.batch_size, args.ze,))
                    w1_code, w2_code, w3_code, w4_code, w5_code = netE(z)
                    l1 = W1(w1_code).mean(0)
                    l2 = W2(w2_code).mean(0)
                    l3 = W3(w3_code).mean(0)
                    l4 = W4(w4_code).mean(0)
                    l5 = W5(w5_code).mean(0)
                    min_loss_batch = 10.
                    z_test = [l1, l2, l3, l4, l5]
                    correct, loss = train_clf(args, [l1, l2, l3, l4, l5], data, y, val=True)
                    if loss.item() < min_loss_batch:
                        min_loss_batch = loss.item()
                        z_test = [l1, l2, l3, l4, l5]
                    test_acc += correct.item()
                    test_loss += loss.item()
                #y_acc, y_loss = utils.test_samples(args, z_test, train=True)
                test_loss /= len(cifar_test.dataset)
                test_acc /= len(cifar_test.dataset)
                print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss))
                # print ('FC Accuracy: {}, FC Loss: {}'.format(y_acc, y_loss))
                if test_loss < best_test_loss or test_acc > best_test_acc:
                    print ('==> new best stats, saving')
                    if test_loss < best_test_loss:
                        best_test_loss = test_loss
                        args.best_loss = test_loss
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        args.best_acc = test_acc
Example #5
0
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset)
        dataloader = DataLoader(dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.workers,
                                pin_memory=True,
                                sampler=train_sampler)
    else:
        dataloader = DataLoader(dataset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.workers,
                                pin_memory=True)

    dataloader = inf_train_gen(dataloader)

    #models
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)
    vgg = Vgg16(requires_grad=False).to(device)

    if args.pre_train:
        if args.distributed:
            g_checkpoint = torch.load(
                args.checkpoint_path +
                'generator_checkpoint_{}.ckpt'.format(args.last_iter),
                map_location=lambda storage, loc: storage.cuda(args.local_rank
                                                               ))
            d_checkpoint = torch.load(
                args.checkpoint_path +
Example #6
0
def train(args):
    
    torch.manual_seed(8734)
    
    netE = Encoder(args).cuda()
    W1 = GeneratorW1(args).cuda()
    W2 = GeneratorW2(args).cuda()
    W3 = GeneratorW3(args).cuda()
    W4 = GeneratorW4(args).cuda()
    W5 = GeneratorW5(args).cuda()
    print (netE, W1, W2, W3, W4, W5)

    optimizerE = optim.Adam(netE.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimizerW1 = optim.Adam(W1.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimizerW2 = optim.Adam(W2.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimizerW3 = optim.Adam(W3.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimizerW4 = optim.Adam(W4.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    optimizerW5 = optim.Adam(W5.parameters(), lr=3e-4, betas=(0.5, 0.9), weight_decay=1e-4)
    
    best_test_acc, best_test_loss = 0., np.inf
    args.best_loss, args.best_acc = best_test_loss, best_test_acc
    if args.resume:
        netE, optimizerE = load_model(args, netE, optimizerE, 'single_code1', m0)
        W1, optimizerW1 = load_model(args, W1, optimizerW1, 'single_code1', m1)
        W2, optimizerW2 = load_model(args, W2, optimizerW2, 'single_code1', m2)
        W3, optimizerW3, stats = load_model(args, W3, optimizerW3, 'single_code1', m3)
        W4, optimizerW3, stats = load_model(args, W4, optimizerW4, 'single_code1', m4)
        W4, optimizerW3, stats = load_model(args, W5, optimizerW5, 'single_code1', m5)
        best_test_acc, best_test_loss = stats
        print ('==> resumeing models at ', stats)

    cifar_train, cifar_test = load_cifar()
    if args.use_x:
        base_gen = datagen.load(args)
        w1_gen = utils.inf_train_gen(base_gen[0])
        w2_gen = utils.inf_train_gen(base_gen[1])
        w3_gen = utils.inf_train_gen(base_gen[2])
        w4_gen = utils.inf_train_gen(base_gen[3])
        w5_gen = utils.inf_train_gen(base_gen[4])

    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    if args.use_x:
        X = sample_x(args, [w1_gen, w2_gen, w3_gen, w4_gen, w5_gen], 0)
        X = list(map(lambda x: (x+1e-10).float(), X))
    for _ in range(1000):
        for batch_idx, (data, target) in enumerate(cifar_train):

            netE.zero_grad()
            W1.zero_grad()
            W2.zero_grad()
            W3.zero_grad()
            W4.zero_grad()
            W5.zero_grad()
            """
            if batch_idx % 50 == 0:
                if args.use_x:
                    acc, loss = 0., 0.
                    for i, (x, y) in enumerate(cifar_test):
                        correct, l = train_clf(args, X, x, y, val=True)
                        acc += correct.item()
                        loss += l.item()
                    print ("Functional Net: ", acc/len(cifar_test.dataset),
                            loss/len(cifar_test.dataset))
            """
            z = sample_z_like((args.batch_size, args.ze,))
            code = netE(z)
            l1 = W1(code)
            l2 = W2(code)
            l3 = W3(code)
            l4 = W4(code)
            l5 = W5(code)#.contiguous().view(args.batch_size, -1))
            
            for (z1, z2, z3, z4, z5) in zip(l1, l2, l3, l4, l5):
                correct, loss = train_clf(args, [z1, z2, z3, z4, z5], data, target, val=True)
                scaled_loss = (1000*loss) #+ z1_loss + z2_loss + z3_loss
                scaled_loss.backward(retain_graph=True)
            optimizerE.step()
            optimizerW1.step()
            optimizerW2.step()
            optimizerW3.step()
            optimizerW4.step()
            optimizerW5.step()
            loss = loss.item()
                
            if batch_idx % 50 == 0:
                acc = (correct / 1) 
                norm_z1 = np.linalg.norm(z1.data)
                norm_z2 = np.linalg.norm(z2.data)
                norm_z3 = np.linalg.norm(z3.data)
                norm_z4 = np.linalg.norm(z4.data)
                norm_z5 = np.linalg.norm(z5.data)
                print ('**************************************')
                print ('100 tied test')
                print ('Acc: {}, Loss: {}'.format(acc, loss))
                print ('Filter norm: ', norm_z1)
                print ('Filter norm: ', norm_z2)
                print ('Linear norm: ', norm_z3)
                print ('Linear norm: ', norm_z4)
                print ('Linear norm: ', norm_z5)
                print ('best test loss: {}'.format(args.best_loss))
                print ('best test acc: {}'.format(args.best_acc))
                print ('**************************************')
            if batch_idx % 100 == 0:
                test_acc = 0.
                test_loss = 0.
                for i, (data, y) in enumerate(cifar_test):
                    z = sample_z_like((args.batch_size, args.ze,))
                    code = netE(z)
                    l1 = W1(code)
                    l2 = W2(code)
                    l3 = W3(code)
                    l4 = W4(code)
                    l5 = W5(code)
                    min_loss_batch = 10.
                    z_test = [l1[0], l2[0], l3[0], l4[0], l5[0]]
                    for (z1, z2, z3, z4, z5) in zip(l1, l2, l3, l4, l5):
                        correct, loss = train_clf(args, [z1, z2, z3, z4, z5], data, y, val=True)
                        if loss.item() < min_loss_batch:
                            min_loss_batch = loss.item()
                            z_test = [z1, z2, z3, z4, z5]
                        test_acc += correct.item()
                        test_loss += loss.item()
                #y_acc, y_loss = utils.test_samples(args, z_test, train=True)
                test_loss /= len(cifar_test.dataset) * 32
                test_acc /= len(cifar_test.dataset) * 32
                print ('Test Accuracy: {}, Test Loss: {}'.format(test_acc, test_loss))
                # print ('FC Accuracy: {}, FC Loss: {}'.format(y_acc, y_loss))
                if test_loss < best_test_loss or test_acc > best_test_acc:
                    print ('==> new best stats, saving')
                    if test_loss < best_test_loss:
                        best_test_loss = test_loss
                        args.best_loss = test_loss
                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        args.best_acc = test_acc
def train():
    args = load_args()
    train_gen = utils.dataset_iterator(args)
    dev_gen = utils.dataset_iterator(args)
    torch.manual_seed(1)
    netG, netD, netE = load_models(args)

    if args.use_spectral_norm:
        optimizerD = optim.Adam(filter(lambda p: p.requires_grad,
            netD.parameters()), lr=2e-4, betas=(0.0,0.9))
    else:
        optimizerD = optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.9))
    optimizerG = optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.9))
    optimizerE = optim.Adam(netE.parameters(), lr=2e-4, betas=(0.5, 0.9))

    schedulerD = optim.lr_scheduler.ExponentialLR(optimizerD, gamma=0.99)
    schedulerG = optim.lr_scheduler.ExponentialLR(optimizerG, gamma=0.99) 
    schedulerE = optim.lr_scheduler.ExponentialLR(optimizerE, gamma=0.99)
    
    ae_criterion = nn.MSELoss()
    one = torch.FloatTensor([1]).cuda()
    mone = (one * -1).cuda()
    gen = utils.inf_train_gen(train_gen)

    for iteration in range(args.epochs):
        start_time = time.time()
        """ Update AutoEncoder """
        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()
        netE.zero_grad()
        _data = next(gen)
        # real_data = stack_data(args, _data)
        real_data = _data
        real_data_v = autograd.Variable(real_data).cuda()
        encoding = netE(real_data_v)
        fake = netG(encoding)
        ae_loss = ae_criterion(fake, real_data_v)
        ae_loss.backward(one)
        optimizerE.step()
        optimizerG.step()

        
        """ Update D network """

        for p in netD.parameters():  
            p.requires_grad = True 
        for i in range(5):
            _data = next(gen)
            # real_data = stack_data(args, _data)
            real_data = _data
            real_data_v = autograd.Variable(real_data).cuda()
            # train with real data
            netD.zero_grad()
            D_real = netD(real_data_v)
            D_real = D_real.mean()
            D_real.backward(mone)
            # train with fake data
            noise = torch.randn(args.batch_size, args.dim).cuda()
            noisev = autograd.Variable(noise, volatile=True)
            fake = autograd.Variable(netG(noisev).data)
            inputv = fake
            D_fake = netD(inputv)
            D_fake = D_fake.mean()
            D_fake.backward(one)

            # train with gradient penalty 
            gradient_penalty = ops.calc_gradient_penalty(args,
                    netD, real_data_v.data, fake.data)
            gradient_penalty.backward()

            D_cost = D_fake - D_real + gradient_penalty
            Wasserstein_D = D_real - D_fake
            optimizerD.step()

        # Update generator network (GAN)
        noise = torch.randn(args.batch_size, args.dim).cuda()
        noisev = autograd.Variable(noise)
        fake = netG(noisev)
        G = netD(fake)
        G = G.mean()
        G.backward(mone)
        G_cost = -G
        optimizerG.step() 

        schedulerD.step()
        schedulerG.step()
        schedulerE.step()
        # Write logs and save samples 
        save_dir = './plots/'+args.dataset
        plot.plot(save_dir, '/disc cost', D_cost.cpu().data.numpy())
        plot.plot(save_dir, '/gen cost', G_cost.cpu().data.numpy())
        plot.plot(save_dir, '/w1 distance', Wasserstein_D.cpu().data.numpy())
        plot.plot(save_dir, '/ae cost', ae_loss.data.cpu().numpy())
        
        # Calculate dev loss and generate samples every 100 iters
        if iteration % 100 == 99:
            dev_disc_costs = []
            for i, (images, _) in enumerate(dev_gen):
                # imgs = stack_data(args, images) 
                imgs = images
                imgs_v = autograd.Variable(imgs, volatile=True).cuda()
                D = netD(imgs_v)
                _dev_disc_cost = -D.mean().cpu().data.numpy()
                dev_disc_costs.append(_dev_disc_cost)
            plot.plot(save_dir ,'/dev disc cost', np.mean(dev_disc_costs))
            
            utils.generate_image(iteration, netG, save_dir, args)
            # utils.generate_ae_image(iteration, netE, netG, save_dir, args, real_data_v)

        # Save logs every 100 iters 
        if (iteration < 5) or (iteration % 100 == 99):
            plot.flush()
        plot.tick()
        if iteration % 100 == 0:
            utils.save_model(netG, optimizerG, iteration,
                    'models/{}/G_{}'.format(args.dataset, iteration))
            utils.save_model(netD, optimizerD, iteration, 
                    'models/{}/D_{}'.format(args.dataset, iteration))