예제 #1
0
def test_vae_fid(model, args, total_fid_samples):
    dims = 2048
    device = 'cuda'
    num_gpus = args.num_process_per_node * args.num_proc_node
    num_sample_per_gpu = int(np.ceil(total_fid_samples / num_gpus))

    g = create_generator_vae(model, args.batch_size, num_sample_per_gpu)
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx], model_dir=args.fid_dir).to(device)
    m, s = compute_statistics_of_generator(g,
                                           model,
                                           args.batch_size,
                                           dims,
                                           device,
                                           max_samples=num_sample_per_gpu)

    # share m and s
    m = torch.from_numpy(m).cuda()
    s = torch.from_numpy(s).cuda()
    # take average across gpus
    utils.average_tensor(m, args.distributed)
    utils.average_tensor(s, args.distributed)

    # convert m, s
    m = m.cpu().numpy()
    s = s.cpu().numpy()

    # load precomputed m, s
    path = os.path.join(args.fid_dir, args.dataset + '.npz')
    m0, s0 = load_statistics(path)

    fid = calculate_frechet_distance(m0, s0, m, s)
    return fid
예제 #2
0
 def __init__(self, noise_sampler, config, limit=100):
     super().__init__()
     # dimension of inception feature vector for FID
     # can be 64, 192, 768 and 2048
     self.dims = 2048
     self.config = config
     # generator and inception model for FID estimation can be placed on different devices
     self.generator_device = torch.device(config.DEVICE)
     if hasattr(config, 'ESTIMATOR_DEVICE'):
         self.device = torch.device(config.ESTIMATOR_DEVICE)
     else:
         self.device = self.generator_device
     block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
     self.model = InceptionV3([block_idx]).to(self.device).eval()
     self.noise_sampler = noise_sampler
     self.limit = limit
예제 #3
0
파일: fid_score.py 프로젝트: NVlabs/NVAE
def calculate_fid_given_paths(paths, batch_size, device, dims):
    """Calculates the FID of two paths"""
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx]).to(device)

    m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size,
                                         dims, device)
    m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size,
                                         dims, device)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value
예제 #4
0
def main(args):
    device = 'cuda'
    dims = 2048
    # for binary datasets including MNIST and OMNIGLOT, we don't apply binarization for FID computation
    train_queue, valid_queue, _ = get_loaders_eval(args.dataset, args)
    print('len train queue', len(train_queue), 'len val queue',
          len(valid_queue), 'batch size', args.batch_size)
    if args.dataset in {'celeba_256', 'omniglot'}:
        train_queue = chain(train_queue, valid_queue)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx], model_dir=args.fid_dir).to(device)
    m, s = compute_statistics_of_generator(train_queue, model, args.batch_size,
                                           dims, device, args.max_samples)
    file_path = os.path.join(args.fid_dir, args.dataset + '.npz')
    print('saving fid stats at %s' % file_path)
    save_statistics(file_path, m, s)
예제 #5
0
def calculate_inception_features_for_gen_evaluation(flags,
                                                    paths,
                                                    modality=None,
                                                    dims=2048,
                                                    batch_size=128):
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx],
                        path_state_dict=flags.inception_state_dict)
    model = model.to(flags.device)

    if 'random' in paths:
        dir_rand_gen = paths['random']
        if not os.path.exists(dir_rand_gen):
            raise RuntimeError('Invalid path: %s' % dir_rand_gen)
        if modality is not None:
            files_rand_gen = glob.glob(
                os.path.join(dir_rand_gen, modality, '*' + '.png'))
            filename_random = os.path.join(
                flags.dir_gen_eval_fid_random,
                'random_sampling_' + modality + '_activations.npy')
        else:
            files_rand_gen = glob.glob(os.path.join(dir_rand_gen, '*.png'))
            filename_random = os.path.join(flags.dir_gen_eval_fid_random,
                                           'random_img_activations.npy')
        act_rand_gen = get_activations(files_rand_gen,
                                       model,
                                       batch_size,
                                       dims,
                                       True,
                                       verbose=False)
        np.save(filename_random, act_rand_gen)

    if 'dynamic_prior' in paths:
        dirs_dyn_prior = paths['dynamic_prior']
        for k, key in enumerate(dirs_dyn_prior.keys()):
            if not os.path.exists(dirs_dyn_prior[key]):
                raise RuntimeError('Invalid path: %s' % dirs_dyn_prior[key])
            files_dyn_gen = glob.glob(
                os.path.join(dirs_dyn_prior[key], modality, '*' + '.png'))
            filename_dyn = os.path.join(
                dirs_dyn_prior[key], key + '_' + modality + '_activations.npy')
            act_cond_gen = get_activations(files_dyn_gen,
                                           model,
                                           batch_size,
                                           dims,
                                           True,
                                           verbose=False)
            np.save(filename_dyn, act_cond_gen)

    if 'conditional' in paths:
        dir_cond_gen = paths['conditional']
        if not os.path.exists(dir_cond_gen):
            raise RuntimeError('Invalid path: %s' % dir_cond_gen)
        if modality is not None:
            files_cond_gen = glob.glob(
                os.path.join(dir_cond_gen, modality, '*' + '.png'))
            filename_conditional = os.path.join(
                dir_cond_gen, 'cond_gen_' + modality + '_activations.npy')
        else:
            files_cond_gen = glob.glob(os.path.join(dir_cond_gen, '*.png'))
            filename_conditional = os.path.join(
                flags.dir_gen_eval_fid_cond_gen,
                'conditional_img_activations.npy')
        act_cond_gen = get_activations(files_cond_gen,
                                       model,
                                       batch_size,
                                       dims,
                                       True,
                                       verbose=False)
        np.save(filename_conditional, act_cond_gen)

    if 'conditional_2a1m' in paths:
        dirs_cond_gen = paths['conditional_2a1m']
        for k, key in enumerate(dirs_cond_gen.keys()):
            if not os.path.exists(dirs_cond_gen[key]):
                raise RuntimeError('Invalid path: %s' % dirs_cond_gen[key])
            files_cond_gen = glob.glob(
                os.path.join(dirs_cond_gen[key], modality, '*' + '.png'))
            filename_conditional = os.path.join(
                dirs_cond_gen[key], key + '_' + modality + '_activations.npy')
            act_cond_gen = get_activations(files_cond_gen,
                                           model,
                                           batch_size,
                                           dims,
                                           True,
                                           verbose=False)
            np.save(filename_conditional, act_cond_gen)

    if 'conditional_1a2m' in paths:
        dirs_cond_gen = paths['conditional_1a2m']
        for k, key in enumerate(dirs_cond_gen.keys()):
            if not os.path.exists(dirs_cond_gen[key]):
                raise RuntimeError('Invalid path: %s' % dirs_cond_gen[key])
            files_cond_gen = glob.glob(
                os.path.join(dirs_cond_gen[key], modality, '*' + '.png'))
            filename_conditional = os.path.join(
                dirs_cond_gen[key], key + '_' + modality + '_activations.npy')
            act_cond_gen = get_activations(files_cond_gen,
                                           model,
                                           batch_size,
                                           dims,
                                           True,
                                           verbose=False)
            np.save(filename_conditional, act_cond_gen)

    if 'real' in paths:
        dir_real = paths['real']
        if not os.path.exists(dir_real):
            raise RuntimeError('Invalid path: %s' % dir_real)
        if modality is not None:
            files_real = glob.glob(
                os.path.join(dir_real, modality, '*' + '.png'))
            filename_real = os.path.join(
                flags.dir_gen_eval_fid_real,
                'real_' + modality + '_activations.npy')
        else:
            files_real = glob.glob(os.path.join(dir_real, '*.png'))
            filename_real = os.path.join(flags.dir_gen_eval_fid_real,
                                         'real_img_activations.npy')
        act_real = get_activations(files_real,
                                   model,
                                   batch_size,
                                   dims,
                                   True,
                                   verbose=False)
        np.save(filename_real, act_real)