示例#1
0
def validate(args,
             fixed_z,
             fid_stat,
             gen_net: nn.Module,
             writer_dict,
             clean_dir=True):
    writer = writer_dict["writer"]
    global_steps = writer_dict["valid_global_steps"]

    # eval mode
    gen_net = gen_net.eval()

    # generate images
    sample_imgs = gen_net(fixed_z)
    img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)

    # get fid and inception score
    fid_buffer_dir = os.path.join(args.path_helper["sample_path"],
                                  "fid_buffer")
    os.makedirs(fid_buffer_dir, exist_ok=True)

    eval_iter = args.num_eval_imgs // args.eval_batch_size
    img_list = list()
    for iter_idx in tqdm(range(eval_iter), desc="sample images"):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = (gen_net(z).mul_(127.5).add_(127.5).clamp_(
            0.0, 255.0).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy())
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(fid_buffer_dir,
                                     f"iter{iter_idx}_b{img_idx}.png")
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    # get inception score
    logger.info("=> calculate inception score")
    mean, std = get_inception_score(img_list)
    print(f"Inception score: {mean}")

    # get fid score
    logger.info("=> calculate fid score")
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat],
                                          inception_path=None)
    print(f"FID score: {fid_score}")

    if clean_dir:
        os.system("rm -r {}".format(fid_buffer_dir))
    else:
        logger.info(f"=> sampled images are saved to {fid_buffer_dir}")

    writer.add_image("sampled_images", img_grid, global_steps)
    writer.add_scalar("Inception_score/mean", mean, global_steps)
    writer.add_scalar("Inception_score/std", std, global_steps)
    writer.add_scalar("FID_score", fid_score, global_steps)

    writer_dict["valid_global_steps"] = global_steps + 1

    return mean, fid_score
示例#2
0
    def get_score(opt, netG, netD, dataloader, loss, data_name):

        # eval mode
        netG.eval()
        img_list = list()
        ds = list()
        with torch.no_grad():
            for data in dataloader:
                imgs = data[data_name]
                if len(opt.gpu_ids) != 0:
                    imgs = imgs.cuda()

                conv_imgs = netG(imgs)

                d = [float(loss(netD(im.unsqueeze(0)), True)) for im in conv_imgs]
                conv_imgs = conv_imgs.mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu',
                                                                                                        torch.uint8).numpy()
                img_list.extend(list(conv_imgs))
                ds.extend(d)

        mean_is, std_is = get_inception_score(img_list, splits=1)

        mean_d = 1 / (np.mean(ds) + 0.00005)
        std_d = 1 / (np.std(ds) + 0.00005)

        return mean_is, mean_d * 3.0
示例#3
0
def get_is(args, gen_net: nn.Module, num_img):
    """
    Get inception score.
    :param args:
    :param gen_net:
    :param num_img:
    :return: Inception score
    """

    # eval mode
    gen_net = gen_net.eval()

    eval_iter = num_img // args.eval_batch_size
    img_list = list()
    for _ in range(eval_iter):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = (gen_net(z).mul_(127.5).add_(127.5).clamp_(
            0.0, 255.0).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy())
        img_list.extend(list(gen_imgs))

    # get inception score
    logger.info("calculate Inception score...")
    mean, std = get_inception_score(img_list)

    return mean
示例#4
0
    def calculate_inception_score(self):
                    
        sample_list = []
        for i in range(10):
            z = self.fixed_noise
            samples = self.netG(z)
            sample_list.append(samples.data.cpu().numpy())

        new_sample_list = list(chain.from_iterable(sample_list))
        print("Calculating Inception Score over 8k generated images")
        inception_score = get_inception_score(
            new_sample_list,
            cuda=True,
            batch_size=32,
            resize=True,
            splits=10,
        )

        time = t.time() - self.t_begin
        print("Inception score: {}".format(inception_score))
        print("Generator iter: {}".format(iters))
        print("Time {}".format(time))

        output = str(iters) + ", " + str(inception_score[0]) + "\n"
                    
        return output
示例#5
0
def validate(args, fixed_z, fid_stat, gen_net: nn.Module, writer_dict, epoch):
    np.random.seed(args.random_seed**2 + epoch)
    writer = writer_dict['writer']
    global_steps = writer_dict['valid_global_steps']

    # eval mode
    gen_net = gen_net.eval()

    # generate images
    sample_imgs = gen_net(fixed_z)
    img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)

    # get fid and inception score
    fid_buffer_dir = os.path.join(args.path_helper['sample_path'],
                                  'fid_buffer')
    os.makedirs(fid_buffer_dir)

    eval_iter = args.num_eval_imgs // args.eval_batch_size
    img_list = list()
    for iter_idx in range(eval_iter):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = gen_net(z).mul_(127.5).add_(127.5).clamp_(
            0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(fid_buffer_dir,
                                     'iter%d_b%d.png' % (iter_idx, img_idx))
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    # get inception score
    logger.info('=> calculate inception score')
    mean, std = get_inception_score(img_list)

    # get fid score
    logger.info('=> calculate fid score')
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat],
                                          inception_path=None)

    os.system('rm -r {}'.format(fid_buffer_dir))

    writer.add_image('sampled_images', img_grid, global_steps)
    writer.add_scalar('Inception_score/mean', mean, global_steps)
    writer.add_scalar('Inception_score/std', std, global_steps)
    writer.add_scalar('FID_score', fid_score, global_steps)

    writer_dict['valid_global_steps'] = global_steps + 1

    return mean, fid_score
示例#6
0
def validate(args, fixed_z, fid_stat, epoch, gen_net: nn.Module, writer_dict, clean_dir=True):
    writer = writer_dict['writer']
    global_steps = writer_dict['valid_global_steps']

    # eval mode
    gen_net = gen_net.eval()

    # generate images
    sample_imgs = gen_net(fixed_z, epoch)
    img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)

    # get fid and inception score
    fid_buffer_dir = os.path.join(args.path_helper['sample_path'], 'fid_buffer')
    os.makedirs(fid_buffer_dir, exist_ok=True)

    eval_iter = args.num_eval_imgs // args.eval_batch_size
    img_list = list()
    for iter_idx in tqdm(range(eval_iter), desc='sample images'):
        z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = gen_net(z, epoch).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu',
                                                                                                torch.uint8).numpy()
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(fid_buffer_dir, f'iter{iter_idx}_b{img_idx}.png')
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    # get inception score
    logger.info('=> calculate inception score')
    mean, std = get_inception_score(img_list)
    print(f"Inception score: {mean}")

    # get fid score
    logger.info('=> calculate fid score')
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat], inception_path=None)
    print(f"FID score: {fid_score}")

    if clean_dir:
        os.system('rm -r {}'.format(fid_buffer_dir))
    else:
        logger.info(f'=> sampled images are saved to {fid_buffer_dir}')
    # print('first')
    writer.add_image('sampled_images', img_grid, global_steps)
    writer.add_scalar('Inception_score/mean', mean, global_steps)
    writer.add_scalar('Inception_score/std', std, global_steps)
    writer.add_scalar('FID_score', fid_score, global_steps)
    # print('second')
    writer_dict['valid_global_steps'] = global_steps + 1
    # print('third')
    return mean, fid_score
示例#7
0
def fid_is_eval():
    pathes = os.listdir("results_pure_noise/generated/")
    finished = []
    tmp_path = []
    for path in pathes:
        if path not in finished:
            tmp_path.append(path)
    pathes = tmp_path
    for path in pathes:
        fp = open("FID_IS_result.txt", 'a')
        info = path.split('_')
        dataset = info[0]
        model_name = path
        if dataset == "cifar10":
            data = np.load('results_pure_noise/generated/{}'.format(path))
            mb_size, X_dim, width, height, channels, len_x_train, x_train, y_train, len_x_test, x_test, y_test = data_loader(
                dataset)
            real_set = x_train
            img_set = data['x']
            print("Calculating Fréchet Inception Distance for {}".format(
                model_name))
            print("Calculating Fréchet Inception Distance for {}".format(
                model_name),
                  file=fp)
            fid_set_r = real_set * 255.0
            fid_set_r = fid_set_r.astype(np.uint8)
            fid_set_r = np.transpose(fid_set_r, (0, 3, 1, 2))
            fid_set_i = img_set * 255.0
            fid_set_i = fid_set_i.astype(np.uint8)
            fid_set_i = np.transpose(fid_set_i, (0, 3, 1, 2))
            #fid_set_i = fid_set_i[:256]
            #fid_set_r = fid_set_r[:256]
            fid_score = get_fid(fid_set_r, fid_set_i)
            print("FID: {}".format(fid_score))
            print("FID: {}".format(fid_score), file=fp)
            tf.reset_default_graph()
            fp.close()
            fp = open("FID_IS_result.txt", 'a')

            print("Calculating inception score for {}".format(model_name))
            print("Calculating inception score for {}".format(model_name),
                  file=fp)
            is_set = img_set * 2.0 - 1.0
            is_set = np.transpose(is_set, (0, 3, 1, 2))
            #is_set = is_set[:256]
            mean, std = get_inception_score(is_set)
            print("mean: {} std: {}".format(mean, std))
            print("mean: {} std: {}".format(mean, std), file=fp)
            tf.reset_default_graph()
            fp.close()
示例#8
0
def get_is(args, gen_net: nn.Module, num_img, z_numpy=None, get_is_score=True):
    """
    Get inception score.
    :param args:
    :param gen_net:
    :param num_img:
    :return: Inception score
    """

    # eval mode
    gen_net = gen_net.eval()

    eval_iter = num_img // args.eval_batch_size
    img_list = list()
    state_list = list()
    for i in range(eval_iter):
        # We use a fixed set of random seeds for the reward and progressive states in the search stage to stabalize the training
        np.random.seed(i)
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))
        # Generate a batch of images
        gen_imgs, gen_states = gen_net(z, eval=True)
        gen_imgs2 = gen_imgs.mul_(127.5).add_(127.5).clamp_(0.0,
                                                            255.0).permute(
                                                                0, 2, 3, 1)
        img_list.extend(list(gen_imgs2.to('cpu', torch.uint8).numpy()))
        state_list.extend(list(gen_states.to('cpu').numpy()))
    state = list(np.mean(state_list, axis=0).flatten())
    if not get_is_score:
        return state
    # get inception score
    logger.info('calculate Inception score...')
    mean, std = get_inception_score(img_list)
    logger.info('=> calculate fid score')
    fid_buffer_dir = os.path.join(args.path_helper['sample_path'],
                                  'fid_buffer')
    os.system('rm -rf {}'.format(fid_buffer_dir))
    os.makedirs(fid_buffer_dir, exist_ok=True)
    for img_idx, img in enumerate(img_list):
        if img_idx < 5000:
            file_name = os.path.join(fid_buffer_dir, f'iter0_b{img_idx}.png')
            imsave(file_name, img)
    fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    assert os.path.exists(fid_stat)
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat],
                                          inception_path=None)

    return mean, fid_score, state
示例#9
0
 def get_inception_score(self, num_batches, splits=None):
     all_samples = []
     config = self.config
     if not splits:
         splits = config.inps_splits
     batch_size = 100
     dim_z = config.dim_z
     for i in range(num_batches):
         z = np.random.normal(size=[batch_size, dim_z]).astype(np.float32)
         feed_dict = {self.noise: z, self.is_training: False}
         samples = self.sess.run(self.fake_data, feed_dict=feed_dict)
         all_samples.append(samples)
     all_samples = np.concatenate(all_samples, axis=0)
     all_samples = ((all_samples + 1.) * 255. / 2.).astype(np.int32)
     all_samples = all_samples.reshape((-1, 32, 32, 3))
     return inception_score.get_inception_score(list(all_samples),
                                                splits=splits)
示例#10
0
文件: test.py 项目: authorsLEAD/LEAD
def validate(args, fixed_z, gen_net: nn.Module, writer_dict):
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train
    gen_net = gen_net.eval()
    global_steps = writer_dict['valid_global_steps']
    eval_iter = args.num_eval_imgs // args.eval_batch_size
    fid_buffer_dir = os.path.join(args.path_helper['sample_path'],
                                  'fid_buffer')
    os.makedirs(fid_buffer_dir)
    img_list = list()
    for iter_idx in tqdm(range(eval_iter), desc='sample images'):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = gen_net(z).mul_(127.5).add_(127.5).clamp_(
            0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(fid_buffer_dir,
                                     f'iter{iter_idx}_b{img_idx}.png')
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    # compute IS
    inception_score, std = get_inception_score(img_list)
    print('------------------------ Inception Score ------------------------')
    print(inception_score)

    print('------------------------ FID pytorch ------------------------')
    print(fid_score)

    # Generate a batch of images
    sample_dir = os.path.join(args.path_helper['sample_path'], 'sample_dir')
    Path(sample_dir).mkdir(exist_ok=True)

    sample_imgs = gen_net(fixed_z).mul_(127.5).add_(127.5).clamp_(0.0, 255.0)
    img_grid = make_grid(sample_imgs, nrow=5).to('cpu', torch.uint8).numpy()
    file_name = os.path.join(
        sample_dir, f'final_fid_{fid}_inception_score{inception_score}.png')
    imsave(file_name, img_grid.swapaxes(0, 1).swapaxes(1, 2))

    writer_dict['valid_global_steps'] = global_steps + 1
    return inception_score, fid
示例#11
0
def calculate_metrics(fid_buffer_dir, num_eval_imgs, eval_batch_size, latent_dim, fid_stat, G, do_IS=False, do_FID=True):
	# eval mode
	G = G.eval()

	# get fid and inception score
	if do_IS and do_FID:
		if not os.path.isdir(fid_buffer_dir):
			os.mkdir(fid_buffer_dir)

		eval_iter = num_eval_imgs // eval_batch_size
		img_list = list()
		for iter_idx in range(eval_iter):
			z = torch.cuda.FloatTensor(np.random.normal(0, 1, (eval_batch_size, latent_dim)))

			# Generate a batch of images
			gen_imgs = G(z).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
			for img_idx, img in enumerate(gen_imgs):
				file_name = os.path.join(fid_buffer_dir, 'iter%d_b%d.png' % (iter_idx, img_idx))
				imsave(file_name, img)
			img_list.extend(list(gen_imgs))

	# get inception score
	if do_IS:
		print('=> calculate inception score')
		mean, std = get_inception_score(img_list)
	else:
		mean, std = 0, 0

	# get fid score
	if do_FID:
		print('=> calculate fid score')
		fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat], inception_path=None)
	else:
		fid_score = 0

	if do_IS and do_FID:
		os.system('rm -r {}'.format(fid_buffer_dir))

	return mean, fid_score
示例#12
0
def evaluate(args, fixed_z, fid_stat, gen_net: nn.Module):
    # eval mode
    gen_net = gen_net.eval()

    # generate images
    sample_imgs = gen_net(fixed_z)
    img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)

    # get fid and inception score
    fid_buffer_dir = 'fid_buffer_test'
    if not os.path.exists(fid_buffer_dir): os.makedirs(fid_buffer_dir)

    eval_iter = args.num_eval_imgs // args.eval_batch_size
    img_list = list()
    for iter_idx in tqdm(range(eval_iter), desc='sample images'):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = gen_net(z).mul_(127.5).add_(127.5).clamp_(
            0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(fid_buffer_dir,
                                     'iter%d_b%d.png' % (iter_idx, img_idx))
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    # get inception score
    print('=> calculate inception score')
    mean, std = get_inception_score(img_list)

    # get fid score
    print('=> calculate fid score')
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat],
                                          inception_path=None)

    os.system('rm -r {}'.format(fid_buffer_dir))
    return mean, fid_score
示例#13
0
def validate(args, fixed_z, fid_stat, gen_net: nn.Module, writer_dict):
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train
    gen_net = gen_net.eval()
    global_steps = writer_dict['valid_global_steps']
    eval_iter = args.num_eval_imgs // args.eval_batch_size
    # compute IS
    IS_buffer_dir = os.path.join(args.path_helper['sample_path'], 'fid_buffer')
    os.makedirs(IS_buffer_dir)
    img_list = list()
    for iter_idx in tqdm(range(eval_iter), desc='sample images'):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = gen_net(z).mul_(127.5).add_(127.5).clamp_(
            0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(IS_buffer_dir,
                                     f'iter{iter_idx}_b{img_idx}.png')
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    inception_score, std = get_inception_score(img_list)
    print('------------------------Inception Score------------------------')
    print(inception_score)

    # compute FID
    sample_list = []
    for i in range(eval_iter):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))
        samples = gen_net(z)
        sample_list.append(samples.data.cpu().numpy())

    new_sample_list = list(chain.from_iterable(sample_list))
    fake_image_np = np.concatenate([img[None] for img in new_sample_list], 0)

    real_image_np = []
    for i, (images, _) in enumerate(train_loader):
        real_image_np += [images.data.numpy()]
        batch_size = real_image_np[0].shape[0]
        if len(real_image_np) * batch_size >= fake_image_np.shape[0]:
            break
    real_image_np = np.concatenate(real_image_np, 0)[:fake_image_np.shape[0]]
    fid_score = calculate_fid(real_image_np, fake_image_np, batch_size=300)
    var_fid = fid_score[0][2]
    fid = round(fid_score[0][1], 3)
    print('------------------------fid_score------------------------')
    print(fid_score)

    # Generate a batch of images
    sample_dir = os.path.join(args.path_helper['sample_path'], 'sample_dir')
    Path(sample_dir).mkdir(exist_ok=True)

    sample_imgs = gen_net(fixed_z).mul_(127.5).add_(127.5).clamp_(0.0, 255.0)
    img_grid = make_grid(sample_imgs, nrow=5).to('cpu', torch.uint8).numpy()
    file_name = os.path.join(
        sample_dir, f'final_fid_{fid}_inception_score{inception_score}.png')
    imsave(file_name, img_grid.swapaxes(0, 1).swapaxes(1, 2))

    writer_dict['valid_global_steps'] = global_steps + 1
    return inception_score, fid
示例#14
0
    def train(self, train_loader):
        self.t_begin = t.time()
        generator_iter = 0
        self.file = open("inception_score_graph.txt", "w")
        dis_params_flatten = parameters_to_vector(self.D.parameters())
        gen_params_flatten = parameters_to_vector(self.G.parameters())

        # just to fill the empty grad buffers
        if self.cuda:
            z = Variable(torch.randn(self.batch_size, 100, 1,
                                     1)).cuda(self.cuda_index)
        else:
            z = Variable(torch.randn(self.batch_size, 100, 1, 1))
        fake_images = self.G(z)
        outputs = self.D(fake_images)
        fake_labels = torch.zeros(self.batch_size)
        fake_labels = Variable(fake_labels).cuda(self.cuda_index)
        d_loss_fake = self.loss(outputs.squeeze(), fake_labels)
        (0.0 * d_loss_fake).backward(create_graph=True)
        d_loss_fake = 0.0
        best_inception_score = 0.0
        d_loss_list = []
        g_loss_list = []
        for epoch in range(self.epochs):
            self.epoch_start_time = t.time()

            for i, (images, _) in enumerate(train_loader):
                # Check if round number of batches
                if i == train_loader.dataset.__len__() // self.batch_size:
                    break

                z = torch.rand((self.batch_size, 100, 1, 1))
                real_labels = torch.ones(self.batch_size)
                fake_labels = torch.zeros(self.batch_size)

                if self.cuda:
                    images, z = Variable(images).cuda(
                        self.cuda_index), Variable(z).cuda(self.cuda_index)
                    real_labels, fake_labels = Variable(real_labels).cuda(
                        self.cuda_index), Variable(fake_labels).cuda(
                            self.cuda_index)
                else:
                    images, z = Variable(images), Variable(z)
                    real_labels, fake_labels = Variable(real_labels), Variable(
                        fake_labels)

                # Train discriminator
                # Compute BCE_Loss using real images
                outputs = self.D(images)
                d_loss_real = self.loss(outputs.squeeze(), real_labels)
                real_score = outputs

                # Compute BCE Loss using fake images
                if self.cuda:
                    z = Variable(torch.randn(self.batch_size, 100, 1,
                                             1)).cuda(self.cuda_index)
                else:
                    z = Variable(torch.randn(self.batch_size, 100, 1, 1))
                fake_images = self.G(z)
                outputs = self.D(fake_images)
                d_loss_fake = self.loss(outputs.squeeze(), fake_labels)
                fake_score = outputs

                # Optimize discriminator
                d_loss = d_loss_real + d_loss_fake
                if self.mode == 'adam':
                    self.D.zero_grad()
                    d_loss.backward()
                    self.d_optimizer.step()
                elif self.mode == 'adam_vjp':
                    gradsD = torch.autograd.grad(outputs=d_loss,
                                                 inputs=(self.D.parameters()),
                                                 create_graph=True)
                    for p, g in zip(self.D.parameters(), gradsD):
                        p.grad = g
                    gen_params_flatten_prev = gen_params_flatten + 0.0
                    gen_params_flatten = parameters_to_vector(
                        self.G.parameters()) + 0.0
                    grad_gen_params_flatten = optim.parameters_grad_to_vector(
                        self.G.parameters())
                    delta_gen_params_flatten = gen_params_flatten - gen_params_flatten_prev
                    vjp_dis = torch.autograd.grad(
                        grad_gen_params_flatten,
                        self.D.parameters(),
                        grad_outputs=delta_gen_params_flatten)
                    self.d_optimizer.step(vjps=vjp_dis)

                # Train generator
                # Compute loss with fake images
                if self.cuda:
                    z = Variable(torch.randn(self.batch_size, 100, 1,
                                             1)).cuda(self.cuda_index)
                else:
                    z = Variable(torch.randn(self.batch_size, 100, 1, 1))
                fake_images = self.G(z)
                outputs = self.D(fake_images)
                # non-zero_sum
                g_loss = self.loss(outputs.squeeze(), real_labels)
                # zer_sum:
                # g_loss = - self.loss(outputs.squeeze(), fake_labels)
                # Optimize generator
                if self.mode == 'adam':
                    self.D.zero_grad()
                    self.G.zero_grad()
                    g_loss.backward()
                    self.g_optimizer.step()
                elif self.mode == 'adam_vjp':
                    gradsG = torch.autograd.grad(outputs=g_loss,
                                                 inputs=(self.G.parameters()),
                                                 create_graph=True)
                    for p, g in zip(self.G.parameters(), gradsG):
                        p.grad = g

                    dis_params_flatten_prev = dis_params_flatten + 0.0
                    dis_params_flatten = parameters_to_vector(
                        self.D.parameters()) + 0.0
                    grad_dis_params_flatten = optim.parameters_grad_to_vector(
                        self.D.parameters())
                    delta_dis_params_flatten = dis_params_flatten - dis_params_flatten_prev
                    vjp_gen = torch.autograd.grad(
                        grad_dis_params_flatten,
                        self.G.parameters(),
                        grad_outputs=delta_dis_params_flatten)
                    self.g_optimizer.step(vjps=vjp_gen)

                generator_iter += 1

                if generator_iter % 1000 == 0:
                    # Workaround because graphic card memory can't store more than 800+ examples in memory for generating image
                    # Therefore doing loop and generating 800 examples and stacking into list of samples to get 8000 generated images
                    # This way Inception score is more correct since there are different generated examples from every class of Inception model
                    sample_list = []
                    for i in range(10):
                        z = Variable(torch.randn(800, 100, 1,
                                                 1)).cuda(self.cuda_index)
                        samples = self.G(z)
                        sample_list.append(samples.data.cpu().numpy())

                    # Flattening list of lists into one list of numpy arrays
                    new_sample_list = list(chain.from_iterable(sample_list))
                    print(
                        "Calculating Inception Score over 8k generated images")
                    # Feeding list of numpy arrays
                    inception_score = get_inception_score(new_sample_list,
                                                          cuda=True,
                                                          batch_size=32,
                                                          resize=True,
                                                          splits=10)
                    print('Epoch-{}'.format(epoch + 1))
                    print(inception_score)
                    if inception_score >= best_inception_score:
                        best_inception_score = inception_score
                        self.save_model()

                    # Denormalize images and save them in grid 8x8
                    z = Variable(torch.randn(800, 100, 1,
                                             1)).cuda(self.cuda_index)
                    samples = self.G(z)
                    samples = samples.mul(0.5).add(0.5)
                    samples = samples.data.cpu()[:64]
                    grid = utils.make_grid(samples)
                    utils.save_image(
                        grid, self.name + '/iter_{}_inception_{}_.png'.format(
                            str(generator_iter).zfill(3),
                            str(inception_score)))

                    time = t.time() - self.t_begin
                    print("Inception score: {}".format(inception_score))
                    print("Generator iter: {}".format(generator_iter))
                    print("Time {}".format(time))

                    # Write to file inception_score, gen_iters, time
                    output = str(generator_iter) + " " + str(time) + " " + str(
                        inception_score[0]) + "\n"
                    self.file.write(output)

                if ((i + 1) % 100) == 0:
                    d_loss_list += [d_loss.item()]
                    g_loss_list += [g_loss.item()]
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epoch + 1),
                           (i + 1), train_loader.dataset.__len__() //
                           self.batch_size, d_loss.item(), g_loss.item()))

                    z = Variable(
                        torch.randn(self.batch_size, 100, 1,
                                    1).cuda(self.cuda_index))

                    # TensorBoard logging
                    # Log the scalar values
                    info = {'d_loss': d_loss.item(), 'g_loss': g_loss.item()}

                    for tag, value in info.items():
                        self.logger.scalar_summary(tag, value, generator_iter)

                    # Log values and gradients of the parameters
                    for tag, value in self.D.named_parameters():
                        tag = tag.replace('.', '/')
                        self.logger.histo_summary(tag, self.to_np(value),
                                                  generator_iter)
                        self.logger.histo_summary(tag + '/grad',
                                                  self.to_np(value.grad),
                                                  generator_iter)

                    # Log the images while training
                    info = {
                        'real_images':
                        self.real_images(images, self.number_of_images),
                        'generated_images':
                        self.generate_img(z, self.number_of_images)
                    }

                    for tag, images in info.items():
                        self.logger.image_summary(tag, images, generator_iter)

        self.t_end = t.time()
        print('Time of training-{}'.format((self.t_end - self.t_begin)))

        # Save the trained parameters
        self.save_final_model()
        np.save(self.name + '/d_loss', np.array(d_loss_list))
        np.save(self.name + '/g_loss', np.array(g_loss_list))
        self.evaluate(train_loader, self.name + '/discriminator.pkl',
                      self.name + '/generator.pkl')