def computeTCL(net, model, img_fake, img1, img2, c_trg):
    ff_last = computeRAFT(model, img2, img1)
    bf_last = computeRAFT(model, img1, img2)
    mask_last = fbcCheckTorch(ff_last, bf_last)
    #warp_last = warp(net.generator(img2, s_trg), bf_last)
    warp_last = warp(net(img2, c_trg), bf_last)
    return ((mask_last * (img_fake - warp_last))**2).mean()**0.5
    def forward_train(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))
        
        # 2nd frame
        self.fake_B2 = self.netG_A(self.real_A2)  # G_A(A)
        self.rec_A2 = self.netG_B(self.fake_B2)   # G_B(G_A(A))
        self.fake_A2 = self.netG_B(self.real_B2)  # G_B(B)
        self.rec_B2 = self.netG_A(self.fake_A2)   # G_A(G_B(B))

        self.ff_real_A = self.computeRAFT(self.real_A, self.real_A2)
        self.bf_real_A = self.computeRAFT(self.real_A2, self.real_A)
        self.bf_fake_B = self.computeRAFT(self.fake_B2, self.fake_B)
        self.bf_rec_A = self.computeRAFT(self.rec_A2, self.rec_A)
        self.bf_M_A = self.netM_A(self.bf_real_A)
        self.warp_B = warp(self.fake_B, self.bf_M_A)
        self.mask_A = fbcCheckTorch(self.ff_real_A, self.bf_real_A)
        
        #print("real_A", self.real_A.requires_grad)
        #print("fake_B", self.fake_B.requires_grad)
        #print("real_A2", self.real_A2.requires_grad)
        #print("ff_real_A", self.ff_real_A.requires_grad)
        #print("bf_M_A", self.bf_M_A.requires_grad)
        #print("warp_B", self.warp_B.requires_grad)
        #print("mask_A", self.mask_A.requires_grad)
        
        self.ff_real_B = self.computeRAFT(self.real_B, self.real_B2)
        self.bf_real_B = self.computeRAFT(self.real_B2, self.real_B)
        self.bf_fake_A = self.computeRAFT(self.fake_A2, self.fake_A)
        self.bf_rec_B = self.computeRAFT(self.rec_B2, self.rec_B)
        self.bf_M_B = self.netM_B(self.bf_real_B)
        self.warp_A = warp(self.fake_A, self.bf_M_B)
        self.mask_B = fbcCheckTorch(self.ff_real_B, self.bf_real_B)
    def evaluate_sintel(self,
                        args,
                        n_styles,
                        epochs,
                        n_epochs,
                        emphasis_parameter,
                        sintel_dir="D:/Datasets/MPI-Sintel-complete/",
                        batchsize=16,
                        learning_rate=1e-3,
                        dset='FC2'):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        out_path = "G:/Code/LBST/eval_sintel/" + self.method + "/"
        raft_model = initRaftModel(args)

        num_domains = 4

        transform = []
        transform.append(transforms.ToTensor())
        transform.append(
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
        transform = transforms.Compose(transform)

        train_dir = os.path.join(sintel_dir, "training", "final")
        train_list = os.listdir(train_dir)
        train_list.sort()

        test_dir = os.path.join(sintel_dir, "test", "final")
        test_list = os.listdir(test_dir)
        test_list.sort()

        #video_list = [os.path.join(train_dir, vid) for vid in train_list]
        #video_list += [os.path.join(test_dir, vid) for vid in test_list]

        video_list = [
            os.path.join(train_dir, "alley_2"),
            os.path.join(train_dir, "market_6"),
            os.path.join(train_dir, "temple_2")
        ]

        #vid_list = train_list + test_list
        vid_list = ["alley_2", "market_6", "temple_2"]

        tcl_st_dict = {}
        tcl_lt_dict = {}

        tcl_st_dict = OrderedDict()
        tcl_lt_dict = OrderedDict()
        dt_dict = OrderedDict()

        #emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles)
        tmp_dir = self.train_dir + dset + '/' + self.method + '/'
        tmp_list = os.listdir(tmp_dir)
        tmp_list.sort()
        #run_id = self.setup_method(run_id, emphasis_parameter.T)

        if self.method == "ruder":
            self.model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device)
            self.pre_style_model = FastStyleNet(3, n_styles).to(self.device)
        else:
            self.model = FastStyleNet(3, n_styles).to(self.device)

        first = True

        for j, vid_dir in enumerate(video_list):
            vid = vid_list[j]

            #print(vid_dir)

            sintel_dset = SingleSintelVideo(vid_dir, transform)
            loader = data.DataLoader(dataset=sintel_dset,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=0)

            for y in range(1, num_domains):
                y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device)
                key = vid + "_s" + str(y)
                vid_path = os.path.join(out_path, key)
                if not os.path.exists(vid_path):
                    os.makedirs(vid_path)

                if y == 3:
                    gray = True
                else:
                    gray = False

                tcl_st_vals = []
                tcl_lt_vals = []
                dt_vals = []

                #if n_styles > 1:
                #  run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate))
                #else:
                #  run_id = "sid%d_ep%d_bs%d_lr%d" % (y - 1, epochs, batchsize, np.log10(learning_rate))

                if n_styles > 1:
                    if first:
                        self.model.load_state_dict(
                            torch.load(tmp_dir + '/' + tmp_list[y - 1] +
                                       '/epoch_' + str(n_epochs) + '.pth'))
                        first = False
                else:
                    #print(tmp_dir + '/' + tmp_list[y-1] + '/epoch_' + str(n_epochs) + '.pth')
                    self.model.load_state_dict(
                        torch.load(tmp_dir + '/' + tmp_list[y - 1] +
                                   '/epoch_' + str(n_epochs) + '.pth'))

                if self.method == "ruder":
                    pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + str(
                        y - 1) + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth"
                    self.pre_style_model.load_state_dict(
                        torch.load(pre_style_path))

                past_sty_list = []

                for i, imgs in enumerate(tqdm(loader, total=len(loader))):
                    img, img_last, img_past = imgs

                    img = img.to(device)
                    img_last = img_last.to(device)
                    img_past = img_past.to(device)

                    #save_image(img[0], ncol=1, filename="blah.png")
                    if i > 0:
                        ff_last = computeRAFT(raft_model, img_last, img)
                        bf_last = computeRAFT(raft_model, img, img_last)
                        mask_last = fbcCheckTorch(ff_last, bf_last)
                        x_fake_last = past_sty_list[
                            -1]  #self.infer_method((img_last, None, None), y_trg - 1)
                        warp_last = warp(torch.clamp(x_fake_last, 0.0, 1.0),
                                         bf_last)
                    else:
                        mask_last = None
                        warp_last = None
                    #mask, x_warp

                    t_start = time.time()
                    x_fake = self.infer_method((img, mask_last, warp_last),
                                               y_trg - 1)
                    x_fake = torch.clamp(x_fake, 0.0, 1.0)
                    t_end = time.time()

                    past_sty_list.append(x_fake)
                    dt_vals.append((t_end - t_start) * 1000)

                    if i > 0:
                        tcl_st = ((mask_last *
                                   (x_fake - warp_last))**2).mean()**0.5
                        tcl_st_vals.append(tcl_st.cpu().numpy())

                    if i >= 5:
                        ff_past = computeRAFT(raft_model, img_past, img)
                        bf_past = computeRAFT(raft_model, img, img_past)
                        mask_past = fbcCheckTorch(ff_past, bf_past)
                        #torch.clamp(self.infer_method((img_past, None, None), y_trg - 1), 0.0, 1.0)
                        warp_past = warp(past_sty_list[0], bf_past)
                        tcl_lt = ((mask_past *
                                   (x_fake - warp_past))**2).mean()**0.5
                        tcl_lt_vals.append(tcl_lt.cpu().numpy())
                        '''
            print(img.shape)
            print(img_past.shape)
            print(warp_past.shape)
            print(x_fake.shape)
            print(past_sty_list[0].shape)
            
            save_image(denormalize(img[0]), ncol=1, filename="blah1.png")
            save_image(denormalize(img_past[0]), ncol=1, filename="blah2.png")
            save_image(warp_past[0], ncol=1, filename="blah3.png")
            save_image(x_fake[0], ncol=1, filename="blah4.png")
            save_image(past_sty_list[0][0], ncol=1, filename="blah5.png")
            save_image(mask_past*warp_past, ncol=1, filename="blah6.png")
            blah'''

                        past_sty_list.pop(0)

                    filename = os.path.join(vid_path, "frame_%04d.png" % i)
                    save_image(x_fake[0], ncol=1, filename=filename, gray=gray)

                tcl_st_dict["TCL-ST_" + key] = float(
                    np.array(tcl_st_vals).mean())
                tcl_lt_dict["TCL-LT_" + key] = float(
                    np.array(tcl_lt_vals).mean())
                dt_dict["DT_" + key] = float(np.array(dt_vals).mean())

        save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains)
        save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains)
        save_dict_as_json("DT", dt_dict, out_path, num_domains)
def eval_sintel(net, args):
    sintel_dir = "G:/Datasets/MPI-Sintel-complete/"
    #sintel_dir="/srv/local/tomstrident/datasets/MPI-Sintel-complete/"

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    out_path = os.getcwd() + "/eval_sintel/" + str(args.weight_tcl) + "/"
    raft_model = initRaftModel(args)

    pyr_shapes = [(109, 256), (218, 512), (436, 1024)]
    net.set_shapes(pyr_shapes)

    num_domains = 4
    net.batch_size = 1

    #transform = []
    #transform.append(transforms.ToTensor())
    #transform.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
    #transform = transforms.Compose(transform)
    transform = T.Compose([  #T.Resize((436, 1024)),
        T.ToTensor(),
        T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),  #turn to BGR
        T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]),
        T.Lambda(lambda x: x.mul_(255))
    ])

    transform2 = T.Compose([  #T.Resize((436, 1024)),
        #T.ToTensor(),
        T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]),  #turn to BGR
        T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]),
        T.Lambda(lambda x: x.mul_(255))
    ])

    train_dir = os.path.join(sintel_dir, "training", "final")
    train_list = os.listdir(train_dir)
    train_list.sort()

    test_dir = os.path.join(sintel_dir, "test", "final")
    test_list = os.listdir(test_dir)
    test_list.sort()

    video_list = [os.path.join(train_dir, vid) for vid in train_list]
    video_list += [os.path.join(test_dir, vid) for vid in test_list]

    #video_list = video_list[:1]

    vid_list = train_list + test_list
    tcl_st_dict = {}
    tcl_lt_dict = {}

    tcl_st_dict = OrderedDict()
    tcl_lt_dict = OrderedDict()
    dt_dict = OrderedDict()

    for j, vid_dir in enumerate(video_list):
        vid = vid_list[j]

        sintel_dset = SingleSintelVideo(vid_dir, transform)
        loader = data.DataLoader(dataset=sintel_dset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=0)

        for y in range(1, num_domains):
            #y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device)
            key = vid + "_s" + str(y)
            vid_path = os.path.join(out_path, key)
            if not os.path.exists(vid_path):
                os.makedirs(vid_path)

            tcl_st_vals = []
            tcl_lt_vals = []
            dt_vals = []

            net.set_style(y - 1)
            x_fake = []
            styled_past = []
            #imgs_past = []

            for i, imgs in enumerate(tqdm(loader, total=len(loader))):
                img, img_last, img_past = imgs

                img = img.to(device)
                img_last = img_last.to(device)
                img_past = img_past.to(device)

                if i > 0:
                    ff_last = computeRAFT(raft_model, img_last, img)
                    bf_last = computeRAFT(raft_model, img, img_last)
                    mask_last = fbcCheckTorch(ff_last, bf_last)
                    #pre = warp(styled_past[-1], bf_last)
                    #pre = mask_last*warp(styled_past[-1], bf_last)
                    pre = mask_last * warp(styled_past[-1],
                                           bf_last) + (1 - mask_last) * img
                    #pre = img
                    #net.postp(pre.data[0].cpu()).save("test%d.png" % i)
                else:
                    pre = img
                    mask_last = torch.zeros((1, ) + img.shape[2:]).to(
                        device).unsqueeze(1)

                #pre = transform2(torch.randn(img.size())[0]).unsqueeze(0).to(device)

                #pre = img
                mask_last = torch.zeros((1, ) +
                                        img.shape[2:]).to(device).unsqueeze(1)
                '''
        if i > 1:
          ff_last = computeRAFT(raft_model, imgs_past[-2], img)
          bf_last = computeRAFT(raft_model, img, imgs_past[-2])
          mask_last = torch.clamp(mask_last - fbcCheckTorch(ff_last, bf_last), 0.0, 1.0)
          pre = mask_last*warp(styled_past[-2], bf_last) + (1 - mask_last)*pre
          #net.postp(pre.data[0].cpu()).save("test%d.png" % i)
          #blah
        '''

                #save_image(img[0], ncol=1, filename="blah.png")

                t_start = time.time()
                x_fake = net.run(pre, img, y - 1, mask_last, args.weight_tcl)
                t_end = time.time()

                #save_image(x_fake[0], ncol=1, filename="blah2.png")
                #blah

                dt_vals.append((t_end - t_start) * 1000)
                styled_past.append(x_fake)
                #imgs_past.append(img)

                if i > 0:
                    tcl_st = ((mask_last * (x_fake - pre))**2).mean()**0.5
                    tcl_st_vals.append(tcl_st.cpu().numpy())
                    #blah
                    #tcl_st = computeTCL(net, raft_model, x_fake, img, img_last, y - 1)
                    #tcl_st_vals.append(tcl_st.cpu().numpy())

                if i >= 5:
                    ff_past = computeRAFT(raft_model, img_past, img)
                    bf_past = computeRAFT(raft_model, img, img_past)
                    mask_past = fbcCheckTorch(ff_past, bf_past)
                    warp_past = warp(styled_past[0], bf_past)
                    tcl_lt = ((mask_past *
                               (x_fake - warp_past))**2).mean()**0.5
                    tcl_lt_vals.append(tcl_lt.cpu().numpy())
                    styled_past.pop(0)
                    #imgs_past.pop(0)

                filename = os.path.join(vid_path, "frame_%04d.png" % i)
                #save_image(x_fake[0], ncol=1, filename=filename)
                if y - 1 == 2:
                    out_img = net.postp2(x_fake.data[0].cpu())
                else:
                    out_img = net.postp(x_fake.data[0].cpu())
                out_img.save(filename)

            tcl_st_dict["TCL-ST_" + key] = float(np.array(tcl_st_vals).mean())
            tcl_lt_dict["TCL-LT_" + key] = float(np.array(tcl_lt_vals).mean())
            dt_dict["DT_" + key] = float(np.array(dt_vals).mean())

    save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains)
    save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains)
    save_dict_as_json("DT", dt_dict, out_path, num_domains)
def computeTCL(net, model, img_fake, img1, img2, sid):
    ff_last = computeRAFT(model, img2, img1)
    bf_last = computeRAFT(model, img1, img2)
    mask_last = fbcCheckTorch(ff_last, bf_last)
    warp_last = warp(net.run(img2, sid), bf_last)
    return ((mask_last * (img_fake - warp_last))**2).mean()**0.5