Example #1
0
def calculate_fid_for_all_tasks(args, domains, step, mode):
    print('Calculating FID for all tasks...')
    fid_values = OrderedDict()
    for trg_domain in domains:
        src_domains = [x for x in domains if x != trg_domain]

        for src_domain in src_domains:
            task = '%s2%s' % (src_domain, trg_domain)
            path_real = os.path.join(args.train_img_dir, trg_domain)
            path_fake = os.path.join(args.eval_dir, task)
            print('Calculating FID for %s...' % task)
            fid_value = calculate_fid_given_paths(
                paths=[path_real, path_fake],
                img_size=args.img_size,
                batch_size=args.val_batch_size)
            fid_values['FID_%s/%s' % (mode, task)] = fid_value

    # calculate the average FID for all tasks
    fid_mean = 0
    for _, value in fid_values.items():
        fid_mean += value / len(fid_values)
    fid_values['FID_%s/mean' % mode] = fid_mean

    # report FID values
    filename = os.path.join(args.eval_dir, 'FID_%.5i_%s.json' % (step, mode))
    utils.save_json(fid_values, filename)
Example #2
0
def calculate_total_fid(nets_ema, args, step, keep_samples=False):
    target_path = args.eval_path
    sample_path = get_sample_path(args.eval_dir, step)
    generate_samples(nets_ema, args, sample_path)
    fid = calculate_fid_given_paths(paths=[target_path, sample_path],
                                    img_size=args.img_size,
                                    batch_size=args.eval_batch_size,
                                    use_cache=args.eval_cache)
    if not keep_samples:
        delete_dir(sample_path)
    return fid
Example #3
0
 def evaluate(self):
     args = self.args
     assert args.eval_path != "", "eval_path shouldn't be empty"
     target_path = args.eval_path
     sample_path = self.sample()
     fid = calculate_fid_given_paths(paths=[target_path, sample_path],
                                     img_size=args.img_size,
                                     batch_size=args.eval_batch_size)
     print(f"FID is: {fid}")
     send_message(f"Sample {args.sample_id}'s FID is {fid}")
     if not args.keep_all_eval_samples:
         delete_dir(sample_path)
Example #4
0
def calculate_fid(args, sample_path):
    fid_list = []
    for src_domain in args.domains:
        target_domains = [
            domain for domain in args.domains if domain != src_domain
        ]
        for trg_domain in target_domains:
            task = f"{src_domain}2{trg_domain}"
            path_real = os.path.join(args.eval_path, src_domain)
            path_fake = os.path.join(sample_path, task)
            print(f'Calculating FID for {task}...')
            fid = calculate_fid_given_paths(paths=[path_real, path_fake],
                                            img_size=args.img_size,
                                            batch_size=args.eval_batch_size,
                                            use_cache=args.eval_cache)
            fid_list.append(fid)
            write_record(f"FID for {task}: {fid}", args.record_file)
    fid_mean = sum(fid_list) / len(fid_list)
    write_record(f"FID mean: {fid_mean}", args.record_file)
    return fid_mean
Example #5
0
def calculate_fid_for_all_tasks(args, domains, step, mode, dataset_dir=''):
    print('Calculating FID for all tasks...')
    fid_values = OrderedDict()
    for trg_domain in domains:
        task = '%s' % trg_domain
        path_real = args.val_img_dir
        print('Calculating FID for %s...' % task)
        fid_value = calculate_fid_given_paths(paths=[path_real, args.eval_dir],
                                              img_size=args.img_size,
                                              batch_size=args.val_batch_size,
                                              trg_domain=trg_domain,
                                              dataset_dir=dataset_dir)
        fid_values['FID_%s/%s' % (mode, task)] = fid_value

    # calculate the average FID for all tasks
    fid_mean = 0
    for _, value in fid_values.items():
        fid_mean += value / len(fid_values)
    fid_values['FID_%s/mean' % mode] = fid_mean

    # report FID values
    filename = os.path.join(args.eval_dir, 'FID_%.5i_%s.json' % (step, mode))
    utils.save_json(fid_values, filename)
    return fid_values, fid_mean
    def evaluate_fc2(self,
                     args,
                     n_styles,
                     epochs,
                     n_epochs,
                     emphasis_parameter,
                     batchsize=16,
                     learning_rate=1e-3,
                     dset='FC2'):
        print('Calculating evaluation metrics...')
        #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        data_dir = "G:/Datasets/FC2/DATAFiles/"
        style_dir = "G:/Datasets/FC2/styled-files/"
        temp_dir = "G:/Datasets/FC2/styled-files3/"
        eval_dir = os.getcwd() + "/eval_fc2/" + self.method + "/"

        num_workers = 0
        args.batch_size = 4

        domains = os.listdir(style_dir)
        domains.sort()
        num_domains = len(domains)
        print('Number of domains: %d' % num_domains)
        print("Batch Size:", args.batch_size)

        _, eval_loader = get_loaderFC2(data_dir, style_dir, temp_dir,
                                       args.batch_size, num_workers,
                                       num_domains)

        tmp_dir = self.train_dir + dset + '/' + self.method + '/'
        tmp_list = os.listdir(tmp_dir)
        tmp_list.sort()

        models = []
        pre_models = []
        if n_styles > 1:
            model = FastStyleNet(3, n_styles).to(self.device)
            model.load_state_dict(
                torch.load(tmp_dir + '/' + tmp_list[0] + '/epoch_' +
                           str(n_epochs) + '.pth'))
        else:
            if self.method == "ruder":
                for tmp in tmp_list:
                    model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device)
                    model.load_state_dict(
                        torch.load(tmp_dir + '/' + tmp + '/epoch_' +
                                   str(n_epochs) + '.pth'))
                    models.append(model)
                    pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + tmp[
                        3] + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth"
                    model = FastStyleNet(3, n_styles).to(self.device)
                    model.load_state_dict(torch.load(pre_style_path))
                    pre_models.append(model)
            else:
                for tmp in tmp_list:
                    model = FastStyleNet(3, n_styles).to(self.device)
                    model.load_state_dict(
                        torch.load(tmp_dir + '/' + tmp + '/epoch_' +
                                   str(n_epochs) + '.pth'))
                    models.append(model)

        generate_new = True

        tcl_dict = {}
        # prepare
        for d in range(1, num_domains):
            src_domain = "style0"
            trg_domain = "style" + str(d)

            t1 = '%s2%s' % (src_domain, trg_domain)
            t2 = '%s2%s' % (trg_domain, src_domain)

            tcl_dict[t1] = []
            tcl_dict[t2] = []

            if generate_new:
                create_task_folders(eval_dir, t1)
                #create_task_folders(eval_dir, t2)

        # generate
        for i, x_src_all in enumerate(tqdm(eval_loader,
                                           total=len(eval_loader))):
            x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all

            x_real = x_real.to(self.device)
            x_real2 = x_real2.to(self.device)
            y_org = y_org.to(self.device)
            x_ref = x_ref.to(self.device)
            y_trg = y_trg.to(self.device)
            mask = mask.to(self.device)
            flow = flow.to(self.device)

            N = x_real.size(0)

            for k in range(N):
                y_org_np = y_org[k].cpu().numpy()
                y_trg_np = y_trg[k].cpu().numpy()
                src_domain = "style" + str(y_org_np)
                trg_domain = "style" + str(y_trg_np)

                if src_domain == trg_domain or y_trg_np == 0:
                    continue

                task = '%s2%s' % (src_domain, trg_domain)

                if n_styles > 1:
                    self.model = model
                else:
                    self.model = models[y_trg_np - 1]

                if self.method == "ruder":
                    self.pre_style_model = pre_models[y_trg_np - 1]

                x_fake = self.infer_method((x_real, None, None), y_trg[k] - 1)
                #x_fake = torch.clamp(x_fake, 0.0, 1.0)
                x_warp = warp(x_fake, flow)
                x_fake2 = self.infer_method((x_real2, mask, x_warp),
                                            y_trg[k] - 1)
                #x_fake2 = torch.clamp(x_fake2, 0.0, 1.0)

                tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2,
                                                                     3))**0.5

                tcl_dict[task].append(tcl_err[k].cpu().numpy())

                path_ref = os.path.join(eval_dir, task + "/ref")
                path_fake = os.path.join(eval_dir, task + "/fake")

                if generate_new:
                    filename = os.path.join(
                        path_ref, '%.4i.png' % (i * args.batch_size + (k + 1)))
                    save_image(denormalize(x_ref[k]),
                               ncol=1,
                               filename=filename)

                filename = os.path.join(
                    path_fake, '%.4i.png' % (i * args.batch_size + (k + 1)))
                save_image(x_fake[k], ncol=1, filename=filename)

        # evaluate
        print("computing fid, lpips and tcl")

        tasks = [
            dir for dir in os.listdir(eval_dir)
            if os.path.isdir(os.path.join(eval_dir, dir))
        ]
        tasks.sort()

        # fid and lpips
        fid_values = OrderedDict()
        #lpips_dict = OrderedDict()
        tcl_values = OrderedDict()
        for task in tasks:
            print(task)
            path_ref = os.path.join(eval_dir, task + "/ref")
            path_fake = os.path.join(eval_dir, task + "/fake")

            tcl_data = tcl_dict[task]

            print("TCL", len(tcl_data))
            tcl_mean = np.array(tcl_data).mean()
            print(tcl_mean)
            tcl_values['TCL_%s' % (task)] = float(tcl_mean)

            print("FID")
            fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake],
                                                  img_size=256,
                                                  batch_size=args.batch_size)
            fid_values['FID_%s' % (task)] = fid_value

        # calculate the average FID for all tasks
        fid_mean = 0
        for key, value in fid_values.items():
            fid_mean += value / len(fid_values)

        fid_values['FID_mean'] = fid_mean

        # report FID values
        filename = os.path.join(eval_dir, 'FID.json')
        utils.save_json(fid_values, filename)

        # calculate the average TCL for all tasks
        tcl_mean = 0
        for _, value in tcl_values.items():
            tcl_mean += value / len(tcl_values)

        tcl_values['TCL_mean'] = float(tcl_mean)

        # report TCL values
        filename = os.path.join(eval_dir, 'TCL.json')
        utils.save_json(tcl_values, filename)
Example #7
0
def calculate_metrics(nets, args, step, mode, eval_loader):
    print('Calculating evaluation metrics...')
    assert mode in ['latent', 'reference']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    domains = os.listdir(args.style_dir)
    domains.sort()
    num_domains = len(domains)
    print('Number of domains: %d' % num_domains)
    
    #generate_new = True

    #num_files = sum([len(files) for r, d, files in os.walk(args.eval_dir)])
    #print("num_files", num_files, len(eval_loader), (1 + args.num_outs_per_domain)*len(eval_loader)*args.batch_size)

    #if num_files != (1 + args.num_outs_per_domain)*len(eval_loader):
    #shutil.rmtree(args.eval_dir, ignore_errors=True)
    #os.makedirs(args.eval_dir)
    generate_new = True
    
    tcl_dict = {}
    # prepare
    for d in range(1, num_domains):
      src_domain = "style0"
      trg_domain = "style" + str(d)
      
      t1 = '%s2%s' % (src_domain, trg_domain)
      t2 = '%s2%s' % (trg_domain, src_domain)
      
      tcl_dict[t1] = []
      tcl_dict[t2] = []
      
      if generate_new:
        create_task_folders(args, t1)
        create_task_folders(args, t2)

    # generate
    for i, x_src_all in enumerate(tqdm(eval_loader, total=len(eval_loader))):
      x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all
      
      x_real = x_real.to(device)
      x_real2 = x_real2.to(device)
      y_org = y_org.to(device)
      x_ref = x_ref.to(device)
      y_trg = y_trg.to(device)
      mask = mask.to(device)
      flow = flow.to(device)
      
      N = x_real.size(0)
      masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

      for j in range(args.num_outs_per_domain):
        if mode == 'latent':
          z_trg = torch.randn(N, args.latent_dim).to(device)
          s_trg = nets.mapping_network(z_trg, y_trg)
        else:
          s_trg = nets.style_encoder(x_ref, y_trg)
        
        
        
        x_fake = nets.generator(x_real, s_trg, masks=masks)
        x_fake2 = nets.generator(x_real2, s_trg, masks=masks)
        
        
        
        x_warp = warp(x_fake, flow)
        tcl_err = ((mask*(x_fake2 - x_warp))**2).mean(dim=(1, 2, 3))**0.5
      
        for k in range(N):
          src_domain = "style" + str(y_org[k].cpu().numpy())
          trg_domain = "style" + str(y_trg[k].cpu().numpy())
          
          if src_domain == trg_domain:
            continue
          
          task = '%s2%s' % (src_domain, trg_domain)
          
          tcl_dict[task].append(tcl_err[k].cpu().numpy())

          path_ref = os.path.join(args.eval_dir, task + "/ref")
          path_fake = os.path.join(args.eval_dir, task + "/fake")

          #if not os.path.exists(path_ref):
          #  os.makedirs(path_ref)
            
          #if not os.path.exists(path_fake):
          #  os.makedirs(path_fake)
          
          if generate_new:
            filename = os.path.join(path_ref, '%.4i_%.2i.png' % (i*args.val_batch_size+(k+1), j+1))
            utils.save_image(x_ref[k], ncol=1, filename=filename)
          
          filename = os.path.join(path_fake, '%.4i_%.2i.png' % (i*args.val_batch_size+(k+1), j+1))
          utils.save_image(x_fake[k], ncol=1, filename=filename)

          #filename = os.path.join(args.eval_dir, task + "/tcl_losses.txt")
          #with open(filename, "a") as text_file:
          #  text_file.write(str(tcl_err[k].cpu().numpy()) + "\n")
    
    # evaluate
    print("computing fid, lpips and tcl")

    tasks = [dir for dir in os.listdir(args.eval_dir) if os.path.isdir(os.path.join(args.eval_dir, dir))]
    tasks.sort()

    # fid and lpips
    fid_values = OrderedDict()
    lpips_dict = OrderedDict()
    tcl_values = OrderedDict()
    for task in tasks:
      print(task)
      path_ref = os.path.join(args.eval_dir, task + "/ref")
      path_fake = os.path.join(args.eval_dir, task + "/fake")
      #path_tcl = os.path.join(args.eval_dir, task + "/tcl_losses.txt")
    
      fake_group = load_images(path_fake)
      
      #with open(path_tcl, "r") as text_file:
      #  tcl_data = text_file.read()
      
      #tcl_data = tcl_data.split("\n")[:-1]
      #tcl_data = [float(td) for td in tcl_data]
      tcl_data = tcl_dict[task]
        
      print("TCL", len(tcl_data))
      tcl_mean = np.array(tcl_data).mean()
      print(tcl_mean)
      tcl_values['TCL_%s/%s' % (mode, task)] = float(tcl_mean)
      
      lpips_values = []
      fake_chunks = chunks(fake_group, args.num_outs_per_domain)
      for cidx in range(len(fake_chunks)):
        lpips_value = calculate_lpips_given_images(fake_chunks[cidx])
        lpips_values.append(lpips_value)
      
      print("LPIPS")
      # calculate LPIPS for each task (e.g. cat2dog, dog2cat)
      lpips_mean = np.array(lpips_values).mean()
      lpips_dict['LPIPS_%s/%s' % (mode, task)] = lpips_mean

      print("FID")
      fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake], img_size=args.img_size, batch_size=args.val_batch_size)
      fid_values['FID_%s/%s' % (mode, task)] = fid_value
    
    # calculate the average LPIPS for all tasks
    lpips_mean = 0
    for _, value in lpips_dict.items():
        lpips_mean += value / len(lpips_dict)
    lpips_dict['LPIPS_%s/mean' % mode] = lpips_mean

    # report LPIPS values
    filename = os.path.join(args.eval_dir, 'LPIPS_%.5i_%s.json' % (step, mode))
    utils.save_json(lpips_dict, filename)
    
    # calculate the average FID for all tasks
    fid_mean = 0
    for _, value in fid_values.items():
        fid_mean += value / len(fid_values)
    fid_values['FID_%s/mean' % mode] = fid_mean

    # report FID values
    filename = os.path.join(args.eval_dir, 'FID_%.5i_%s.json' % (step, mode))
    utils.save_json(fid_values, filename)
    
    # calculate the average TCL for all tasks
    tcl_mean = 0
    for _, value in tcl_values.items():
      print(value, len(tcl_values))
      tcl_mean += value / len(tcl_values)
    print(tcl_mean)
    tcl_values['TCL_%s/mean' % mode] = float(tcl_mean)

    # report TCL values
    filename = os.path.join(args.eval_dir, 'TCL_%.5i_%s.json' % (step, mode))
    utils.save_json(tcl_values, filename)
Example #8
0
def evaluate_fc2(args):
    print('Calculating evaluation metrics...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    data_dir = "G:/Datasets/FC2/DATAFiles/"
    style_dir = "G:/Datasets/FC2/styled-files/"
    temp_dir = "G:/Datasets/FC2/styled-files3/"
    num_workers = 0
    args.batch_size = 4

    domains = os.listdir(style_dir)
    domains.sort()
    num_domains = len(domains)
    print('Number of domains: %d' % num_domains)
    print("Batch Size:", args.batch_size)

    _, eval_loader = get_loaderFC2(data_dir, style_dir, temp_dir,
                                   args.batch_size, num_workers, num_domains)

    model_list = os.listdir(args.checkpoints_dir)
    model_list.sort()

    model = create_model(args)
    model.setup(args)

    #generate_new = True

    #num_files = sum([len(files) for r, d, files in os.walk(args.eval_dir)])
    #print("num_files", num_files, len(eval_loader), (1 + args.num_outs_per_domain)*len(eval_loader)*args.batch_size)

    #if num_files != (1 + args.num_outs_per_domain)*len(eval_loader):
    #shutil.rmtree(args.eval_dir, ignore_errors=True)
    #os.makedirs(args.eval_dir)
    generate_new = True

    tcl_dict = {}
    models = []
    # prepare
    for d in range(1, num_domains):
        src_domain = "style0"
        trg_domain = "style" + str(d)

        t1 = '%s2%s' % (src_domain, trg_domain)
        t2 = '%s2%s' % (trg_domain, src_domain)

        tcl_dict[t1] = []
        tcl_dict[t2] = []

        if generate_new:
            create_task_folders(args, t1)
            #create_task_folders(args, t2)

        args.name = model_list[d - 1]
        model = create_model(args)
        model.setup(args)
        models.append(model)

    # generate
    for i, x_src_all in enumerate(tqdm(eval_loader, total=len(eval_loader))):
        x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all

        x_real = x_real.to(device)
        x_real2 = x_real2.to(device)
        y_org = y_org.to(device)
        x_ref = x_ref.to(device)
        y_trg = y_trg.to(device)
        mask = mask.to(device)
        flow = flow.to(device)

        N = x_real.size(0)

        for k in range(N):
            y_org_np = y_org[k].cpu().numpy()
            y_trg_np = y_trg[k].cpu().numpy()
            src_domain = "style" + str(y_org_np)
            trg_domain = "style" + str(y_trg_np)

            if src_domain == trg_domain or y_trg_np == 0:
                continue

            task = '%s2%s' % (src_domain, trg_domain)

            if y_trg_np != 0:
                y = y_trg_np

                x_fake = models[y - 1].forward_eval(x_real)
                x_fake2 = models[y - 1].forward_eval(x_real2, x_real, x_fake)
            else:
                y = y_org_np

                x_fake = models[y - 1].forward_eval(x_real, AtoB=False)
                x_fake2 = models[y - 1].forward_eval(x_real2,
                                                     x_real,
                                                     x_fake,
                                                     AtoB=False)

            x_warp = warp(x_fake, flow)
            tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2, 3))**0.5

            tcl_dict[task].append(tcl_err[k].cpu().numpy())

            path_ref = os.path.join(args.eval_dir, task + "/ref")
            path_fake = os.path.join(args.eval_dir, task + "/fake")

            #if not os.path.exists(path_ref):
            #  os.makedirs(path_ref)

            #if not os.path.exists(path_fake):
            #  os.makedirs(path_fake)

            if generate_new:
                filename = os.path.join(
                    path_ref, '%.4i.png' % (i * args.batch_size + (k + 1)))
                utils.save_image(x_ref[k], ncol=1, filename=filename)

            filename = os.path.join(
                path_fake, '%.4i.png' % (i * args.batch_size + (k + 1)))
            utils.save_image(x_fake[k], ncol=1, filename=filename)

            #filename = os.path.join(args.eval_dir, task + "/tcl_losses.txt")
            #with open(filename, "a") as text_file:
            #  text_file.write(str(tcl_err[k].cpu().numpy()) + "\n")

    # evaluate
    print("computing fid, lpips and tcl")

    tasks = [
        dir for dir in os.listdir(args.eval_dir)
        if os.path.isdir(os.path.join(args.eval_dir, dir))
    ]
    tasks.sort()

    # fid and lpips
    fid_values = OrderedDict()
    #lpips_dict = OrderedDict()
    tcl_values = OrderedDict()
    for task in tasks:
        print(task)
        path_ref = os.path.join(args.eval_dir, task + "/ref")
        path_fake = os.path.join(args.eval_dir, task + "/fake")
        #path_tcl = os.path.join(args.eval_dir, task + "/tcl_losses.txt")

        #fake_group = load_images(path_fake)

        #with open(path_tcl, "r") as text_file:
        #  tcl_data = text_file.read()

        #tcl_data = tcl_data.split("\n")[:-1]
        #tcl_data = [float(td) for td in tcl_data]
        tcl_data = tcl_dict[task]

        print("TCL", len(tcl_data))
        tcl_mean = np.array(tcl_data).mean()
        print(tcl_mean)
        tcl_values['TCL_%s' % (task)] = float(tcl_mean)
        '''
    lpips_values = []
    fake_chunks = chunks(fake_group, 1)
    for cidx in range(len(fake_chunks)):
      lpips_value = calculate_lpips_given_images(fake_chunks[cidx])
      lpips_values.append(lpips_value)
    
    
    print("LPIPS")
    # calculate LPIPS for each task (e.g. cat2dog, dog2cat)
    lpips_mean = np.array(lpips_values).mean()
    lpips_dict['LPIPS_%s' % (task)] = lpips_mean
    '''

        print("FID")
        fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake],
                                              img_size=256,
                                              batch_size=args.batch_size)
        fid_values['FID_%s' % (task)] = fid_value
    '''
  # calculate the average LPIPS for all tasks
  lpips_mean = 0
  for _, value in lpips_dict.items():
      lpips_mean += value / len(lpips_dict)
  lpips_dict['LPIPS_mean'] = lpips_mean

  # report LPIPS values
  filename = os.path.join(args.eval_dir, 'LPIPS.json')
  utils.save_json(lpips_dict, filename)'''

    # calculate the average FID for all tasks
    fid_mean = 0
    #fid_means = [[], [], []]
    for key, value in fid_values.items():
        #for d in range(1, num_domains):
        #  if str(d) in key:
        #    fid_means[d-1].append(value)
        fid_mean += value / len(fid_values)

    #for d in range(1, num_domains):
    #  fid_values['FID_s%d_mean' % d] = np.array(fid_means[d-1]).mean()

    fid_values['FID_mean'] = fid_mean

    # report FID values
    filename = os.path.join(args.eval_dir, 'FID.json')
    utils.save_json(fid_values, filename)

    # calculate the average TCL for all tasks
    tcl_mean = 0
    #tcl_means = [[], [], []]
    for _, value in tcl_values.items():
        #for d in range(1, num_domains):
        #  if str(d) in key:
        #    tcl_means[d-1].append(value)
        #print(value, len(tcl_values))
        tcl_mean += value / len(tcl_values)
    #print(tcl_mean)
    #for d in range(1, num_domains):
    #  tcl_values['TCL_s%d_mean' % d] = np.array(tcl_means[d-1]).mean()

    tcl_values['TCL_mean'] = float(tcl_mean)

    # report TCL values
    filename = os.path.join(args.eval_dir, 'TCL.json')
    utils.save_json(tcl_values, filename)
def eval_fc2(net, args):
    print('Calculating evaluation metrics...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    data_dir = "G:/Datasets/FC2/DATAFiles/"
    style_dir = "G:/Datasets/FC2/styled-files/"
    temp_dir = "G:/Datasets/FC2/styled-files3/"

    #data_dir = "/srv/local/tomstrident/datasets/FC2/DATAFiles/"
    #style_dir = "/srv/local/tomstrident/datasets/FC2/styled-files/"
    #temp_dir = "/srv/local/tomstrident/datasets/FC2/styled-files3/"

    eval_dir = os.getcwd() + "/eval_fc2/" + str(args.weight_tcl) + "/"

    num_workers = 0
    net.batch_size = 1  #args.batch_size

    pyr_shapes = [(64, 64), (128, 128), (256, 256)]
    net.set_shapes(pyr_shapes)

    transform = T.Compose([  #T.Resize(pyr_shapes[-1]),
        T.ToTensor(),
        T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),  #turn to BGR
        T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]),
        T.Lambda(lambda x: x.mul_(255))
    ])

    domains = os.listdir(style_dir)
    domains.sort()
    num_domains = len(domains)
    print('Number of domains: %d' % num_domains)
    print("Batch Size:", args.batch_size)

    _, eval_loader = get_loaderFC2(data_dir, style_dir, temp_dir, transform,
                                   args.batch_size, num_workers, num_domains)

    generate_new = True

    tcl_dict = {}
    # prepare
    for d in range(1, num_domains):
        src_domain = "style0"
        trg_domain = "style" + str(d)

        t1 = '%s2%s' % (src_domain, trg_domain)
        t2 = '%s2%s' % (trg_domain, src_domain)

        tcl_dict[t1] = []
        tcl_dict[t2] = []

        if generate_new:
            create_task_folders(eval_dir, t1)
            #create_task_folders(eval_dir, t2)

    # generate
    for i, x_src_all in enumerate(tqdm(eval_loader, total=len(eval_loader))):
        x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all

        x_real = x_real.to(device)
        x_real2 = x_real2.to(device)
        y_org = y_org.to(device)
        x_ref = x_ref.to(device)
        y_trg = y_trg.to(device)
        mask = mask.to(device)
        flow = flow.to(device)

        mask_zero = torch.zeros(mask.shape).to(device)

        N = x_real.size(0)
        #y = y_trg.cpu().numpy()

        for k in range(N):
            y_org_np = y_org[k].cpu().numpy()
            y_trg_np = y_trg[k].cpu().numpy()
            src_domain = "style" + str(y_org_np)
            trg_domain = "style" + str(y_trg_np)

            if src_domain == trg_domain or y_trg_np == 0:
                continue

            task = '%s2%s' % (src_domain, trg_domain)
            net.set_style(y_trg_np - 1)

            x_fake = net.run(x_real, x_real, y_trg_np - 1, mask_zero,
                             args.weight_tcl)
            x_warp = warp(x_fake, flow)
            #x_fake2 = net.run(mask*x_warp  + (1 - mask)*x_real2, x_real2, y_trg_np - 1, mask)
            x_fake2 = net.run(x_warp, x_real2, y_trg_np - 1, mask,
                              args.weight_tcl)

            tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2, 3))**0.5

            tcl_dict[task].append(tcl_err[k].cpu().numpy())

            path_ref = os.path.join(eval_dir, task + "/ref")
            path_fake = os.path.join(eval_dir, task + "/fake")

            if generate_new:
                filename = os.path.join(
                    path_ref, '%.4i.png' % (i * args.batch_size + (k + 1)))
                if y_trg_np - 1 == 2:
                    out_img = net.postp2(x_ref.data[0].cpu())
                else:
                    out_img = net.postp(x_ref.data[0].cpu())
                out_img.save(filename)

            filename = os.path.join(
                path_fake, '%.4i.png' % (i * args.batch_size + (k + 1)))
            if y_trg_np - 1 == 2:
                out_img = net.postp2(x_fake.data[0].cpu())
            else:
                out_img = net.postp(x_fake.data[0].cpu())
            out_img.save(filename)

    # evaluate
    print("computing fid, lpips and tcl")

    tasks = [
        dir for dir in os.listdir(eval_dir)
        if os.path.isdir(os.path.join(eval_dir, dir))
    ]
    tasks.sort()

    # fid and lpips
    fid_values = OrderedDict()
    #lpips_dict = OrderedDict()
    tcl_values = OrderedDict()
    for task in tasks:
        print(task)
        path_ref = os.path.join(eval_dir, task + "/ref")
        path_fake = os.path.join(eval_dir, task + "/fake")

        tcl_data = tcl_dict[task]

        print("TCL", len(tcl_data))
        tcl_mean = np.array(tcl_data).mean()
        print(tcl_mean)
        tcl_values['TCL_%s' % (task)] = float(tcl_mean)

        print("FID")
        fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake],
                                              img_size=256,
                                              batch_size=args.batch_size)
        fid_values['FID_%s' % (task)] = fid_value

    # calculate the average FID for all tasks
    fid_mean = 0
    for key, value in fid_values.items():
        fid_mean += value / len(fid_values)

    fid_values['FID_mean'] = fid_mean

    # report FID values
    filename = os.path.join(eval_dir, 'FID.json')
    utils.save_json(fid_values, filename)

    # calculate the average TCL for all tasks
    tcl_mean = 0
    for _, value in tcl_values.items():
        tcl_mean += value / len(tcl_values)

    tcl_values['TCL_mean'] = float(tcl_mean)

    # report TCL values
    filename = os.path.join(eval_dir, 'TCL.json')
    utils.save_json(tcl_values, filename)
    def eval(self):
        self.restore_model(self.test_iters)

        print('Calculating evaluation metrics...')
        #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        eval_dir = os.getcwd() + "/eval/"
        print(eval_dir)
        eval_loader = self.eval_loader

        if not os.path.exists(eval_dir):
            os.makedirs(eval_dir)

        num_domains = 4

        #generate_new = True

        #num_files = sum([len(files) for r, d, files in os.walk(args.eval_dir)])
        #print("num_files", num_files, len(eval_loader), (1 + args.num_outs_per_domain)*len(eval_loader)*args.batch_size)

        #if num_files != (1 + args.num_outs_per_domain)*len(eval_loader):
        #shutil.rmtree(args.eval_dir, ignore_errors=True)
        #os.makedirs(args.eval_dir)
        generate_new = True

        tcl_dict = {}
        # prepare
        for d in range(1, num_domains):
            src_domain = "style0"
            trg_domain = "style" + str(d)

            t1 = '%s2%s' % (src_domain, trg_domain)
            t2 = '%s2%s' % (trg_domain, src_domain)

            tcl_dict[t1] = []
            tcl_dict[t2] = []

            if generate_new:
                create_task_folders(eval_dir, t1)
                #create_task_folders(eval_dir, t2)

        # generate
        for i, x_src_all in enumerate(tqdm(eval_loader,
                                           total=len(eval_loader))):
            x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all

            x_real = x_real.to(self.device)
            x_real2 = x_real2.to(self.device)
            y_org = y_org.to(self.device)
            x_ref = x_ref.to(self.device)
            y_trg = y_trg.to(self.device)
            mask = mask.to(self.device)
            flow = flow.to(self.device)

            N = x_real.size(0)

            for k in range(N):
                y_org_np = y_org[k].cpu().numpy()
                y_trg_np = y_trg[k].cpu().numpy()
                src_domain = "style" + str(y_org_np)
                trg_domain = "style" + str(y_trg_np)

                if src_domain == trg_domain or y_trg_np == 0:
                    continue

                task = '%s2%s' % (src_domain, trg_domain)

                b_size = x_real.shape[0]
                c_trg = np.zeros((b_size, self.c_dim))

                for bs in range(b_size):
                    c_trg[bs, y_trg[bs]] = 1

                c_trg = torch.tensor(c_trg).float().to(self.device)

                x_fake = self.G(x_real, c_trg)
                x_fake2 = self.G(x_real2, c_trg)

                x_warp = warp(x_fake, flow)
                tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2,
                                                                     3))**0.5

                tcl_dict[task].append(tcl_err[k].cpu().numpy())

                path_ref = os.path.join(eval_dir, task + "/ref")
                path_fake = os.path.join(eval_dir, task + "/fake")

                #if not os.path.exists(path_ref):
                #  os.makedirs(path_ref)

                #if not os.path.exists(path_fake):
                #  os.makedirs(path_fake)

                if generate_new:
                    filename = os.path.join(
                        path_ref, '%.4i.png' % (i * self.batch_size + (k + 1)))
                    utils.save_image(x_ref[k], ncol=1, filename=filename)

                filename = os.path.join(
                    path_fake, '%.4i.png' % (i * self.batch_size + (k + 1)))
                utils.save_image(x_fake[k], ncol=1, filename=filename)

                #filename = os.path.join(args.eval_dir, task + "/tcl_losses.txt")
                #with open(filename, "a") as text_file:
                #  text_file.write(str(tcl_err[k].cpu().numpy()) + "\n")

        # evaluate
        print("computing fid, lpips and tcl")

        tasks = [
            dir for dir in os.listdir(eval_dir)
            if os.path.isdir(os.path.join(eval_dir, dir))
        ]
        tasks.sort()

        # fid and lpips
        fid_values = OrderedDict()
        #lpips_dict = OrderedDict()
        tcl_values = OrderedDict()
        for task in tasks:
            print(task)
            path_ref = os.path.join(eval_dir, task + "/ref")
            path_fake = os.path.join(eval_dir, task + "/fake")
            #path_tcl = os.path.join(args.eval_dir, task + "/tcl_losses.txt")

            #fake_group = load_images(path_fake)

            #with open(path_tcl, "r") as text_file:
            #  tcl_data = text_file.read()

            #tcl_data = tcl_data.split("\n")[:-1]
            #tcl_data = [float(td) for td in tcl_data]
            tcl_data = tcl_dict[task]

            print("TCL", len(tcl_data))
            tcl_mean = np.array(tcl_data).mean()
            print(tcl_mean)
            tcl_values['TCL_%s' % (task)] = float(tcl_mean)
            '''
        lpips_values = []
        fake_chunks = chunks(fake_group, 1)
        for cidx in range(len(fake_chunks)):
          lpips_value = calculate_lpips_given_images(fake_chunks[cidx])
          lpips_values.append(lpips_value)
        
        
        print("LPIPS")
        # calculate LPIPS for each task (e.g. cat2dog, dog2cat)
        lpips_mean = np.array(lpips_values).mean()
        lpips_dict['LPIPS_%s' % (task)] = lpips_mean
        '''

            print("FID")
            fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake],
                                                  img_size=256,
                                                  batch_size=self.batch_size)
            fid_values['FID_%s' % (task)] = fid_value
        '''
      # calculate the average LPIPS for all tasks
      lpips_mean = 0
      for _, value in lpips_dict.items():
          lpips_mean += value / len(lpips_dict)
      lpips_dict['LPIPS_mean'] = lpips_mean
    
      # report LPIPS values
      filename = os.path.join(args.eval_dir, 'LPIPS.json')
      utils.save_json(lpips_dict, filename)'''

        # calculate the average FID for all tasks
        fid_mean = 0
        #fid_means = [[], [], []]
        for key, value in fid_values.items():
            #for d in range(1, num_domains):
            #  if str(d) in key:
            #    fid_means[d-1].append(value)
            fid_mean += value / len(fid_values)

        #for d in range(1, num_domains):
        #  fid_values['FID_s%d_mean' % d] = np.array(fid_means[d-1]).mean()

        fid_values['FID_mean'] = fid_mean

        # report FID values
        filename = os.path.join(eval_dir, 'FID.json')
        utils.save_json(fid_values, filename)

        # calculate the average TCL for all tasks
        tcl_mean = 0
        #tcl_means = [[], [], []]
        for _, value in tcl_values.items():
            #for d in range(1, num_domains):
            #  if str(d) in key:
            #    tcl_means[d-1].append(value)
            #print(value, len(tcl_values))
            tcl_mean += value / len(tcl_values)
        #print(tcl_mean)
        #for d in range(1, num_domains):
        #  tcl_values['TCL_s%d_mean' % d] = np.array(tcl_means[d-1]).mean()

        tcl_values['TCL_mean'] = float(tcl_mean)

        # report TCL values
        filename = os.path.join(eval_dir, 'TCL.json')
        utils.save_json(tcl_values, filename)