예제 #1
0
def get_fid(fakes, model, npz, device, batch_size=1, use_tqdm=True):
    m1, s1 = npz['mu'], npz['sigma']
    fakes = torch.cat(fakes, dim=0)
    fakes = util.tensor2im(fakes).astype(float)
    m2, s2 = _compute_statistics_of_ims(fakes, model, batch_size, 2048,
                                        device, use_tqdm=use_tqdm)
    return float(calculate_frechet_distance(m1, s1, m2, s2))
예제 #2
0
def main(opt):
    dataloader = create_dataloader(opt)
    device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids \
        else torch.device('cpu')
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
    inception_model = InceptionV3([block_idx])
    inception_model.to(device)
    inception_model.eval()

    tensors = []
    for i, data_i in enumerate(tqdm.tqdm(dataloader)):
        if opt.dataset_mode in ['single', 'aligned']:
            tensor = data_i[opt.direction[-1]]
        else:
            tensor = data_i['image']
        tensors.append(tensor)
    tensors = torch.cat(tensors, dim=0)
    tensors = util.tensor2im(tensors).astype(float)
    mu, sigma = _compute_statistics_of_ims(tensors,
                                           inception_model,
                                           32,
                                           2048,
                                           device,
                                           use_tqdm=True)
    np.savez(opt.output_path, mu=mu, sigma=sigma)
예제 #3
0
def get_fid_new(reals, fakes, model, device, batch_size=1, use_tqdm=True):
    reals = torch.cat(reals, dim=0)
    reals = util.tensor2im(reals).astype(float)
    fakes = torch.cat(fakes, dim=0)
    fakes = util.tensor2im(fakes).astype(float)

    m1, s1 = _compute_statistics_of_ims(reals,
                                        model,
                                        batch_size,
                                        2048,
                                        device,
                                        use_tqdm=use_tqdm,
                                        median=False)
    m2, s2 = _compute_statistics_of_ims(fakes,
                                        model,
                                        batch_size,
                                        2048,
                                        device,
                                        use_tqdm=use_tqdm,
                                        median=False)
    fid_mean = float(calculate_frechet_distance(m1, s1, m2, s2, median=False))

    m1, s1 = _compute_statistics_of_ims(reals,
                                        model,
                                        batch_size,
                                        2048,
                                        device,
                                        use_tqdm=use_tqdm,
                                        median=True)
    m2, s2 = _compute_statistics_of_ims(fakes,
                                        model,
                                        batch_size,
                                        2048,
                                        device,
                                        use_tqdm=use_tqdm,
                                        median=True)
    fid_median = float(calculate_frechet_distance(m1, s1, m2, s2, median=True))

    return fid_mean, fid_median
예제 #4
0
def main(opt):
    dataloader = create_dataset(opt)
    device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids \
        else torch.device('cpu')
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
    inception_model = InceptionV3([block_idx])
    inception_model.to(device)
    inception_model.eval()

    tensors = {}
    for i, data_i in enumerate(dataloader):
        tensor = data_i['B' if opt.direction == 'AtoB' else 'A']
        tensors[data_i['B_paths' if opt.direction == 'AtoB' else 'A_paths']
                [0]] = tensor
    tensors = torch.cat(list(tensors.values()), dim=0)
    tensors = util.tensor2imgs(tensors).astype(float)
    mu, sigma = _compute_statistics_of_ims(tensors,
                                           inception_model,
                                           32,
                                           2048,
                                           device,
                                           use_tqdm=True)
    np.savez(opt.output_path, mu=mu, sigma=sigma)