def calculate_fid_given_paths(paths, img_size=256, batch_size=50): print('Calculating FID given paths %s and %s...' % (paths[0], paths[1])) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') inception = InceptionV3().eval().to(device) loaders = [get_eval_loader(path, img_size, batch_size) for path in paths] mu, cov = [], [] for loader in loaders: actvs = [] for x in tqdm(loader, total=len(loader)): actv = inception(x.to(device)) actvs.append(actv) actvs = torch.cat(actvs, dim=0).cpu().detach().numpy() mu.append(np.mean(actvs, axis=0)) cov.append(np.cov(actvs, rowvar=False)) fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1]) return fid_value
def calculate_fid_given_paths(paths, img_size=256, batch_size=50): print('Calculating FID given paths %s and %s...' % (paths[0], paths[1])) device = porch.device('cuda' if porch.cuda.is_available() else 'cpu') inception = InceptionV3("./metrics/inception_v3_pretrained.pdparams") inception.eval() loaders = [get_eval_loader(path, img_size, batch_size) for path in paths] mu, cov = [], [] for loader in loaders: actvs = [] for x in tqdm(loader, total=len(loader)): x = porch.varbase_to_tensor(x[0]) actv = inception(x) actvs.append(actv) actvs = porch.cat(actvs, dim=0).numpy() mu.append(np.mean(actvs, axis=0)) cov.append(np.cov(actvs, rowvar=False)) fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1]) return fid_value.astype(float)
def calculate_metrics(nets, args, step, mode): print('Calculating evaluation metrics...') assert mode in ['latent', 'reference'] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') domains = os.listdir(args.val_img_dir) domains.sort() num_domains = len(domains) print('Number of domains: %d' % num_domains) lpips_dict = OrderedDict() for trg_idx, trg_domain in enumerate(domains): src_domains = [x for x in domains if x != trg_domain] if mode == 'reference': path_ref = os.path.join(args.val_img_dir, trg_domain) loader_ref = get_eval_loader(root=path_ref, img_size=args.img_size, batch_size=args.val_batch_size, imagenet_normalize=False, drop_last=True) for src_idx, src_domain in enumerate(src_domains): path_src = os.path.join(args.val_img_dir, src_domain) loader_src = get_eval_loader(root=path_src, img_size=args.img_size, batch_size=args.val_batch_size, imagenet_normalize=False) task = '%s2%s' % (src_domain, trg_domain) path_fake = os.path.join(args.eval_dir, task) shutil.rmtree(path_fake, ignore_errors=True) os.makedirs(path_fake) lpips_values = [] iter_ref = iter(loader_ref) print('Generating images and calculating LPIPS for %s...' % task) for i, x_src in enumerate(tqdm(loader_src, total=len(loader_src))): N = x_src.size(0) x_src = x_src.to(device) y_trg = torch.tensor([trg_idx] * N).to(device) masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None # generate 10 outputs from the same input group_of_images = [] for j in range(args.num_outs_per_domain): if mode == 'latent': z_trg = torch.randn(N, args.latent_dim).to(device) s_trg = nets.mapping_network(z_trg, y_trg) else: try: x_ref = next(iter_ref).to(device) except (NameError, StopIteration): iter_ref = iter(loader_ref) x_ref = next(iter_ref).to(device) if x_ref.size(0) > N: x_ref = x_ref[:N] s_trg = nets.style_encoder(x_ref, y_trg) x_fake = nets.generator(x_src, s_trg, masks=masks) group_of_images.append(x_fake) # save generated images to calculate FID later for k in range(N): filename = os.path.join( path_fake, '%.4i_%.2i.png' % (i * args.val_batch_size + (k + 1), j + 1)) utils.save_image(x_fake[k], ncol=1, filename=filename) lpips_value = calculate_lpips_given_images(group_of_images) lpips_values.append(lpips_value) # calculate LPIPS for each task (e.g. cat2dog, dog2cat) lpips_mean = np.array(lpips_values).mean() lpips_dict['LPIPS_%s/%s' % (mode, task)] = lpips_mean # delete dataloaders del loader_src if mode == 'reference': del loader_ref del iter_ref # calculate the average LPIPS for all tasks lpips_mean = 0 for _, value in lpips_dict.items(): lpips_mean += value / len(lpips_dict) lpips_dict['LPIPS_%s/mean' % mode] = lpips_mean # report LPIPS values filename = os.path.join(args.eval_dir, 'LPIPS_%.5i_%s.json' % (step, mode)) utils.save_json(lpips_dict, filename) # calculate and report fid values calculate_fid_for_all_tasks(args, domains, step=step, mode=mode)