def computeTCL(net, model, img_fake, img1, img2, c_trg): ff_last = computeRAFT(model, img2, img1) bf_last = computeRAFT(model, img1, img2) mask_last = fbcCheckTorch(ff_last, bf_last) #warp_last = warp(net.generator(img2, s_trg), bf_last) warp_last = warp(net(img2, c_trg), bf_last) return ((mask_last * (img_fake - warp_last))**2).mean()**0.5
def forward_train(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG_A(self.real_A) # G_A(A) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.fake_A = self.netG_B(self.real_B) # G_B(B) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) # 2nd frame self.fake_B2 = self.netG_A(self.real_A2) # G_A(A) self.rec_A2 = self.netG_B(self.fake_B2) # G_B(G_A(A)) self.fake_A2 = self.netG_B(self.real_B2) # G_B(B) self.rec_B2 = self.netG_A(self.fake_A2) # G_A(G_B(B)) self.ff_real_A = self.computeRAFT(self.real_A, self.real_A2) self.bf_real_A = self.computeRAFT(self.real_A2, self.real_A) self.bf_fake_B = self.computeRAFT(self.fake_B2, self.fake_B) self.bf_rec_A = self.computeRAFT(self.rec_A2, self.rec_A) self.bf_M_A = self.netM_A(self.bf_real_A) self.warp_B = warp(self.fake_B, self.bf_M_A) self.mask_A = fbcCheckTorch(self.ff_real_A, self.bf_real_A) #print("real_A", self.real_A.requires_grad) #print("fake_B", self.fake_B.requires_grad) #print("real_A2", self.real_A2.requires_grad) #print("ff_real_A", self.ff_real_A.requires_grad) #print("bf_M_A", self.bf_M_A.requires_grad) #print("warp_B", self.warp_B.requires_grad) #print("mask_A", self.mask_A.requires_grad) self.ff_real_B = self.computeRAFT(self.real_B, self.real_B2) self.bf_real_B = self.computeRAFT(self.real_B2, self.real_B) self.bf_fake_A = self.computeRAFT(self.fake_A2, self.fake_A) self.bf_rec_B = self.computeRAFT(self.rec_B2, self.rec_B) self.bf_M_B = self.netM_B(self.bf_real_B) self.warp_A = warp(self.fake_A, self.bf_M_B) self.mask_B = fbcCheckTorch(self.ff_real_B, self.bf_real_B)
def evaluate_sintel(self, args, n_styles, epochs, n_epochs, emphasis_parameter, sintel_dir="D:/Datasets/MPI-Sintel-complete/", batchsize=16, learning_rate=1e-3, dset='FC2'): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') out_path = "G:/Code/LBST/eval_sintel/" + self.method + "/" raft_model = initRaftModel(args) num_domains = 4 transform = [] transform.append(transforms.ToTensor()) transform.append( transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = transforms.Compose(transform) train_dir = os.path.join(sintel_dir, "training", "final") train_list = os.listdir(train_dir) train_list.sort() test_dir = os.path.join(sintel_dir, "test", "final") test_list = os.listdir(test_dir) test_list.sort() #video_list = [os.path.join(train_dir, vid) for vid in train_list] #video_list += [os.path.join(test_dir, vid) for vid in test_list] video_list = [ os.path.join(train_dir, "alley_2"), os.path.join(train_dir, "market_6"), os.path.join(train_dir, "temple_2") ] #vid_list = train_list + test_list vid_list = ["alley_2", "market_6", "temple_2"] tcl_st_dict = {} tcl_lt_dict = {} tcl_st_dict = OrderedDict() tcl_lt_dict = OrderedDict() dt_dict = OrderedDict() #emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles) tmp_dir = self.train_dir + dset + '/' + self.method + '/' tmp_list = os.listdir(tmp_dir) tmp_list.sort() #run_id = self.setup_method(run_id, emphasis_parameter.T) if self.method == "ruder": self.model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device) self.pre_style_model = FastStyleNet(3, n_styles).to(self.device) else: self.model = FastStyleNet(3, n_styles).to(self.device) first = True for j, vid_dir in enumerate(video_list): vid = vid_list[j] #print(vid_dir) sintel_dset = SingleSintelVideo(vid_dir, transform) loader = data.DataLoader(dataset=sintel_dset, batch_size=1, shuffle=False, num_workers=0) for y in range(1, num_domains): y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device) key = vid + "_s" + str(y) vid_path = os.path.join(out_path, key) if not os.path.exists(vid_path): os.makedirs(vid_path) if y == 3: gray = True else: gray = False tcl_st_vals = [] tcl_lt_vals = [] dt_vals = [] #if n_styles > 1: # run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate)) #else: # run_id = "sid%d_ep%d_bs%d_lr%d" % (y - 1, epochs, batchsize, np.log10(learning_rate)) if n_styles > 1: if first: self.model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[y - 1] + '/epoch_' + str(n_epochs) + '.pth')) first = False else: #print(tmp_dir + '/' + tmp_list[y-1] + '/epoch_' + str(n_epochs) + '.pth') self.model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[y - 1] + '/epoch_' + str(n_epochs) + '.pth')) if self.method == "ruder": pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + str( y - 1) + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth" self.pre_style_model.load_state_dict( torch.load(pre_style_path)) past_sty_list = [] for i, imgs in enumerate(tqdm(loader, total=len(loader))): img, img_last, img_past = imgs img = img.to(device) img_last = img_last.to(device) img_past = img_past.to(device) #save_image(img[0], ncol=1, filename="blah.png") if i > 0: ff_last = computeRAFT(raft_model, img_last, img) bf_last = computeRAFT(raft_model, img, img_last) mask_last = fbcCheckTorch(ff_last, bf_last) x_fake_last = past_sty_list[ -1] #self.infer_method((img_last, None, None), y_trg - 1) warp_last = warp(torch.clamp(x_fake_last, 0.0, 1.0), bf_last) else: mask_last = None warp_last = None #mask, x_warp t_start = time.time() x_fake = self.infer_method((img, mask_last, warp_last), y_trg - 1) x_fake = torch.clamp(x_fake, 0.0, 1.0) t_end = time.time() past_sty_list.append(x_fake) dt_vals.append((t_end - t_start) * 1000) if i > 0: tcl_st = ((mask_last * (x_fake - warp_last))**2).mean()**0.5 tcl_st_vals.append(tcl_st.cpu().numpy()) if i >= 5: ff_past = computeRAFT(raft_model, img_past, img) bf_past = computeRAFT(raft_model, img, img_past) mask_past = fbcCheckTorch(ff_past, bf_past) #torch.clamp(self.infer_method((img_past, None, None), y_trg - 1), 0.0, 1.0) warp_past = warp(past_sty_list[0], bf_past) tcl_lt = ((mask_past * (x_fake - warp_past))**2).mean()**0.5 tcl_lt_vals.append(tcl_lt.cpu().numpy()) ''' print(img.shape) print(img_past.shape) print(warp_past.shape) print(x_fake.shape) print(past_sty_list[0].shape) save_image(denormalize(img[0]), ncol=1, filename="blah1.png") save_image(denormalize(img_past[0]), ncol=1, filename="blah2.png") save_image(warp_past[0], ncol=1, filename="blah3.png") save_image(x_fake[0], ncol=1, filename="blah4.png") save_image(past_sty_list[0][0], ncol=1, filename="blah5.png") save_image(mask_past*warp_past, ncol=1, filename="blah6.png") blah''' past_sty_list.pop(0) filename = os.path.join(vid_path, "frame_%04d.png" % i) save_image(x_fake[0], ncol=1, filename=filename, gray=gray) tcl_st_dict["TCL-ST_" + key] = float( np.array(tcl_st_vals).mean()) tcl_lt_dict["TCL-LT_" + key] = float( np.array(tcl_lt_vals).mean()) dt_dict["DT_" + key] = float(np.array(dt_vals).mean()) save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains) save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains) save_dict_as_json("DT", dt_dict, out_path, num_domains)
def eval_sintel(net, args): sintel_dir = "G:/Datasets/MPI-Sintel-complete/" #sintel_dir="/srv/local/tomstrident/datasets/MPI-Sintel-complete/" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') out_path = os.getcwd() + "/eval_sintel/" + str(args.weight_tcl) + "/" raft_model = initRaftModel(args) pyr_shapes = [(109, 256), (218, 512), (436, 1024)] net.set_shapes(pyr_shapes) num_domains = 4 net.batch_size = 1 #transform = [] #transform.append(transforms.ToTensor()) #transform.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) #transform = transforms.Compose(transform) transform = T.Compose([ #T.Resize((436, 1024)), T.ToTensor(), T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]), #turn to BGR T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]), T.Lambda(lambda x: x.mul_(255)) ]) transform2 = T.Compose([ #T.Resize((436, 1024)), #T.ToTensor(), T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]), #turn to BGR T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]), T.Lambda(lambda x: x.mul_(255)) ]) train_dir = os.path.join(sintel_dir, "training", "final") train_list = os.listdir(train_dir) train_list.sort() test_dir = os.path.join(sintel_dir, "test", "final") test_list = os.listdir(test_dir) test_list.sort() video_list = [os.path.join(train_dir, vid) for vid in train_list] video_list += [os.path.join(test_dir, vid) for vid in test_list] #video_list = video_list[:1] vid_list = train_list + test_list tcl_st_dict = {} tcl_lt_dict = {} tcl_st_dict = OrderedDict() tcl_lt_dict = OrderedDict() dt_dict = OrderedDict() for j, vid_dir in enumerate(video_list): vid = vid_list[j] sintel_dset = SingleSintelVideo(vid_dir, transform) loader = data.DataLoader(dataset=sintel_dset, batch_size=1, shuffle=False, num_workers=0) for y in range(1, num_domains): #y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device) key = vid + "_s" + str(y) vid_path = os.path.join(out_path, key) if not os.path.exists(vid_path): os.makedirs(vid_path) tcl_st_vals = [] tcl_lt_vals = [] dt_vals = [] net.set_style(y - 1) x_fake = [] styled_past = [] #imgs_past = [] for i, imgs in enumerate(tqdm(loader, total=len(loader))): img, img_last, img_past = imgs img = img.to(device) img_last = img_last.to(device) img_past = img_past.to(device) if i > 0: ff_last = computeRAFT(raft_model, img_last, img) bf_last = computeRAFT(raft_model, img, img_last) mask_last = fbcCheckTorch(ff_last, bf_last) #pre = warp(styled_past[-1], bf_last) #pre = mask_last*warp(styled_past[-1], bf_last) pre = mask_last * warp(styled_past[-1], bf_last) + (1 - mask_last) * img #pre = img #net.postp(pre.data[0].cpu()).save("test%d.png" % i) else: pre = img mask_last = torch.zeros((1, ) + img.shape[2:]).to( device).unsqueeze(1) #pre = transform2(torch.randn(img.size())[0]).unsqueeze(0).to(device) #pre = img mask_last = torch.zeros((1, ) + img.shape[2:]).to(device).unsqueeze(1) ''' if i > 1: ff_last = computeRAFT(raft_model, imgs_past[-2], img) bf_last = computeRAFT(raft_model, img, imgs_past[-2]) mask_last = torch.clamp(mask_last - fbcCheckTorch(ff_last, bf_last), 0.0, 1.0) pre = mask_last*warp(styled_past[-2], bf_last) + (1 - mask_last)*pre #net.postp(pre.data[0].cpu()).save("test%d.png" % i) #blah ''' #save_image(img[0], ncol=1, filename="blah.png") t_start = time.time() x_fake = net.run(pre, img, y - 1, mask_last, args.weight_tcl) t_end = time.time() #save_image(x_fake[0], ncol=1, filename="blah2.png") #blah dt_vals.append((t_end - t_start) * 1000) styled_past.append(x_fake) #imgs_past.append(img) if i > 0: tcl_st = ((mask_last * (x_fake - pre))**2).mean()**0.5 tcl_st_vals.append(tcl_st.cpu().numpy()) #blah #tcl_st = computeTCL(net, raft_model, x_fake, img, img_last, y - 1) #tcl_st_vals.append(tcl_st.cpu().numpy()) if i >= 5: ff_past = computeRAFT(raft_model, img_past, img) bf_past = computeRAFT(raft_model, img, img_past) mask_past = fbcCheckTorch(ff_past, bf_past) warp_past = warp(styled_past[0], bf_past) tcl_lt = ((mask_past * (x_fake - warp_past))**2).mean()**0.5 tcl_lt_vals.append(tcl_lt.cpu().numpy()) styled_past.pop(0) #imgs_past.pop(0) filename = os.path.join(vid_path, "frame_%04d.png" % i) #save_image(x_fake[0], ncol=1, filename=filename) if y - 1 == 2: out_img = net.postp2(x_fake.data[0].cpu()) else: out_img = net.postp(x_fake.data[0].cpu()) out_img.save(filename) tcl_st_dict["TCL-ST_" + key] = float(np.array(tcl_st_vals).mean()) tcl_lt_dict["TCL-LT_" + key] = float(np.array(tcl_lt_vals).mean()) dt_dict["DT_" + key] = float(np.array(dt_vals).mean()) save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains) save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains) save_dict_as_json("DT", dt_dict, out_path, num_domains)
def computeTCL(net, model, img_fake, img1, img2, sid): ff_last = computeRAFT(model, img2, img1) bf_last = computeRAFT(model, img1, img2) mask_last = fbcCheckTorch(ff_last, bf_last) warp_last = warp(net.run(img2, sid), bf_last) return ((mask_last * (img_fake - warp_last))**2).mean()**0.5