def computeTCL(net, model, img_fake, img1, img2, c_trg): ff_last = computeRAFT(model, img2, img1) bf_last = computeRAFT(model, img1, img2) mask_last = fbcCheckTorch(ff_last, bf_last) #warp_last = warp(net.generator(img2, s_trg), bf_last) warp_last = warp(net(img2, c_trg), bf_last) return ((mask_last * (img_fake - warp_last))**2).mean()**0.5
def forward_train(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" self.fake_B = self.netG_A(self.real_A) # G_A(A) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.fake_A = self.netG_B(self.real_B) # G_B(B) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) # 2nd frame self.fake_B2 = self.netG_A(self.real_A2) # G_A(A) self.rec_A2 = self.netG_B(self.fake_B2) # G_B(G_A(A)) self.fake_A2 = self.netG_B(self.real_B2) # G_B(B) self.rec_B2 = self.netG_A(self.fake_A2) # G_A(G_B(B)) self.ff_real_A = self.computeRAFT(self.real_A, self.real_A2) self.bf_real_A = self.computeRAFT(self.real_A2, self.real_A) self.bf_fake_B = self.computeRAFT(self.fake_B2, self.fake_B) self.bf_rec_A = self.computeRAFT(self.rec_A2, self.rec_A) self.bf_M_A = self.netM_A(self.bf_real_A) self.warp_B = warp(self.fake_B, self.bf_M_A) self.mask_A = fbcCheckTorch(self.ff_real_A, self.bf_real_A) #print("real_A", self.real_A.requires_grad) #print("fake_B", self.fake_B.requires_grad) #print("real_A2", self.real_A2.requires_grad) #print("ff_real_A", self.ff_real_A.requires_grad) #print("bf_M_A", self.bf_M_A.requires_grad) #print("warp_B", self.warp_B.requires_grad) #print("mask_A", self.mask_A.requires_grad) self.ff_real_B = self.computeRAFT(self.real_B, self.real_B2) self.bf_real_B = self.computeRAFT(self.real_B2, self.real_B) self.bf_fake_A = self.computeRAFT(self.fake_A2, self.fake_A) self.bf_rec_B = self.computeRAFT(self.rec_B2, self.rec_B) self.bf_M_B = self.netM_B(self.bf_real_B) self.warp_A = warp(self.fake_A, self.bf_M_B) self.mask_B = fbcCheckTorch(self.ff_real_B, self.bf_real_B)
def forward_train(self): """Run forward pass; called by both functions <optimize_parameters> and <test>.""" # 1st frame (t-1) self.fake_B = self.netG_A(self.real_A) # G_A(A) XT-1 self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.fake_A = self.netG_B(self.real_B) # G_B(B) self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) # 2nd frame (t) self.fake_B2 = self.netG_A(self.real_A2) # G_A(A) self.rec_A2 = self.netG_B(self.fake_B2) # G_B(G_A(A)) self.fake_A2 = self.netG_B(self.real_B2) # G_B(B) self.rec_B2 = self.netG_A(self.fake_A2) # G_A(G_B(B)) self.bf_real_A = self.computeRAFT(self.real_A2, self.real_A) self.warp_B = warp(self.fake_B, self.bf_real_A) self.fuse_B = self.netF_A(self.fake_B2, self.warp_B) #XT self.mask_A = self.generateMask(self.real_A2, warp(self.real_A, self.bf_real_A)) self.vgg_fuse_B = self.netVGG_19(self.fuse_B) self.vgg_real_A2 = self.netVGG_19(self.real_A2) self.bf_fake_B = self.computeRAFT(self.fuse_B, self.fake_B) self.rec3D_A2 = self.netF_B(self.netG_B(self.fuse_B), warp(self.fake_B, self.bf_fake_B)) self.bf_real_B = self.computeRAFT(self.real_B2, self.real_B) self.warp_A = warp(self.fake_A, self.bf_real_B) self.fuse_A = self.netF_B(self.fake_A2, self.warp_A) self.mask_B = self.generateMask(self.real_B2, warp(self.real_B, self.bf_real_B)) self.vgg_fuse_A = self.netVGG_19(self.fuse_A) self.vgg_real_B2 = self.netVGG_19(self.real_B2) self.bf_fake_A = self.computeRAFT(self.fuse_A, self.fake_A) self.rec3D_B2 = self.netF_A(self.netG_A(self.fuse_A), warp(self.fake_A, self.bf_fake_A)) '''
def evaluate_sintel(args, sintel_dir="D:/Datasets/MPI-Sintel-complete/"): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') out_path = "G:/Code/ConGAN/eval/" raft_model = initRaftModel(args, device) #domains = os.listdir(args.style_dir) #domains.sort() num_domains = 4 #len(domains) transform = [] transform.append(transforms.ToTensor()) transform.append( transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = transforms.Compose(transform) train_dir = os.path.join(sintel_dir, "training", "final") train_list = os.listdir(train_dir) train_list.sort() test_dir = os.path.join(sintel_dir, "test", "final") test_list = os.listdir(test_dir) test_list.sort() video_list = [os.path.join(train_dir, vid) for vid in train_list] video_list += [os.path.join(test_dir, vid) for vid in test_list] vid_list = train_list + test_list tcl_st_dict = {} tcl_lt_dict = {} tcl_st_dict = OrderedDict() tcl_lt_dict = OrderedDict() dt_dict = OrderedDict() args.checkpoints_dir = os.getcwd() + "\\checkpoints\\" model_list = os.listdir(args.checkpoints_dir) model_list.sort() for j, vid_dir in enumerate(video_list): vid = vid_list[j] sintel_dset = SingleSintelVideo(vid_dir, transform) loader = data.DataLoader(dataset=sintel_dset, batch_size=1, shuffle=False, num_workers=0) for y in range(1, num_domains): #y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device) key = vid + "_s" + str(y) vid_path = os.path.join(out_path, key) if not os.path.exists(vid_path): os.makedirs(vid_path) tcl_st_vals = [] tcl_lt_vals = [] dt_vals = [] args.name = model_list[y - 1] #args.model = "cycle_gan" model = create_model(args) model.setup(args) x_fake = [] for i, imgs in enumerate(tqdm(loader, total=len(loader))): img, img_last, img_past = imgs img = img.to(device) img_last = img_last.to(device) img_past = img_past.to(device) if i > 0: B, C, H, W = img.size() bf = model.computeRAFT(img, img_last)[:, :, :H, :W] wrp = warp(x_fake, bf) if i == 0: t_start = time.time() x_fake = model.forward_eval(img) t_end = time.time() else: t_start = time.time() x_fake = model.forward_eval(img, wrp) t_end = time.time() dt_vals.append((t_end - t_start) * 1000) if i > 0: tcl_st = computeTCL(model, raft_model, x_fake, img, img_last) tcl_st_vals.append(tcl_st.cpu().numpy()) if i >= 5: tcl_lt = computeTCL(model, raft_model, x_fake, img, img_past) tcl_lt_vals.append(tcl_lt.cpu().numpy()) filename = os.path.join(vid_path, "frame_%04d.png" % i) save_image(x_fake[0], ncol=1, filename=filename) tcl_st_dict["TCL-ST_" + key] = float(np.array(tcl_st_vals).mean()) tcl_lt_dict["TCL-LT_" + key] = float(np.array(tcl_lt_vals).mean()) dt_dict["DT_" + key] = float(np.array(dt_vals).mean()) save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains) save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains) save_dict_as_json("DT", dt_dict, out_path, num_domains)
def evaluate_fc2(self, args, n_styles, epochs, n_epochs, emphasis_parameter, batchsize=16, learning_rate=1e-3, dset='FC2'): print('Calculating evaluation metrics...') #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data_dir = "G:/Datasets/FC2/DATAFiles/" style_dir = "G:/Datasets/FC2/styled-files/" temp_dir = "G:/Datasets/FC2/styled-files3/" eval_dir = os.getcwd() + "/eval_fc2/" + self.method + "/" num_workers = 0 args.batch_size = 4 domains = os.listdir(style_dir) domains.sort() num_domains = len(domains) print('Number of domains: %d' % num_domains) print("Batch Size:", args.batch_size) _, eval_loader = get_loaderFC2(data_dir, style_dir, temp_dir, args.batch_size, num_workers, num_domains) tmp_dir = self.train_dir + dset + '/' + self.method + '/' tmp_list = os.listdir(tmp_dir) tmp_list.sort() models = [] pre_models = [] if n_styles > 1: model = FastStyleNet(3, n_styles).to(self.device) model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[0] + '/epoch_' + str(n_epochs) + '.pth')) else: if self.method == "ruder": for tmp in tmp_list: model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device) model.load_state_dict( torch.load(tmp_dir + '/' + tmp + '/epoch_' + str(n_epochs) + '.pth')) models.append(model) pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + tmp[ 3] + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth" model = FastStyleNet(3, n_styles).to(self.device) model.load_state_dict(torch.load(pre_style_path)) pre_models.append(model) else: for tmp in tmp_list: model = FastStyleNet(3, n_styles).to(self.device) model.load_state_dict( torch.load(tmp_dir + '/' + tmp + '/epoch_' + str(n_epochs) + '.pth')) models.append(model) generate_new = True tcl_dict = {} # prepare for d in range(1, num_domains): src_domain = "style0" trg_domain = "style" + str(d) t1 = '%s2%s' % (src_domain, trg_domain) t2 = '%s2%s' % (trg_domain, src_domain) tcl_dict[t1] = [] tcl_dict[t2] = [] if generate_new: create_task_folders(eval_dir, t1) #create_task_folders(eval_dir, t2) # generate for i, x_src_all in enumerate(tqdm(eval_loader, total=len(eval_loader))): x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all x_real = x_real.to(self.device) x_real2 = x_real2.to(self.device) y_org = y_org.to(self.device) x_ref = x_ref.to(self.device) y_trg = y_trg.to(self.device) mask = mask.to(self.device) flow = flow.to(self.device) N = x_real.size(0) for k in range(N): y_org_np = y_org[k].cpu().numpy() y_trg_np = y_trg[k].cpu().numpy() src_domain = "style" + str(y_org_np) trg_domain = "style" + str(y_trg_np) if src_domain == trg_domain or y_trg_np == 0: continue task = '%s2%s' % (src_domain, trg_domain) if n_styles > 1: self.model = model else: self.model = models[y_trg_np - 1] if self.method == "ruder": self.pre_style_model = pre_models[y_trg_np - 1] x_fake = self.infer_method((x_real, None, None), y_trg[k] - 1) #x_fake = torch.clamp(x_fake, 0.0, 1.0) x_warp = warp(x_fake, flow) x_fake2 = self.infer_method((x_real2, mask, x_warp), y_trg[k] - 1) #x_fake2 = torch.clamp(x_fake2, 0.0, 1.0) tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2, 3))**0.5 tcl_dict[task].append(tcl_err[k].cpu().numpy()) path_ref = os.path.join(eval_dir, task + "/ref") path_fake = os.path.join(eval_dir, task + "/fake") if generate_new: filename = os.path.join( path_ref, '%.4i.png' % (i * args.batch_size + (k + 1))) save_image(denormalize(x_ref[k]), ncol=1, filename=filename) filename = os.path.join( path_fake, '%.4i.png' % (i * args.batch_size + (k + 1))) save_image(x_fake[k], ncol=1, filename=filename) # evaluate print("computing fid, lpips and tcl") tasks = [ dir for dir in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, dir)) ] tasks.sort() # fid and lpips fid_values = OrderedDict() #lpips_dict = OrderedDict() tcl_values = OrderedDict() for task in tasks: print(task) path_ref = os.path.join(eval_dir, task + "/ref") path_fake = os.path.join(eval_dir, task + "/fake") tcl_data = tcl_dict[task] print("TCL", len(tcl_data)) tcl_mean = np.array(tcl_data).mean() print(tcl_mean) tcl_values['TCL_%s' % (task)] = float(tcl_mean) print("FID") fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake], img_size=256, batch_size=args.batch_size) fid_values['FID_%s' % (task)] = fid_value # calculate the average FID for all tasks fid_mean = 0 for key, value in fid_values.items(): fid_mean += value / len(fid_values) fid_values['FID_mean'] = fid_mean # report FID values filename = os.path.join(eval_dir, 'FID.json') utils.save_json(fid_values, filename) # calculate the average TCL for all tasks tcl_mean = 0 for _, value in tcl_values.items(): tcl_mean += value / len(tcl_values) tcl_values['TCL_mean'] = float(tcl_mean) # report TCL values filename = os.path.join(eval_dir, 'TCL.json') utils.save_json(tcl_values, filename)
def evaluate_sintel(self, args, n_styles, epochs, n_epochs, emphasis_parameter, sintel_dir="D:/Datasets/MPI-Sintel-complete/", batchsize=16, learning_rate=1e-3, dset='FC2'): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') out_path = "G:/Code/LBST/eval_sintel/" + self.method + "/" raft_model = initRaftModel(args) num_domains = 4 transform = [] transform.append(transforms.ToTensor()) transform.append( transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = transforms.Compose(transform) train_dir = os.path.join(sintel_dir, "training", "final") train_list = os.listdir(train_dir) train_list.sort() test_dir = os.path.join(sintel_dir, "test", "final") test_list = os.listdir(test_dir) test_list.sort() #video_list = [os.path.join(train_dir, vid) for vid in train_list] #video_list += [os.path.join(test_dir, vid) for vid in test_list] video_list = [ os.path.join(train_dir, "alley_2"), os.path.join(train_dir, "market_6"), os.path.join(train_dir, "temple_2") ] #vid_list = train_list + test_list vid_list = ["alley_2", "market_6", "temple_2"] tcl_st_dict = {} tcl_lt_dict = {} tcl_st_dict = OrderedDict() tcl_lt_dict = OrderedDict() dt_dict = OrderedDict() #emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles) tmp_dir = self.train_dir + dset + '/' + self.method + '/' tmp_list = os.listdir(tmp_dir) tmp_list.sort() #run_id = self.setup_method(run_id, emphasis_parameter.T) if self.method == "ruder": self.model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device) self.pre_style_model = FastStyleNet(3, n_styles).to(self.device) else: self.model = FastStyleNet(3, n_styles).to(self.device) first = True for j, vid_dir in enumerate(video_list): vid = vid_list[j] #print(vid_dir) sintel_dset = SingleSintelVideo(vid_dir, transform) loader = data.DataLoader(dataset=sintel_dset, batch_size=1, shuffle=False, num_workers=0) for y in range(1, num_domains): y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device) key = vid + "_s" + str(y) vid_path = os.path.join(out_path, key) if not os.path.exists(vid_path): os.makedirs(vid_path) if y == 3: gray = True else: gray = False tcl_st_vals = [] tcl_lt_vals = [] dt_vals = [] #if n_styles > 1: # run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate)) #else: # run_id = "sid%d_ep%d_bs%d_lr%d" % (y - 1, epochs, batchsize, np.log10(learning_rate)) if n_styles > 1: if first: self.model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[y - 1] + '/epoch_' + str(n_epochs) + '.pth')) first = False else: #print(tmp_dir + '/' + tmp_list[y-1] + '/epoch_' + str(n_epochs) + '.pth') self.model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[y - 1] + '/epoch_' + str(n_epochs) + '.pth')) if self.method == "ruder": pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + str( y - 1) + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth" self.pre_style_model.load_state_dict( torch.load(pre_style_path)) past_sty_list = [] for i, imgs in enumerate(tqdm(loader, total=len(loader))): img, img_last, img_past = imgs img = img.to(device) img_last = img_last.to(device) img_past = img_past.to(device) #save_image(img[0], ncol=1, filename="blah.png") if i > 0: ff_last = computeRAFT(raft_model, img_last, img) bf_last = computeRAFT(raft_model, img, img_last) mask_last = fbcCheckTorch(ff_last, bf_last) x_fake_last = past_sty_list[ -1] #self.infer_method((img_last, None, None), y_trg - 1) warp_last = warp(torch.clamp(x_fake_last, 0.0, 1.0), bf_last) else: mask_last = None warp_last = None #mask, x_warp t_start = time.time() x_fake = self.infer_method((img, mask_last, warp_last), y_trg - 1) x_fake = torch.clamp(x_fake, 0.0, 1.0) t_end = time.time() past_sty_list.append(x_fake) dt_vals.append((t_end - t_start) * 1000) if i > 0: tcl_st = ((mask_last * (x_fake - warp_last))**2).mean()**0.5 tcl_st_vals.append(tcl_st.cpu().numpy()) if i >= 5: ff_past = computeRAFT(raft_model, img_past, img) bf_past = computeRAFT(raft_model, img, img_past) mask_past = fbcCheckTorch(ff_past, bf_past) #torch.clamp(self.infer_method((img_past, None, None), y_trg - 1), 0.0, 1.0) warp_past = warp(past_sty_list[0], bf_past) tcl_lt = ((mask_past * (x_fake - warp_past))**2).mean()**0.5 tcl_lt_vals.append(tcl_lt.cpu().numpy()) ''' print(img.shape) print(img_past.shape) print(warp_past.shape) print(x_fake.shape) print(past_sty_list[0].shape) save_image(denormalize(img[0]), ncol=1, filename="blah1.png") save_image(denormalize(img_past[0]), ncol=1, filename="blah2.png") save_image(warp_past[0], ncol=1, filename="blah3.png") save_image(x_fake[0], ncol=1, filename="blah4.png") save_image(past_sty_list[0][0], ncol=1, filename="blah5.png") save_image(mask_past*warp_past, ncol=1, filename="blah6.png") blah''' past_sty_list.pop(0) filename = os.path.join(vid_path, "frame_%04d.png" % i) save_image(x_fake[0], ncol=1, filename=filename, gray=gray) tcl_st_dict["TCL-ST_" + key] = float( np.array(tcl_st_vals).mean()) tcl_lt_dict["TCL-LT_" + key] = float( np.array(tcl_lt_vals).mean()) dt_dict["DT_" + key] = float(np.array(dt_vals).mean()) save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains) save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains) save_dict_as_json("DT", dt_dict, out_path, num_domains)
def eval_fc2(net, args): print('Calculating evaluation metrics...') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data_dir = "G:/Datasets/FC2/DATAFiles/" style_dir = "G:/Datasets/FC2/styled-files/" temp_dir = "G:/Datasets/FC2/styled-files3/" #data_dir = "/srv/local/tomstrident/datasets/FC2/DATAFiles/" #style_dir = "/srv/local/tomstrident/datasets/FC2/styled-files/" #temp_dir = "/srv/local/tomstrident/datasets/FC2/styled-files3/" eval_dir = os.getcwd() + "/eval_fc2/" + str(args.weight_tcl) + "/" num_workers = 0 net.batch_size = 1 #args.batch_size pyr_shapes = [(64, 64), (128, 128), (256, 256)] net.set_shapes(pyr_shapes) transform = T.Compose([ #T.Resize(pyr_shapes[-1]), T.ToTensor(), T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]), #turn to BGR T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]), T.Lambda(lambda x: x.mul_(255)) ]) domains = os.listdir(style_dir) domains.sort() num_domains = len(domains) print('Number of domains: %d' % num_domains) print("Batch Size:", args.batch_size) _, eval_loader = get_loaderFC2(data_dir, style_dir, temp_dir, transform, args.batch_size, num_workers, num_domains) generate_new = True tcl_dict = {} # prepare for d in range(1, num_domains): src_domain = "style0" trg_domain = "style" + str(d) t1 = '%s2%s' % (src_domain, trg_domain) t2 = '%s2%s' % (trg_domain, src_domain) tcl_dict[t1] = [] tcl_dict[t2] = [] if generate_new: create_task_folders(eval_dir, t1) #create_task_folders(eval_dir, t2) # generate for i, x_src_all in enumerate(tqdm(eval_loader, total=len(eval_loader))): x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all x_real = x_real.to(device) x_real2 = x_real2.to(device) y_org = y_org.to(device) x_ref = x_ref.to(device) y_trg = y_trg.to(device) mask = mask.to(device) flow = flow.to(device) mask_zero = torch.zeros(mask.shape).to(device) N = x_real.size(0) #y = y_trg.cpu().numpy() for k in range(N): y_org_np = y_org[k].cpu().numpy() y_trg_np = y_trg[k].cpu().numpy() src_domain = "style" + str(y_org_np) trg_domain = "style" + str(y_trg_np) if src_domain == trg_domain or y_trg_np == 0: continue task = '%s2%s' % (src_domain, trg_domain) net.set_style(y_trg_np - 1) x_fake = net.run(x_real, x_real, y_trg_np - 1, mask_zero, args.weight_tcl) x_warp = warp(x_fake, flow) #x_fake2 = net.run(mask*x_warp + (1 - mask)*x_real2, x_real2, y_trg_np - 1, mask) x_fake2 = net.run(x_warp, x_real2, y_trg_np - 1, mask, args.weight_tcl) tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2, 3))**0.5 tcl_dict[task].append(tcl_err[k].cpu().numpy()) path_ref = os.path.join(eval_dir, task + "/ref") path_fake = os.path.join(eval_dir, task + "/fake") if generate_new: filename = os.path.join( path_ref, '%.4i.png' % (i * args.batch_size + (k + 1))) if y_trg_np - 1 == 2: out_img = net.postp2(x_ref.data[0].cpu()) else: out_img = net.postp(x_ref.data[0].cpu()) out_img.save(filename) filename = os.path.join( path_fake, '%.4i.png' % (i * args.batch_size + (k + 1))) if y_trg_np - 1 == 2: out_img = net.postp2(x_fake.data[0].cpu()) else: out_img = net.postp(x_fake.data[0].cpu()) out_img.save(filename) # evaluate print("computing fid, lpips and tcl") tasks = [ dir for dir in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, dir)) ] tasks.sort() # fid and lpips fid_values = OrderedDict() #lpips_dict = OrderedDict() tcl_values = OrderedDict() for task in tasks: print(task) path_ref = os.path.join(eval_dir, task + "/ref") path_fake = os.path.join(eval_dir, task + "/fake") tcl_data = tcl_dict[task] print("TCL", len(tcl_data)) tcl_mean = np.array(tcl_data).mean() print(tcl_mean) tcl_values['TCL_%s' % (task)] = float(tcl_mean) print("FID") fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake], img_size=256, batch_size=args.batch_size) fid_values['FID_%s' % (task)] = fid_value # calculate the average FID for all tasks fid_mean = 0 for key, value in fid_values.items(): fid_mean += value / len(fid_values) fid_values['FID_mean'] = fid_mean # report FID values filename = os.path.join(eval_dir, 'FID.json') utils.save_json(fid_values, filename) # calculate the average TCL for all tasks tcl_mean = 0 for _, value in tcl_values.items(): tcl_mean += value / len(tcl_values) tcl_values['TCL_mean'] = float(tcl_mean) # report TCL values filename = os.path.join(eval_dir, 'TCL.json') utils.save_json(tcl_values, filename)
def eval_sintel(net, args): sintel_dir = "G:/Datasets/MPI-Sintel-complete/" #sintel_dir="/srv/local/tomstrident/datasets/MPI-Sintel-complete/" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') out_path = os.getcwd() + "/eval_sintel/" + str(args.weight_tcl) + "/" raft_model = initRaftModel(args) pyr_shapes = [(109, 256), (218, 512), (436, 1024)] net.set_shapes(pyr_shapes) num_domains = 4 net.batch_size = 1 #transform = [] #transform.append(transforms.ToTensor()) #transform.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) #transform = transforms.Compose(transform) transform = T.Compose([ #T.Resize((436, 1024)), T.ToTensor(), T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]), #turn to BGR T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]), T.Lambda(lambda x: x.mul_(255)) ]) transform2 = T.Compose([ #T.Resize((436, 1024)), #T.ToTensor(), T.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]), #turn to BGR T.Normalize(mean=[0.40760392, 0.45795686, 0.48501961], std=[1, 1, 1]), T.Lambda(lambda x: x.mul_(255)) ]) train_dir = os.path.join(sintel_dir, "training", "final") train_list = os.listdir(train_dir) train_list.sort() test_dir = os.path.join(sintel_dir, "test", "final") test_list = os.listdir(test_dir) test_list.sort() video_list = [os.path.join(train_dir, vid) for vid in train_list] video_list += [os.path.join(test_dir, vid) for vid in test_list] #video_list = video_list[:1] vid_list = train_list + test_list tcl_st_dict = {} tcl_lt_dict = {} tcl_st_dict = OrderedDict() tcl_lt_dict = OrderedDict() dt_dict = OrderedDict() for j, vid_dir in enumerate(video_list): vid = vid_list[j] sintel_dset = SingleSintelVideo(vid_dir, transform) loader = data.DataLoader(dataset=sintel_dset, batch_size=1, shuffle=False, num_workers=0) for y in range(1, num_domains): #y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device) key = vid + "_s" + str(y) vid_path = os.path.join(out_path, key) if not os.path.exists(vid_path): os.makedirs(vid_path) tcl_st_vals = [] tcl_lt_vals = [] dt_vals = [] net.set_style(y - 1) x_fake = [] styled_past = [] #imgs_past = [] for i, imgs in enumerate(tqdm(loader, total=len(loader))): img, img_last, img_past = imgs img = img.to(device) img_last = img_last.to(device) img_past = img_past.to(device) if i > 0: ff_last = computeRAFT(raft_model, img_last, img) bf_last = computeRAFT(raft_model, img, img_last) mask_last = fbcCheckTorch(ff_last, bf_last) #pre = warp(styled_past[-1], bf_last) #pre = mask_last*warp(styled_past[-1], bf_last) pre = mask_last * warp(styled_past[-1], bf_last) + (1 - mask_last) * img #pre = img #net.postp(pre.data[0].cpu()).save("test%d.png" % i) else: pre = img mask_last = torch.zeros((1, ) + img.shape[2:]).to( device).unsqueeze(1) #pre = transform2(torch.randn(img.size())[0]).unsqueeze(0).to(device) #pre = img mask_last = torch.zeros((1, ) + img.shape[2:]).to(device).unsqueeze(1) ''' if i > 1: ff_last = computeRAFT(raft_model, imgs_past[-2], img) bf_last = computeRAFT(raft_model, img, imgs_past[-2]) mask_last = torch.clamp(mask_last - fbcCheckTorch(ff_last, bf_last), 0.0, 1.0) pre = mask_last*warp(styled_past[-2], bf_last) + (1 - mask_last)*pre #net.postp(pre.data[0].cpu()).save("test%d.png" % i) #blah ''' #save_image(img[0], ncol=1, filename="blah.png") t_start = time.time() x_fake = net.run(pre, img, y - 1, mask_last, args.weight_tcl) t_end = time.time() #save_image(x_fake[0], ncol=1, filename="blah2.png") #blah dt_vals.append((t_end - t_start) * 1000) styled_past.append(x_fake) #imgs_past.append(img) if i > 0: tcl_st = ((mask_last * (x_fake - pre))**2).mean()**0.5 tcl_st_vals.append(tcl_st.cpu().numpy()) #blah #tcl_st = computeTCL(net, raft_model, x_fake, img, img_last, y - 1) #tcl_st_vals.append(tcl_st.cpu().numpy()) if i >= 5: ff_past = computeRAFT(raft_model, img_past, img) bf_past = computeRAFT(raft_model, img, img_past) mask_past = fbcCheckTorch(ff_past, bf_past) warp_past = warp(styled_past[0], bf_past) tcl_lt = ((mask_past * (x_fake - warp_past))**2).mean()**0.5 tcl_lt_vals.append(tcl_lt.cpu().numpy()) styled_past.pop(0) #imgs_past.pop(0) filename = os.path.join(vid_path, "frame_%04d.png" % i) #save_image(x_fake[0], ncol=1, filename=filename) if y - 1 == 2: out_img = net.postp2(x_fake.data[0].cpu()) else: out_img = net.postp(x_fake.data[0].cpu()) out_img.save(filename) tcl_st_dict["TCL-ST_" + key] = float(np.array(tcl_st_vals).mean()) tcl_lt_dict["TCL-LT_" + key] = float(np.array(tcl_lt_vals).mean()) dt_dict["DT_" + key] = float(np.array(dt_vals).mean()) save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains) save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains) save_dict_as_json("DT", dt_dict, out_path, num_domains)
def computeTCL(net, model, img_fake, img1, img2, sid): ff_last = computeRAFT(model, img2, img1) bf_last = computeRAFT(model, img1, img2) mask_last = fbcCheckTorch(ff_last, bf_last) warp_last = warp(net.run(img2, sid), bf_last) return ((mask_last * (img_fake - warp_last))**2).mean()**0.5