def compute_fid(netG, data_dir, reference_data, cpu_inference=False, data_size=50000, delete_cache=False): original_data_path = reference_data + "/distil_pics/" if delete_cache: for file_img in glob.glob(data_dir + '/*.png'): os.remove(file_img) if not os.path.exists(data_dir): os.makedirs(data_dir) if len(glob.glob(data_dir + '/*')) < 50000: print( "Here is no generated data, so I generate it using provided model") b_size = 50 eval_dataloader = DataLoader(PairedImageDataset(reference_data), batch_size=b_size, shuffle=False, num_workers=4, drop_last=False) input_eval_source = torch.cuda.FloatTensor(b_size, 512) netG.eval() for i_eval_img, eval_batch in tqdm(enumerate(eval_dataloader)): input_img = Variable(input_eval_source.copy_(eval_batch['input'])) with torch.no_grad(): if cpu_inference: input_img = input_img.cpu() output_img = netG(input_img) for i_img_from_batch in range(b_size): img_np = output_img[i_img_from_batch:( i_img_from_batch + 1)].detach().cpu().numpy() img_np = np.moveaxis(img_np, 1, -1) img_np = np.clip((img_np + 1) / 2, 0, 1) # (-1,1) -> (0,1) imsave( os.path.join( data_dir, '%s.png' % (i_eval_img * b_size + i_img_from_batch)), img_as_ubyte(img_np[0])) if i_eval_img + 1 == data_size: break else: pass #print(f"I found {len(glob.glob(data_dir + '/*.png'))} pictures in the folder") paths = [data_dir, original_data_path] fid = calculate_fid_given_paths(paths, 32, True, 2048, delete_cache) return fid
args.base_model_str, 'pth', 'latest.pth') netD.load_state_dict(torch.load('initial_weights/netD_B_seed_1.pth.tar')) print('load D from %s' % 'initial_weights') start_epoch = 0 best_FID = 1e9 loss_G_lst, loss_G_perceptual_lst, loss_G_GAN_lst, loss_D_lst = [], [], [], [] # Dataset loader: img shape=(256,256) dataset_dir = os.path.join(foreign_dir, 'datasets', args.dataset) soft_data_dir = os.path.join(foreign_dir, 'train_set_result', args.dataset) transforms_ = [ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] # (0,1) -> (-1,1) dataloader = DataLoader(PairedImageDataset(dataset_dir, soft_data_dir, transforms_=transforms_, mode=args.task), batch_size=args.batch_size, shuffle=True, num_workers=args.cpus, drop_last=True) dataloader_test = DataLoader(ImageDataset(os.path.join(dataset_dir, 'test', source_str), transforms_=transforms_), batch_size=1, shuffle=False, num_workers=args.cpus) # FID img dirs: test_img_generation_dir_temp = os.path.join(results_dir, 'test_set_generation_temp')