Esempio n. 1
0
 def eval_fid(fake_images):
     output_images_path = os.path.join(opt.output_path, opt.version, "test")
     os.makedirs(output_images_path, exist_ok=True)
     print("Saving images generated for testing...")
     for i in range(fake_images.size(0)):
         save_image(fake_images[i, :, :, :],
                    "{}/{}.jpg".format(output_images_path, i))
     print("Calculating FID...")
     fid = fid_score.calculate_fid_given_paths(
         (output_images_path, test_images_path), opt.batch_size, device)
     return fid
Esempio n. 2
0
 def eval_fid(gen_images_path, eval_images_path):        
     print("Calculating FID...")
     fid = fid_score.calculate_fid_given_paths((gen_images_path, eval_images_path), opt.batch_size, device)
     return fid
Esempio n. 3
0
def sampling(text_encoder, image_encoder, netG, dataloader, num_samples,
             metric, output_dir, logger):

    model_dir = f'{output_dir}/models'

    model_list = sorted(glob.glob(f'{model_dir}/netG_*.pth'))[start_epoch:]

    results = {'r_epoch': 0, 'R_mean': 0, 'f_epoch': 0, 'fid': 1000}

    for model in model_list:

        epoch = model.split('netG_')[-1].replace('.pth', '')
        module_model_dict = torch.load(model, map_location='cpu')
        ddp = False
        for module in module_model_dict:
            if 'module.' in module:
                ddp = True
                break
            else:
                break

        if ddp:
            model_dict = OrderedDict()
            for key in module_model_dict:
                k = key.split('module.')[-1]
                model_dict[k] = module_model_dict[key]
            netG.load_state_dict(model_dict)
        else:
            netG.load_state_dict(module_model_dict)

        netG.eval()
        netG.cuda()

        cnt = 0
        R_count = 0
        R = np.zeros(num_samples)
        cont = True

        for ii in range(11):
            if (cont == False):
                break
            for data in tqdm(dataloader):
                if cont == False:
                    break
                with torch.no_grad():

                    imgs, sent_embs, keys = data
                    sent_embs = sent_embs.cuda()

                    #######################################################
                    # (2) Generate fake images
                    ######################################################

                    #noise = torch.randn(sent_embs.size(0), 100)
                    #noise=noise.cuda()
                    noise = truncated_z_sample(sent_embs.size(0),
                                               100,
                                               seed=100)
                    noise = torch.from_numpy(noise).float().cuda()
                    fake_imgs = netG(noise, sent_embs)

                    for j in range(sent_embs.size(0)):
                        im = fake_imgs[j].data.cpu().numpy()
                        im = (im + 1.0) * 127.5
                        im = im.astype(np.uint8)
                        im = np.transpose(im, (1, 2, 0))
                        im = Image.fromarray(im)

                        fullpath = f'{img_dir}/{ii}_{keys[j]}.png'
                        im.save(fullpath)
                        cnt += 1

                    if cnt >= num_samples:
                        cont = False

                    s_r = ''
                    s_fid = ''

                    # if metric == 'r-precision' or metric == 'both':
                    #     _, cnn_code = image_encoder(fake_imgs)

                    #     for i in range(batch_size):
                    #         mis_captions,mis_captions_len = dataset.get_mis_caption(class_ids[i])
                    #         hidden = text_encoder.init_hidden(99)
                    #         _,sent_emb_t = text_encoder(mis_captions,mis_captions_len,hidden)
                    #         rnn_code = torch.cat((sent_embs[i, :].unsqueeze(0), sent_emb_t), 0)
                    #         scores = torch.mm(cnn_code[i].unsqueeze(0), rnn_code.transpose(0, 1))  # 1* 100
                    #         cnn_code_norm = torch.norm(cnn_code[i].unsqueeze(0), 2, dim=1, keepdim=True)
                    #         rnn_code_norm = torch.norm(rnn_code, 2, dim=1, keepdim=True)
                    #         norm = torch.mm(cnn_code_norm, rnn_code_norm.transpose(0, 1))
                    #         scores0 = scores / norm.clamp(min=1e-8)
                    #         if torch.argmax(scores0) == 0:
                    #             R[R_count] = 1
                    #         R_count += 1

                    #     if R_count >= num_samples:
                    #         sum = np.zeros(10)
                    #         np.random.shuffle(R)
                    #         assert num_samples%10 == 0

                    #         for i in range(10):
                    #             sum[i] = np.average(R[i * int(num_samples//10):(i + 1) * int(num_samples//10) - 1])
                    #         R_mean = np.average(sum)
                    #         R_std = np.std(sum)

                    #         s_r = f' R mean:{R_mean:.4f} std:{R_std:.4f} '

                    #         if results['R_mean'] < R_mean:
                    #             results['r_epoch'] = epoch
                    #             results['R_mean'] = R_mean

                    if metric == 'fid' or metric == 'both' and cnt >= num_samples:
                        paths = ["", ""]
                        paths[0] = f'eval/coco_val.npz'
                        paths[1] = f'{img_dir}/'
                        fid_value = calculate_fid_given_paths(
                            paths, 50, True, 2048)
                        s_fid = f'FID: {fid_value}'
                        if fid_value < results['fid']:
                            results['f_epoch'] = epoch
                            results['fid'] = fid_value

                    if cnt >= num_samples:

                        s = f'epoch : {epoch} {s_r} {s_fid}'
                        #print(s)
                        logger.info(s)

    s_res = f"Best models is {results['r_epoch']} with R mean : {results['R_mean']} {results['f_epoch']} with fid : {results['fid']}"
    logger.info(s_res)