Exemplo n.º 1
0
def compute_fid(netG,
                data_dir,
                reference_data,
                cpu_inference=False,
                data_size=50000,
                delete_cache=False):

    original_data_path = reference_data + "/distil_pics/"

    if delete_cache:
        for file_img in glob.glob(data_dir + '/*.png'):
            os.remove(file_img)

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    if len(glob.glob(data_dir + '/*')) < 50000:
        print(
            "Here is no generated data, so I generate it using provided model")

        b_size = 50

        eval_dataloader = DataLoader(PairedImageDataset(reference_data),
                                     batch_size=b_size,
                                     shuffle=False,
                                     num_workers=4,
                                     drop_last=False)

        input_eval_source = torch.cuda.FloatTensor(b_size, 512)
        netG.eval()
        for i_eval_img, eval_batch in tqdm(enumerate(eval_dataloader)):
            input_img = Variable(input_eval_source.copy_(eval_batch['input']))
            with torch.no_grad():

                if cpu_inference:
                    input_img = input_img.cpu()
                output_img = netG(input_img)

            for i_img_from_batch in range(b_size):
                img_np = output_img[i_img_from_batch:(
                    i_img_from_batch + 1)].detach().cpu().numpy()

                img_np = np.moveaxis(img_np, 1, -1)
                img_np = np.clip((img_np + 1) / 2, 0, 1)  # (-1,1) -> (0,1)

                imsave(
                    os.path.join(
                        data_dir,
                        '%s.png' % (i_eval_img * b_size + i_img_from_batch)),
                    img_as_ubyte(img_np[0]))

                if i_eval_img + 1 == data_size:
                    break
    else:
        pass
        #print(f"I found {len(glob.glob(data_dir + '/*.png'))} pictures in the folder")
    paths = [data_dir, original_data_path]

    fid = calculate_fid_given_paths(paths, 32, True, 2048, delete_cache)
    return fid
Exemplo n.º 2
0
def validate(args,
             fixed_z,
             fid_stat,
             gen_net: nn.Module,
             writer_dict,
             clean_dir=True):
    writer = writer_dict["writer"]
    global_steps = writer_dict["valid_global_steps"]

    # eval mode
    gen_net = gen_net.eval()

    # generate images
    sample_imgs = gen_net(fixed_z)
    img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)

    # get fid and inception score
    fid_buffer_dir = os.path.join(args.path_helper["sample_path"],
                                  "fid_buffer")
    os.makedirs(fid_buffer_dir, exist_ok=True)

    eval_iter = args.num_eval_imgs // args.eval_batch_size
    img_list = list()
    for iter_idx in tqdm(range(eval_iter), desc="sample images"):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = (gen_net(z).mul_(127.5).add_(127.5).clamp_(
            0.0, 255.0).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy())
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(fid_buffer_dir,
                                     f"iter{iter_idx}_b{img_idx}.png")
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    # get inception score
    logger.info("=> calculate inception score")
    mean, std = get_inception_score(img_list)
    print(f"Inception score: {mean}")

    # get fid score
    logger.info("=> calculate fid score")
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat],
                                          inception_path=None)
    print(f"FID score: {fid_score}")

    if clean_dir:
        os.system("rm -r {}".format(fid_buffer_dir))
    else:
        logger.info(f"=> sampled images are saved to {fid_buffer_dir}")

    writer.add_image("sampled_images", img_grid, global_steps)
    writer.add_scalar("Inception_score/mean", mean, global_steps)
    writer.add_scalar("Inception_score/std", std, global_steps)
    writer.add_scalar("FID_score", fid_score, global_steps)

    writer_dict["valid_global_steps"] = global_steps + 1

    return mean, fid_score
Exemplo n.º 3
0
def validate(args, fixed_z, fid_stat, gen_net: nn.Module, writer_dict, epoch):
    np.random.seed(args.random_seed**2 + epoch)
    writer = writer_dict['writer']
    global_steps = writer_dict['valid_global_steps']

    # eval mode
    gen_net = gen_net.eval()

    # generate images
    sample_imgs = gen_net(fixed_z)
    img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)

    # get fid and inception score
    fid_buffer_dir = os.path.join(args.path_helper['sample_path'],
                                  'fid_buffer')
    os.makedirs(fid_buffer_dir)

    eval_iter = args.num_eval_imgs // args.eval_batch_size
    img_list = list()
    for iter_idx in range(eval_iter):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = gen_net(z).mul_(127.5).add_(127.5).clamp_(
            0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(fid_buffer_dir,
                                     'iter%d_b%d.png' % (iter_idx, img_idx))
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    # get inception score
    logger.info('=> calculate inception score')
    mean, std = get_inception_score(img_list)

    # get fid score
    logger.info('=> calculate fid score')
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat],
                                          inception_path=None)

    os.system('rm -r {}'.format(fid_buffer_dir))

    writer.add_image('sampled_images', img_grid, global_steps)
    writer.add_scalar('Inception_score/mean', mean, global_steps)
    writer.add_scalar('Inception_score/std', std, global_steps)
    writer.add_scalar('FID_score', fid_score, global_steps)

    writer_dict['valid_global_steps'] = global_steps + 1

    return mean, fid_score
Exemplo n.º 4
0
def validate(args, fixed_z, fid_stat, epoch, gen_net: nn.Module, writer_dict, clean_dir=True):
    writer = writer_dict['writer']
    global_steps = writer_dict['valid_global_steps']

    # eval mode
    gen_net = gen_net.eval()

    # generate images
    sample_imgs = gen_net(fixed_z, epoch)
    img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)

    # get fid and inception score
    fid_buffer_dir = os.path.join(args.path_helper['sample_path'], 'fid_buffer')
    os.makedirs(fid_buffer_dir, exist_ok=True)

    eval_iter = args.num_eval_imgs // args.eval_batch_size
    img_list = list()
    for iter_idx in tqdm(range(eval_iter), desc='sample images'):
        z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = gen_net(z, epoch).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu',
                                                                                                torch.uint8).numpy()
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(fid_buffer_dir, f'iter{iter_idx}_b{img_idx}.png')
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    # get inception score
    logger.info('=> calculate inception score')
    mean, std = get_inception_score(img_list)
    print(f"Inception score: {mean}")

    # get fid score
    logger.info('=> calculate fid score')
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat], inception_path=None)
    print(f"FID score: {fid_score}")

    if clean_dir:
        os.system('rm -r {}'.format(fid_buffer_dir))
    else:
        logger.info(f'=> sampled images are saved to {fid_buffer_dir}')
    # print('first')
    writer.add_image('sampled_images', img_grid, global_steps)
    writer.add_scalar('Inception_score/mean', mean, global_steps)
    writer.add_scalar('Inception_score/std', std, global_steps)
    writer.add_scalar('FID_score', fid_score, global_steps)
    # print('second')
    writer_dict['valid_global_steps'] = global_steps + 1
    # print('third')
    return mean, fid_score
Exemplo n.º 5
0
    def evaluate(self, model, noise_generator, data, device):
        print("Generating true data.")
        self.generate_true(data)
        print("Generating fake data.")
        self.generate_fake(model, noise_generator)

        paths = [self.fake_path, self.true_path]
        print("Evaluating FID Score.")
        fid_value = calculate_fid_given_paths(paths,
                                              self.batch_size,
                                              device,
                                              2048,  # defaults
                                              8)
        print('FID: ', fid_value)
        return fid_value
Exemplo n.º 6
0
def get_is(args, gen_net: nn.Module, num_img, z_numpy=None, get_is_score=True):
    """
    Get inception score.
    :param args:
    :param gen_net:
    :param num_img:
    :return: Inception score
    """

    # eval mode
    gen_net = gen_net.eval()

    eval_iter = num_img // args.eval_batch_size
    img_list = list()
    state_list = list()
    for i in range(eval_iter):
        # We use a fixed set of random seeds for the reward and progressive states in the search stage to stabalize the training
        np.random.seed(i)
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))
        # Generate a batch of images
        gen_imgs, gen_states = gen_net(z, eval=True)
        gen_imgs2 = gen_imgs.mul_(127.5).add_(127.5).clamp_(0.0,
                                                            255.0).permute(
                                                                0, 2, 3, 1)
        img_list.extend(list(gen_imgs2.to('cpu', torch.uint8).numpy()))
        state_list.extend(list(gen_states.to('cpu').numpy()))
    state = list(np.mean(state_list, axis=0).flatten())
    if not get_is_score:
        return state
    # get inception score
    logger.info('calculate Inception score...')
    mean, std = get_inception_score(img_list)
    logger.info('=> calculate fid score')
    fid_buffer_dir = os.path.join(args.path_helper['sample_path'],
                                  'fid_buffer')
    os.system('rm -rf {}'.format(fid_buffer_dir))
    os.makedirs(fid_buffer_dir, exist_ok=True)
    for img_idx, img in enumerate(img_list):
        if img_idx < 5000:
            file_name = os.path.join(fid_buffer_dir, f'iter0_b{img_idx}.png')
            imsave(file_name, img)
    fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    assert os.path.exists(fid_stat)
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat],
                                          inception_path=None)

    return mean, fid_score, state
Exemplo n.º 7
0
def get_fid(model, data_dir, reference_data, qfid=True, use_cache=True):
    """Computes FID or qFID given model

    Args:
        model (nn.Module): quantized model
        data_dir (str): path to directory where intermediate data will be stored
        model_path (str): path to full-precion
        reference_data (str): data to compare against, must be path to data generated by full precion model in case qfid==True
                              or else path to real data 
        qfid (bool, optional): if True returns qFID else FID. Defaults to True.
        use_cache (bool, optional): Use cached data or not. Defaults to True.


    Returns:
        float: FID or qFID
    """

    try:
        if use_cache:
            fid_sc = np.load(data_dir + f'/fid_full_{qfid}.npy')[0]
        else:
            raise Exception
    except:
        if qfid:
            fid_sc = compute_fid(model,
                                 cpu_inference=False,
                                 delete_cache=not use_cache,
                                 reference_data=reference_data,
                                 data_dir=data_dir,
                                 data_size=50000)
        else:
            assert len(
                glob.glob(data_dir +
                          '/*')) >= 50000, "Please run compute qfid at first"
            fid_sc = calculate_fid_given_paths([data_dir, reference_data], 32,
                                               True, 2048, not use_cache)

        np.save(data_dir + f'/fid_full_{qfid}.npy', np.array([fid_sc]))

    return fid_sc
Exemplo n.º 8
0
def calculate_metrics(fid_buffer_dir, num_eval_imgs, eval_batch_size, latent_dim, fid_stat, G, do_IS=False, do_FID=True):
	# eval mode
	G = G.eval()

	# get fid and inception score
	if do_IS and do_FID:
		if not os.path.isdir(fid_buffer_dir):
			os.mkdir(fid_buffer_dir)

		eval_iter = num_eval_imgs // eval_batch_size
		img_list = list()
		for iter_idx in range(eval_iter):
			z = torch.cuda.FloatTensor(np.random.normal(0, 1, (eval_batch_size, latent_dim)))

			# Generate a batch of images
			gen_imgs = G(z).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
			for img_idx, img in enumerate(gen_imgs):
				file_name = os.path.join(fid_buffer_dir, 'iter%d_b%d.png' % (iter_idx, img_idx))
				imsave(file_name, img)
			img_list.extend(list(gen_imgs))

	# get inception score
	if do_IS:
		print('=> calculate inception score')
		mean, std = get_inception_score(img_list)
	else:
		mean, std = 0, 0

	# get fid score
	if do_FID:
		print('=> calculate fid score')
		fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat], inception_path=None)
	else:
		fid_score = 0

	if do_IS and do_FID:
		os.system('rm -r {}'.format(fid_buffer_dir))

	return mean, fid_score
Exemplo n.º 9
0
def evaluate(args, fixed_z, fid_stat, gen_net: nn.Module):
    # eval mode
    gen_net = gen_net.eval()

    # generate images
    sample_imgs = gen_net(fixed_z)
    img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True)

    # get fid and inception score
    fid_buffer_dir = 'fid_buffer_test'
    if not os.path.exists(fid_buffer_dir): os.makedirs(fid_buffer_dir)

    eval_iter = args.num_eval_imgs // args.eval_batch_size
    img_list = list()
    for iter_idx in tqdm(range(eval_iter), desc='sample images'):
        z = torch.cuda.FloatTensor(
            np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))

        # Generate a batch of images
        gen_imgs = gen_net(z).mul_(127.5).add_(127.5).clamp_(
            0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
        for img_idx, img in enumerate(gen_imgs):
            file_name = os.path.join(fid_buffer_dir,
                                     'iter%d_b%d.png' % (iter_idx, img_idx))
            imsave(file_name, img)
        img_list.extend(list(gen_imgs))

    # get inception score
    print('=> calculate inception score')
    mean, std = get_inception_score(img_list)

    # get fid score
    print('=> calculate fid score')
    fid_score = calculate_fid_given_paths([fid_buffer_dir, fid_stat],
                                          inception_path=None)

    os.system('rm -r {}'.format(fid_buffer_dir))
    return mean, fid_score
Exemplo n.º 10
0
        for i, batch_test in enumerate(dataloader_test):
            # Set model input
            input_img_test = Variable(input_source_test.copy_(batch_test))
            # Generate output
            output_img_test = 0.5 * (netG(input_img_test).data + 1.0)
            # Save image files
            save_image(
                output_img_test,
                os.path.join(test_img_generation_dir_temp,
                             '%04d.png' % (i + 1)))

            sys.stdout.write('\rGenerated images %04d of %04d' %
                             (i + 1, len(dataloader_test)))
        print()
        # find FID:
        FID = calculate_fid_given_paths(img_paths_for_FID)
        if FID < best_FID:
            best_FID = FID
            torch.save(netG.state_dict(),
                       os.path.join(pth_dir, 'best_FID_netG.pth'))
            # torch.save(netD.state_dict(), os.path.join(pth_dir, 'best_FID_netD.pth'))
            best_FID_epoch = epoch

        f_FID = open(os.path.join(results_dir, 'FID_log.txt'), 'a+')
        f_FID.write('epoch %d: FID %s\n' % (epoch, FID))
        f_FID.close()
    ## End save best FID

    # save latest:
    save_ckpt_finetune(epoch,
                       netG,