def computeTCL(net, model, img_fake, img1, img2, c_trg):
    ff_last = computeRAFT(model, img2, img1)
    bf_last = computeRAFT(model, img1, img2)
    mask_last = fbcCheckTorch(ff_last, bf_last)
    #warp_last = warp(net.generator(img2, s_trg), bf_last)
    warp_last = warp(net(img2, c_trg), bf_last)
    return ((mask_last * (img_fake - warp_last))**2).mean()**0.5
    def forward_train(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))
        
        # 2nd frame
        self.fake_B2 = self.netG_A(self.real_A2)  # G_A(A)
        self.rec_A2 = self.netG_B(self.fake_B2)   # G_B(G_A(A))
        self.fake_A2 = self.netG_B(self.real_B2)  # G_B(B)
        self.rec_B2 = self.netG_A(self.fake_A2)   # G_A(G_B(B))

        self.ff_real_A = self.computeRAFT(self.real_A, self.real_A2)
        self.bf_real_A = self.computeRAFT(self.real_A2, self.real_A)
        self.bf_fake_B = self.computeRAFT(self.fake_B2, self.fake_B)
        self.bf_rec_A = self.computeRAFT(self.rec_A2, self.rec_A)
        self.bf_M_A = self.netM_A(self.bf_real_A)
        self.warp_B = warp(self.fake_B, self.bf_M_A)
        self.mask_A = fbcCheckTorch(self.ff_real_A, self.bf_real_A)
        
        #print("real_A", self.real_A.requires_grad)
        #print("fake_B", self.fake_B.requires_grad)
        #print("real_A2", self.real_A2.requires_grad)
        #print("ff_real_A", self.ff_real_A.requires_grad)
        #print("bf_M_A", self.bf_M_A.requires_grad)
        #print("warp_B", self.warp_B.requires_grad)
        #print("mask_A", self.mask_A.requires_grad)
        
        self.ff_real_B = self.computeRAFT(self.real_B, self.real_B2)
        self.bf_real_B = self.computeRAFT(self.real_B2, self.real_B)
        self.bf_fake_A = self.computeRAFT(self.fake_A2, self.fake_A)
        self.bf_rec_B = self.computeRAFT(self.rec_B2, self.rec_B)
        self.bf_M_B = self.netM_B(self.bf_real_B)
        self.warp_A = warp(self.fake_A, self.bf_M_B)
        self.mask_B = fbcCheckTorch(self.ff_real_B, self.bf_real_B)
Пример #3
0
    def forward_train(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        # 1st frame (t-1)
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)  XT-1
        self.rec_A = self.netG_B(self.fake_B)  # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)  # G_A(G_B(B))

        # 2nd frame (t)
        self.fake_B2 = self.netG_A(self.real_A2)  # G_A(A)
        self.rec_A2 = self.netG_B(self.fake_B2)  # G_B(G_A(A))
        self.fake_A2 = self.netG_B(self.real_B2)  # G_B(B)
        self.rec_B2 = self.netG_A(self.fake_A2)  # G_A(G_B(B))

        self.bf_real_A = self.computeRAFT(self.real_A2, self.real_A)
        self.warp_B = warp(self.fake_B, self.bf_real_A)
        self.fuse_B = self.netF_A(self.fake_B2, self.warp_B)  #XT
        self.mask_A = self.generateMask(self.real_A2,
                                        warp(self.real_A, self.bf_real_A))
        self.vgg_fuse_B = self.netVGG_19(self.fuse_B)
        self.vgg_real_A2 = self.netVGG_19(self.real_A2)

        self.bf_fake_B = self.computeRAFT(self.fuse_B, self.fake_B)
        self.rec3D_A2 = self.netF_B(self.netG_B(self.fuse_B),
                                    warp(self.fake_B, self.bf_fake_B))

        self.bf_real_B = self.computeRAFT(self.real_B2, self.real_B)
        self.warp_A = warp(self.fake_A, self.bf_real_B)
        self.fuse_A = self.netF_B(self.fake_A2, self.warp_A)
        self.mask_B = self.generateMask(self.real_B2,
                                        warp(self.real_B, self.bf_real_B))
        self.vgg_fuse_A = self.netVGG_19(self.fuse_A)
        self.vgg_real_B2 = self.netVGG_19(self.real_B2)

        self.bf_fake_A = self.computeRAFT(self.fuse_A, self.fake_A)
        self.rec3D_B2 = self.netF_A(self.netG_A(self.fuse_A),
                                    warp(self.fake_A, self.bf_fake_A))
        '''
Пример #4
0
def evaluate_sintel(args, sintel_dir="D:/Datasets/MPI-Sintel-complete/"):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    out_path = "G:/Code/ConGAN/eval/"
    raft_model = initRaftModel(args, device)

    #domains = os.listdir(args.style_dir)
    #domains.sort()
    num_domains = 4  #len(domains)

    transform = []
    transform.append(transforms.ToTensor())
    transform.append(
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    transform = transforms.Compose(transform)

    train_dir = os.path.join(sintel_dir, "training", "final")
    train_list = os.listdir(train_dir)
    train_list.sort()

    test_dir = os.path.join(sintel_dir, "test", "final")
    test_list = os.listdir(test_dir)
    test_list.sort()

    video_list = [os.path.join(train_dir, vid) for vid in train_list]
    video_list += [os.path.join(test_dir, vid) for vid in test_list]

    vid_list = train_list + test_list
    tcl_st_dict = {}
    tcl_lt_dict = {}

    tcl_st_dict = OrderedDict()
    tcl_lt_dict = OrderedDict()
    dt_dict = OrderedDict()

    args.checkpoints_dir = os.getcwd() + "\\checkpoints\\"
    model_list = os.listdir(args.checkpoints_dir)
    model_list.sort()

    for j, vid_dir in enumerate(video_list):
        vid = vid_list[j]

        sintel_dset = SingleSintelVideo(vid_dir, transform)
        loader = data.DataLoader(dataset=sintel_dset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0)

        for y in range(1, num_domains):
            #y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device)
            key = vid + "_s" + str(y)
            vid_path = os.path.join(out_path, key)
            if not os.path.exists(vid_path):
                os.makedirs(vid_path)

            tcl_st_vals = []
            tcl_lt_vals = []
            dt_vals = []

            args.name = model_list[y - 1]
            #args.model = "cycle_gan"
            model = create_model(args)
            model.setup(args)
            x_fake = []

            for i, imgs in enumerate(tqdm(loader, total=len(loader))):
                img, img_last, img_past = imgs

                img = img.to(device)
                img_last = img_last.to(device)
                img_past = img_past.to(device)

                if i > 0:
                    B, C, H, W = img.size()
                    bf = model.computeRAFT(img, img_last)[:, :, :H, :W]
                    wrp = warp(x_fake, bf)

                if i == 0:
                    t_start = time.time()
                    x_fake = model.forward_eval(img)
                    t_end = time.time()
                else:
                    t_start = time.time()
                    x_fake = model.forward_eval(img, wrp)
                    t_end = time.time()

                dt_vals.append((t_end - t_start) * 1000)

                if i > 0:
                    tcl_st = computeTCL(model, raft_model, x_fake, img,
                                        img_last)
                    tcl_st_vals.append(tcl_st.cpu().numpy())

                if i >= 5:
                    tcl_lt = computeTCL(model, raft_model, x_fake, img,
                                        img_past)
                    tcl_lt_vals.append(tcl_lt.cpu().numpy())

                filename = os.path.join(vid_path, "frame_%04d.png" % i)
                save_image(x_fake[0], ncol=1, filename=filename)

            tcl_st_dict["TCL-ST_" + key] = float(np.array(tcl_st_vals).mean())
            tcl_lt_dict["TCL-LT_" + key] = float(np.array(tcl_lt_vals).mean())
            dt_dict["DT_" + key] = float(np.array(dt_vals).mean())

    save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains)
    save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains)
    save_dict_as_json("DT", dt_dict, out_path, num_domains)
    def evaluate_fc2(self,
                     args,
                     n_styles,
                     epochs,
                     n_epochs,
                     emphasis_parameter,
                     batchsize=16,
                     learning_rate=1e-3,
                     dset='FC2'):
        print('Calculating evaluation metrics...')
        #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        data_dir = "G:/Datasets/FC2/DATAFiles/"
        style_dir = "G:/Datasets/FC2/styled-files/"
        temp_dir = "G:/Datasets/FC2/styled-files3/"
        eval_dir = os.getcwd() + "/eval_fc2/" + self.method + "/"

        num_workers = 0
        args.batch_size = 4

        domains = os.listdir(style_dir)
        domains.sort()
        num_domains = len(domains)
        print('Number of domains: %d' % num_domains)
        print("Batch Size:", args.batch_size)

        _, eval_loader = get_loaderFC2(data_dir, style_dir, temp_dir,
                                       args.batch_size, num_workers,
                                       num_domains)

        tmp_dir = self.train_dir + dset + '/' + self.method + '/'
        tmp_list = os.listdir(tmp_dir)
        tmp_list.sort()

        models = []
        pre_models = []
        if n_styles > 1:
            model = FastStyleNet(3, n_styles).to(self.device)
            model.load_state_dict(
                torch.load(tmp_dir + '/' + tmp_list[0] + '/epoch_' +
                           str(n_epochs) + '.pth'))
        else:
            if self.method == "ruder":
                for tmp in tmp_list:
                    model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device)
                    model.load_state_dict(
                        torch.load(tmp_dir + '/' + tmp + '/epoch_' +
                                   str(n_epochs) + '.pth'))
                    models.append(model)
                    pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + tmp[
                        3] + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth"
                    model = FastStyleNet(3, n_styles).to(self.device)
                    model.load_state_dict(torch.load(pre_style_path))
                    pre_models.append(model)
            else:
                for tmp in tmp_list:
                    model = FastStyleNet(3, n_styles).to(self.device)
                    model.load_state_dict(
                        torch.load(tmp_dir + '/' + tmp + '/epoch_' +
                                   str(n_epochs) + '.pth'))
                    models.append(model)

        generate_new = True

        tcl_dict = {}
        # prepare
        for d in range(1, num_domains):
            src_domain = "style0"
            trg_domain = "style" + str(d)

            t1 = '%s2%s' % (src_domain, trg_domain)
            t2 = '%s2%s' % (trg_domain, src_domain)

            tcl_dict[t1] = []
            tcl_dict[t2] = []

            if generate_new:
                create_task_folders(eval_dir, t1)
                #create_task_folders(eval_dir, t2)

        # generate
        for i, x_src_all in enumerate(tqdm(eval_loader,
                                           total=len(eval_loader))):
            x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all

            x_real = x_real.to(self.device)
            x_real2 = x_real2.to(self.device)
            y_org = y_org.to(self.device)
            x_ref = x_ref.to(self.device)
            y_trg = y_trg.to(self.device)
            mask = mask.to(self.device)
            flow = flow.to(self.device)

            N = x_real.size(0)

            for k in range(N):
                y_org_np = y_org[k].cpu().numpy()
                y_trg_np = y_trg[k].cpu().numpy()
                src_domain = "style" + str(y_org_np)
                trg_domain = "style" + str(y_trg_np)

                if src_domain == trg_domain or y_trg_np == 0:
                    continue

                task = '%s2%s' % (src_domain, trg_domain)

                if n_styles > 1:
                    self.model = model
                else:
                    self.model = models[y_trg_np - 1]

                if self.method == "ruder":
                    self.pre_style_model = pre_models[y_trg_np - 1]

                x_fake = self.infer_method((x_real, None, None), y_trg[k] - 1)
                #x_fake = torch.clamp(x_fake, 0.0, 1.0)
                x_warp = warp(x_fake, flow)
                x_fake2 = self.infer_method((x_real2, mask, x_warp),
                                            y_trg[k] - 1)
                #x_fake2 = torch.clamp(x_fake2, 0.0, 1.0)

                tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2,
                                                                     3))**0.5

                tcl_dict[task].append(tcl_err[k].cpu().numpy())

                path_ref = os.path.join(eval_dir, task + "/ref")
                path_fake = os.path.join(eval_dir, task + "/fake")

                if generate_new:
                    filename = os.path.join(
                        path_ref, '%.4i.png' % (i * args.batch_size + (k + 1)))
                    save_image(denormalize(x_ref[k]),
                               ncol=1,
                               filename=filename)

                filename = os.path.join(
                    path_fake, '%.4i.png' % (i * args.batch_size + (k + 1)))
                save_image(x_fake[k], ncol=1, filename=filename)

        # evaluate
        print("computing fid, lpips and tcl")

        tasks = [
            dir for dir in os.listdir(eval_dir)
            if os.path.isdir(os.path.join(eval_dir, dir))
        ]
        tasks.sort()

        # fid and lpips
        fid_values = OrderedDict()
        #lpips_dict = OrderedDict()
        tcl_values = OrderedDict()
        for task in tasks:
            print(task)
            path_ref = os.path.join(eval_dir, task + "/ref")
            path_fake = os.path.join(eval_dir, task + "/fake")

            tcl_data = tcl_dict[task]

            print("TCL", len(tcl_data))
            tcl_mean = np.array(tcl_data).mean()
            print(tcl_mean)
            tcl_values['TCL_%s' % (task)] = float(tcl_mean)

            print("FID")
            fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake],
                                                  img_size=256,
                                                  batch_size=args.batch_size)
            fid_values['FID_%s' % (task)] = fid_value

        # calculate the average FID for all tasks
        fid_mean = 0
        for key, value in fid_values.items():
            fid_mean += value / len(fid_values)

        fid_values['FID_mean'] = fid_mean

        # report FID values
        filename = os.path.join(eval_dir, 'FID.json')
        utils.save_json(fid_values, filename)

        # calculate the average TCL for all tasks
        tcl_mean = 0
        for _, value in tcl_values.items():
            tcl_mean += value / len(tcl_values)

        tcl_values['TCL_mean'] = float(tcl_mean)

        # report TCL values
        filename = os.path.join(eval_dir, 'TCL.json')
        utils.save_json(tcl_values, filename)
    def evaluate_sintel(self,
                        args,
                        n_styles,
                        epochs,
                        n_epochs,
                        emphasis_parameter,
                        sintel_dir="D:/Datasets/MPI-Sintel-complete/",
                        batchsize=16,
                        learning_rate=1e-3,
                        dset='FC2'):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        out_path = "G:/Code/LBST/eval_sintel/" + self.method + "/"
        raft_model = initRaftModel(args)

        num_domains = 4

        transform = []
        transform.append(transforms.ToTensor())
        transform.append(
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
        transform = transforms.Compose(transform)

        train_dir = os.path.join(sintel_dir, "training", "final")
        train_list = os.listdir(train_dir)
        train_list.sort()

        test_dir = os.path.join(sintel_dir, "test", "final")
        test_list = os.listdir(test_dir)
        test_list.sort()

        #video_list = [os.path.join(train_dir, vid) for vid in train_list]
        #video_list += [os.path.join(test_dir, vid) for vid in test_list]

        video_list = [
            os.path.join(train_dir, "alley_2"),
            os.path.join(train_dir, "market_6"),
            os.path.join(train_dir, "temple_2")
        ]

        #vid_list = train_list + test_list
        vid_list = ["alley_2", "market_6", "temple_2"]

        tcl_st_dict = {}
        tcl_lt_dict = {}

        tcl_st_dict = OrderedDict()
        tcl_lt_dict = OrderedDict()
        dt_dict = OrderedDict()

        #emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles)
        tmp_dir = self.train_dir + dset + '/' + self.method + '/'
        tmp_list = os.listdir(tmp_dir)
        tmp_list.sort()
        #run_id = self.setup_method(run_id, emphasis_parameter.T)

        if self.method == "ruder":
            self.model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device)
            self.pre_style_model = FastStyleNet(3, n_styles).to(self.device)
        else:
            self.model = FastStyleNet(3, n_styles).to(self.device)

        first = True

        for j, vid_dir in enumerate(video_list):
            vid = vid_list[j]

            #print(vid_dir)

            sintel_dset = SingleSintelVideo(vid_dir, transform)
            loader = data.DataLoader(dataset=sintel_dset,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=0)

            for y in range(1, num_domains):
                y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device)
                key = vid + "_s" + str(y)
                vid_path = os.path.join(out_path, key)
                if not os.path.exists(vid_path):
                    os.makedirs(vid_path)

                if y == 3:
                    gray = True
                else:
                    gray = False

                tcl_st_vals = []
                tcl_lt_vals = []
                dt_vals = []

                #if n_styles > 1:
                #  run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate))
                #else:
                #  run_id = "sid%d_ep%d_bs%d_lr%d" % (y - 1, epochs, batchsize, np.log10(learning_rate))

                if n_styles > 1:
                    if first:
                        self.model.load_state_dict(
                            torch.load(tmp_dir + '/' + tmp_list[y - 1] +
                                       '/epoch_' + str(n_epochs) + '.pth'))
                        first = False
                else:
                    #print(tmp_dir + '/' + tmp_list[y-1] + '/epoch_' + str(n_epochs) + '.pth')
                    self.model.load_state_dict(
                        torch.load(tmp_dir + '/' + tmp_list[y - 1] +
                                   '/epoch_' + str(n_epochs) + '.pth'))

                if self.method == "ruder":
                    pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + str(
                        y - 1) + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth"
                    self.pre_style_model.load_state_dict(
                        torch.load(pre_style_path))

                past_sty_list = []

                for i, imgs in enumerate(tqdm(loader, total=len(loader))):
                    img, img_last, img_past = imgs

                    img = img.to(device)
                    img_last = img_last.to(device)
                    img_past = img_past.to(device)

                    #save_image(img[0], ncol=1, filename="blah.png")
                    if i > 0:
                        ff_last = computeRAFT(raft_model, img_last, img)
                        bf_last = computeRAFT(raft_model, img, img_last)
                        mask_last = fbcCheckTorch(ff_last, bf_last)
                        x_fake_last = past_sty_list[
                            -1]  #self.infer_method((img_last, None, None), y_trg - 1)
                        warp_last = warp(torch.clamp(x_fake_last, 0.0, 1.0),
                                         bf_last)
                    else:
                        mask_last = None
                        warp_last = None
                    #mask, x_warp

                    t_start = time.time()
                    x_fake = self.infer_method((img, mask_last, warp_last),
                                               y_trg - 1)
                    x_fake = torch.clamp(x_fake, 0.0, 1.0)
                    t_end = time.time()

                    past_sty_list.append(x_fake)
                    dt_vals.append((t_end - t_start) * 1000)

                    if i > 0:
                        tcl_st = ((mask_last *
                                   (x_fake - warp_last))**2).mean()**0.5
                        tcl_st_vals.append(tcl_st.cpu().numpy())

                    if i >= 5:
                        ff_past = computeRAFT(raft_model, img_past, img)
                        bf_past = computeRAFT(raft_model, img, img_past)
                        mask_past = fbcCheckTorch(ff_past, bf_past)
                        #torch.clamp(self.infer_method((img_past, None, None), y_trg - 1), 0.0, 1.0)
                        warp_past = warp(past_sty_list[0], bf_past)
                        tcl_lt = ((mask_past *
                                   (x_fake - warp_past))**2).mean()**0.5
                        tcl_lt_vals.append(tcl_lt.cpu().numpy())
                        '''
            print(img.shape)
            print(img_past.shape)
            print(warp_past.shape)
            print(x_fake.shape)
            print(past_sty_list[0].shape)
            
            save_image(denormalize(img[0]), ncol=1, filename="blah1.png")
            save_image(denormalize(img_past[0]), ncol=1, filename="blah2.png")
            save_image(warp_past[0], ncol=1, filename="blah3.png")
            save_image(x_fake[0], ncol=1, filename="blah4.png")
            save_image(past_sty_list[0][0], ncol=1, filename="blah5.png")
            save_image(mask_past*warp_past, ncol=1, filename="blah6.png")
            blah'''

                        past_sty_list.pop(0)

                    filename = os.path.join(vid_path, "frame_%04d.png" % i)
                    save_image(x_fake[0], ncol=1, filename=filename, gray=gray)

                tcl_st_dict["TCL-ST_" + key] = float(
                    np.array(tcl_st_vals).mean())
                tcl_lt_dict["TCL-LT_" + key] = float(
                    np.array(tcl_lt_vals).mean())
                dt_dict["DT_" + key] = float(np.array(dt_vals).mean())

        save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains)
        save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains)
        save_dict_as_json("DT", dt_dict, out_path, num_domains)
def eval_fc2(net, args):
    print('Calculating evaluation metrics...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    data_dir = "G:/Datasets/FC2/DATAFiles/"
    style_dir = "G:/Datasets/FC2/styled-files/"
    temp_dir = "G:/Datasets/FC2/styled-files3/"

    #data_dir = "/srv/local/tomstrident/datasets/FC2/DATAFiles/"
    #style_dir = "/srv/local/tomstrident/datasets/FC2/styled-files/"
    #temp_dir = "/srv/local/tomstrident/datasets/FC2/styled-files3/"

    eval_dir = os.getcwd() + "/eval_fc2/" + str(args.weight_tcl) + "/"

    num_workers = 0
    net.batch_size = 1  #args.batch_size

    pyr_shapes = [(64, 64), (128, 128), (256, 256)]
    net.set_shapes(pyr_shapes)

    transform = T.Compose([  #T.Resize(pyr_shapes[-1]),
        T.ToTensor(),
        T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),  #turn to BGR
        T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]),
        T.Lambda(lambda x: x.mul_(255))
    ])

    domains = os.listdir(style_dir)
    domains.sort()
    num_domains = len(domains)
    print('Number of domains: %d' % num_domains)
    print("Batch Size:", args.batch_size)

    _, eval_loader = get_loaderFC2(data_dir, style_dir, temp_dir, transform,
                                   args.batch_size, num_workers, num_domains)

    generate_new = True

    tcl_dict = {}
    # prepare
    for d in range(1, num_domains):
        src_domain = "style0"
        trg_domain = "style" + str(d)

        t1 = '%s2%s' % (src_domain, trg_domain)
        t2 = '%s2%s' % (trg_domain, src_domain)

        tcl_dict[t1] = []
        tcl_dict[t2] = []

        if generate_new:
            create_task_folders(eval_dir, t1)
            #create_task_folders(eval_dir, t2)

    # generate
    for i, x_src_all in enumerate(tqdm(eval_loader, total=len(eval_loader))):
        x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all

        x_real = x_real.to(device)
        x_real2 = x_real2.to(device)
        y_org = y_org.to(device)
        x_ref = x_ref.to(device)
        y_trg = y_trg.to(device)
        mask = mask.to(device)
        flow = flow.to(device)

        mask_zero = torch.zeros(mask.shape).to(device)

        N = x_real.size(0)
        #y = y_trg.cpu().numpy()

        for k in range(N):
            y_org_np = y_org[k].cpu().numpy()
            y_trg_np = y_trg[k].cpu().numpy()
            src_domain = "style" + str(y_org_np)
            trg_domain = "style" + str(y_trg_np)

            if src_domain == trg_domain or y_trg_np == 0:
                continue

            task = '%s2%s' % (src_domain, trg_domain)
            net.set_style(y_trg_np - 1)

            x_fake = net.run(x_real, x_real, y_trg_np - 1, mask_zero,
                             args.weight_tcl)
            x_warp = warp(x_fake, flow)
            #x_fake2 = net.run(mask*x_warp  + (1 - mask)*x_real2, x_real2, y_trg_np - 1, mask)
            x_fake2 = net.run(x_warp, x_real2, y_trg_np - 1, mask,
                              args.weight_tcl)

            tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2, 3))**0.5

            tcl_dict[task].append(tcl_err[k].cpu().numpy())

            path_ref = os.path.join(eval_dir, task + "/ref")
            path_fake = os.path.join(eval_dir, task + "/fake")

            if generate_new:
                filename = os.path.join(
                    path_ref, '%.4i.png' % (i * args.batch_size + (k + 1)))
                if y_trg_np - 1 == 2:
                    out_img = net.postp2(x_ref.data[0].cpu())
                else:
                    out_img = net.postp(x_ref.data[0].cpu())
                out_img.save(filename)

            filename = os.path.join(
                path_fake, '%.4i.png' % (i * args.batch_size + (k + 1)))
            if y_trg_np - 1 == 2:
                out_img = net.postp2(x_fake.data[0].cpu())
            else:
                out_img = net.postp(x_fake.data[0].cpu())
            out_img.save(filename)

    # evaluate
    print("computing fid, lpips and tcl")

    tasks = [
        dir for dir in os.listdir(eval_dir)
        if os.path.isdir(os.path.join(eval_dir, dir))
    ]
    tasks.sort()

    # fid and lpips
    fid_values = OrderedDict()
    #lpips_dict = OrderedDict()
    tcl_values = OrderedDict()
    for task in tasks:
        print(task)
        path_ref = os.path.join(eval_dir, task + "/ref")
        path_fake = os.path.join(eval_dir, task + "/fake")

        tcl_data = tcl_dict[task]

        print("TCL", len(tcl_data))
        tcl_mean = np.array(tcl_data).mean()
        print(tcl_mean)
        tcl_values['TCL_%s' % (task)] = float(tcl_mean)

        print("FID")
        fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake],
                                              img_size=256,
                                              batch_size=args.batch_size)
        fid_values['FID_%s' % (task)] = fid_value

    # calculate the average FID for all tasks
    fid_mean = 0
    for key, value in fid_values.items():
        fid_mean += value / len(fid_values)

    fid_values['FID_mean'] = fid_mean

    # report FID values
    filename = os.path.join(eval_dir, 'FID.json')
    utils.save_json(fid_values, filename)

    # calculate the average TCL for all tasks
    tcl_mean = 0
    for _, value in tcl_values.items():
        tcl_mean += value / len(tcl_values)

    tcl_values['TCL_mean'] = float(tcl_mean)

    # report TCL values
    filename = os.path.join(eval_dir, 'TCL.json')
    utils.save_json(tcl_values, filename)
def eval_sintel(net, args):
    sintel_dir = "G:/Datasets/MPI-Sintel-complete/"
    #sintel_dir="/srv/local/tomstrident/datasets/MPI-Sintel-complete/"

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    out_path = os.getcwd() + "/eval_sintel/" + str(args.weight_tcl) + "/"
    raft_model = initRaftModel(args)

    pyr_shapes = [(109, 256), (218, 512), (436, 1024)]
    net.set_shapes(pyr_shapes)

    num_domains = 4
    net.batch_size = 1

    #transform = []
    #transform.append(transforms.ToTensor())
    #transform.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    #transform = transforms.Compose(transform)
    transform = T.Compose([  #T.Resize((436, 1024)),
        T.ToTensor(),
        T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),  #turn to BGR
        T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]),
        T.Lambda(lambda x: x.mul_(255))
    ])

    transform2 = T.Compose([  #T.Resize((436, 1024)),
        #T.ToTensor(),
        T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),  #turn to BGR
        T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]),
        T.Lambda(lambda x: x.mul_(255))
    ])

    train_dir = os.path.join(sintel_dir, "training", "final")
    train_list = os.listdir(train_dir)
    train_list.sort()

    test_dir = os.path.join(sintel_dir, "test", "final")
    test_list = os.listdir(test_dir)
    test_list.sort()

    video_list = [os.path.join(train_dir, vid) for vid in train_list]
    video_list += [os.path.join(test_dir, vid) for vid in test_list]

    #video_list = video_list[:1]

    vid_list = train_list + test_list
    tcl_st_dict = {}
    tcl_lt_dict = {}

    tcl_st_dict = OrderedDict()
    tcl_lt_dict = OrderedDict()
    dt_dict = OrderedDict()

    for j, vid_dir in enumerate(video_list):
        vid = vid_list[j]

        sintel_dset = SingleSintelVideo(vid_dir, transform)
        loader = data.DataLoader(dataset=sintel_dset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0)

        for y in range(1, num_domains):
            #y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device)
            key = vid + "_s" + str(y)
            vid_path = os.path.join(out_path, key)
            if not os.path.exists(vid_path):
                os.makedirs(vid_path)

            tcl_st_vals = []
            tcl_lt_vals = []
            dt_vals = []

            net.set_style(y - 1)
            x_fake = []
            styled_past = []
            #imgs_past = []

            for i, imgs in enumerate(tqdm(loader, total=len(loader))):
                img, img_last, img_past = imgs

                img = img.to(device)
                img_last = img_last.to(device)
                img_past = img_past.to(device)

                if i > 0:
                    ff_last = computeRAFT(raft_model, img_last, img)
                    bf_last = computeRAFT(raft_model, img, img_last)
                    mask_last = fbcCheckTorch(ff_last, bf_last)
                    #pre = warp(styled_past[-1], bf_last)
                    #pre = mask_last*warp(styled_past[-1], bf_last)
                    pre = mask_last * warp(styled_past[-1],
                                           bf_last) + (1 - mask_last) * img
                    #pre = img
                    #net.postp(pre.data[0].cpu()).save("test%d.png" % i)
                else:
                    pre = img
                    mask_last = torch.zeros((1, ) + img.shape[2:]).to(
                        device).unsqueeze(1)

                #pre = transform2(torch.randn(img.size())[0]).unsqueeze(0).to(device)

                #pre = img
                mask_last = torch.zeros((1, ) +
                                        img.shape[2:]).to(device).unsqueeze(1)
                '''
        if i > 1:
          ff_last = computeRAFT(raft_model, imgs_past[-2], img)
          bf_last = computeRAFT(raft_model, img, imgs_past[-2])
          mask_last = torch.clamp(mask_last - fbcCheckTorch(ff_last, bf_last), 0.0, 1.0)
          pre = mask_last*warp(styled_past[-2], bf_last) + (1 - mask_last)*pre
          #net.postp(pre.data[0].cpu()).save("test%d.png" % i)
          #blah
        '''

                #save_image(img[0], ncol=1, filename="blah.png")

                t_start = time.time()
                x_fake = net.run(pre, img, y - 1, mask_last, args.weight_tcl)
                t_end = time.time()

                #save_image(x_fake[0], ncol=1, filename="blah2.png")
                #blah

                dt_vals.append((t_end - t_start) * 1000)
                styled_past.append(x_fake)
                #imgs_past.append(img)

                if i > 0:
                    tcl_st = ((mask_last * (x_fake - pre))**2).mean()**0.5
                    tcl_st_vals.append(tcl_st.cpu().numpy())
                    #blah
                    #tcl_st = computeTCL(net, raft_model, x_fake, img, img_last, y - 1)
                    #tcl_st_vals.append(tcl_st.cpu().numpy())

                if i >= 5:
                    ff_past = computeRAFT(raft_model, img_past, img)
                    bf_past = computeRAFT(raft_model, img, img_past)
                    mask_past = fbcCheckTorch(ff_past, bf_past)
                    warp_past = warp(styled_past[0], bf_past)
                    tcl_lt = ((mask_past *
                               (x_fake - warp_past))**2).mean()**0.5
                    tcl_lt_vals.append(tcl_lt.cpu().numpy())
                    styled_past.pop(0)
                    #imgs_past.pop(0)

                filename = os.path.join(vid_path, "frame_%04d.png" % i)
                #save_image(x_fake[0], ncol=1, filename=filename)
                if y - 1 == 2:
                    out_img = net.postp2(x_fake.data[0].cpu())
                else:
                    out_img = net.postp(x_fake.data[0].cpu())
                out_img.save(filename)

            tcl_st_dict["TCL-ST_" + key] = float(np.array(tcl_st_vals).mean())
            tcl_lt_dict["TCL-LT_" + key] = float(np.array(tcl_lt_vals).mean())
            dt_dict["DT_" + key] = float(np.array(dt_vals).mean())

    save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains)
    save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains)
    save_dict_as_json("DT", dt_dict, out_path, num_domains)
def computeTCL(net, model, img_fake, img1, img2, sid):
    ff_last = computeRAFT(model, img2, img1)
    bf_last = computeRAFT(model, img1, img2)
    mask_last = fbcCheckTorch(ff_last, bf_last)
    warp_last = warp(net.run(img2, sid), bf_last)
    return ((mask_last * (img_fake - warp_last))**2).mean()**0.5