def setup_method(self, run_id, emphasis_parameter): run_id += self.concat_id(emphasis_parameter) if run_id[0] == 'm': pre_style_path = self.train_dir[:5] + "FC2/dumoulin/msid" + run_id[4] + "_ep20_bs16_lr-3_a0_a0_a0_b1_b1_b1/epoch_19.pth" n_styles = int(run_id[4]) else: pre_style_path = self.train_dir[:5] + "FC2/johnson/sid" + run_id[3] + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth" n_styles = 1 self.model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device) self.pre_style_model = FastStyleNet(3, n_styles).to(self.device) self.pre_style_model.load_state_dict(torch.load(pre_style_path)) self.first_frame = True return run_id
def setup_method(self, run_id, emphasis_parameter): run_id += self.concat_id(emphasis_parameter) if run_id[0] == 'm': n_styles = int(run_id[4]) else: n_styles = 1 self.model = FastStyleNet(3, n_styles).to(self.device) return run_id
def loadModelID(self, n_styles, model_id): self.model = FastStyleNet(3, n_styles).to(self.device) self.model.load_state_dict(torch.load(model_id))
def evaluate_fc2(self, args, n_styles, epochs, n_epochs, emphasis_parameter, batchsize=16, learning_rate=1e-3, dset='FC2'): print('Calculating evaluation metrics...') #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data_dir = "G:/Datasets/FC2/DATAFiles/" style_dir = "G:/Datasets/FC2/styled-files/" temp_dir = "G:/Datasets/FC2/styled-files3/" eval_dir = os.getcwd() + "/eval_fc2/" + self.method + "/" num_workers = 0 args.batch_size = 4 domains = os.listdir(style_dir) domains.sort() num_domains = len(domains) print('Number of domains: %d' % num_domains) print("Batch Size:", args.batch_size) _, eval_loader = get_loaderFC2(data_dir, style_dir, temp_dir, args.batch_size, num_workers, num_domains) tmp_dir = self.train_dir + dset + '/' + self.method + '/' tmp_list = os.listdir(tmp_dir) tmp_list.sort() models = [] pre_models = [] if n_styles > 1: model = FastStyleNet(3, n_styles).to(self.device) model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[0] + '/epoch_' + str(n_epochs) + '.pth')) else: if self.method == "ruder": for tmp in tmp_list: model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device) model.load_state_dict( torch.load(tmp_dir + '/' + tmp + '/epoch_' + str(n_epochs) + '.pth')) models.append(model) pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + tmp[ 3] + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth" model = FastStyleNet(3, n_styles).to(self.device) model.load_state_dict(torch.load(pre_style_path)) pre_models.append(model) else: for tmp in tmp_list: model = FastStyleNet(3, n_styles).to(self.device) model.load_state_dict( torch.load(tmp_dir + '/' + tmp + '/epoch_' + str(n_epochs) + '.pth')) models.append(model) generate_new = True tcl_dict = {} # prepare for d in range(1, num_domains): src_domain = "style0" trg_domain = "style" + str(d) t1 = '%s2%s' % (src_domain, trg_domain) t2 = '%s2%s' % (trg_domain, src_domain) tcl_dict[t1] = [] tcl_dict[t2] = [] if generate_new: create_task_folders(eval_dir, t1) #create_task_folders(eval_dir, t2) # generate for i, x_src_all in enumerate(tqdm(eval_loader, total=len(eval_loader))): x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all x_real = x_real.to(self.device) x_real2 = x_real2.to(self.device) y_org = y_org.to(self.device) x_ref = x_ref.to(self.device) y_trg = y_trg.to(self.device) mask = mask.to(self.device) flow = flow.to(self.device) N = x_real.size(0) for k in range(N): y_org_np = y_org[k].cpu().numpy() y_trg_np = y_trg[k].cpu().numpy() src_domain = "style" + str(y_org_np) trg_domain = "style" + str(y_trg_np) if src_domain == trg_domain or y_trg_np == 0: continue task = '%s2%s' % (src_domain, trg_domain) if n_styles > 1: self.model = model else: self.model = models[y_trg_np - 1] if self.method == "ruder": self.pre_style_model = pre_models[y_trg_np - 1] x_fake = self.infer_method((x_real, None, None), y_trg[k] - 1) #x_fake = torch.clamp(x_fake, 0.0, 1.0) x_warp = warp(x_fake, flow) x_fake2 = self.infer_method((x_real2, mask, x_warp), y_trg[k] - 1) #x_fake2 = torch.clamp(x_fake2, 0.0, 1.0) tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2, 3))**0.5 tcl_dict[task].append(tcl_err[k].cpu().numpy()) path_ref = os.path.join(eval_dir, task + "/ref") path_fake = os.path.join(eval_dir, task + "/fake") if generate_new: filename = os.path.join( path_ref, '%.4i.png' % (i * args.batch_size + (k + 1))) save_image(denormalize(x_ref[k]), ncol=1, filename=filename) filename = os.path.join( path_fake, '%.4i.png' % (i * args.batch_size + (k + 1))) save_image(x_fake[k], ncol=1, filename=filename) # evaluate print("computing fid, lpips and tcl") tasks = [ dir for dir in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, dir)) ] tasks.sort() # fid and lpips fid_values = OrderedDict() #lpips_dict = OrderedDict() tcl_values = OrderedDict() for task in tasks: print(task) path_ref = os.path.join(eval_dir, task + "/ref") path_fake = os.path.join(eval_dir, task + "/fake") tcl_data = tcl_dict[task] print("TCL", len(tcl_data)) tcl_mean = np.array(tcl_data).mean() print(tcl_mean) tcl_values['TCL_%s' % (task)] = float(tcl_mean) print("FID") fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake], img_size=256, batch_size=args.batch_size) fid_values['FID_%s' % (task)] = fid_value # calculate the average FID for all tasks fid_mean = 0 for key, value in fid_values.items(): fid_mean += value / len(fid_values) fid_values['FID_mean'] = fid_mean # report FID values filename = os.path.join(eval_dir, 'FID.json') utils.save_json(fid_values, filename) # calculate the average TCL for all tasks tcl_mean = 0 for _, value in tcl_values.items(): tcl_mean += value / len(tcl_values) tcl_values['TCL_mean'] = float(tcl_mean) # report TCL values filename = os.path.join(eval_dir, 'TCL.json') utils.save_json(tcl_values, filename)
def evaluate_sintel(self, args, n_styles, epochs, n_epochs, emphasis_parameter, sintel_dir="D:/Datasets/MPI-Sintel-complete/", batchsize=16, learning_rate=1e-3, dset='FC2'): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') out_path = "G:/Code/LBST/eval_sintel/" + self.method + "/" raft_model = initRaftModel(args) num_domains = 4 transform = [] transform.append(transforms.ToTensor()) transform.append( transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = transforms.Compose(transform) train_dir = os.path.join(sintel_dir, "training", "final") train_list = os.listdir(train_dir) train_list.sort() test_dir = os.path.join(sintel_dir, "test", "final") test_list = os.listdir(test_dir) test_list.sort() #video_list = [os.path.join(train_dir, vid) for vid in train_list] #video_list += [os.path.join(test_dir, vid) for vid in test_list] video_list = [ os.path.join(train_dir, "alley_2"), os.path.join(train_dir, "market_6"), os.path.join(train_dir, "temple_2") ] #vid_list = train_list + test_list vid_list = ["alley_2", "market_6", "temple_2"] tcl_st_dict = {} tcl_lt_dict = {} tcl_st_dict = OrderedDict() tcl_lt_dict = OrderedDict() dt_dict = OrderedDict() #emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles) tmp_dir = self.train_dir + dset + '/' + self.method + '/' tmp_list = os.listdir(tmp_dir) tmp_list.sort() #run_id = self.setup_method(run_id, emphasis_parameter.T) if self.method == "ruder": self.model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device) self.pre_style_model = FastStyleNet(3, n_styles).to(self.device) else: self.model = FastStyleNet(3, n_styles).to(self.device) first = True for j, vid_dir in enumerate(video_list): vid = vid_list[j] #print(vid_dir) sintel_dset = SingleSintelVideo(vid_dir, transform) loader = data.DataLoader(dataset=sintel_dset, batch_size=1, shuffle=False, num_workers=0) for y in range(1, num_domains): y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device) key = vid + "_s" + str(y) vid_path = os.path.join(out_path, key) if not os.path.exists(vid_path): os.makedirs(vid_path) if y == 3: gray = True else: gray = False tcl_st_vals = [] tcl_lt_vals = [] dt_vals = [] #if n_styles > 1: # run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate)) #else: # run_id = "sid%d_ep%d_bs%d_lr%d" % (y - 1, epochs, batchsize, np.log10(learning_rate)) if n_styles > 1: if first: self.model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[y - 1] + '/epoch_' + str(n_epochs) + '.pth')) first = False else: #print(tmp_dir + '/' + tmp_list[y-1] + '/epoch_' + str(n_epochs) + '.pth') self.model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[y - 1] + '/epoch_' + str(n_epochs) + '.pth')) if self.method == "ruder": pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + str( y - 1) + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth" self.pre_style_model.load_state_dict( torch.load(pre_style_path)) past_sty_list = [] for i, imgs in enumerate(tqdm(loader, total=len(loader))): img, img_last, img_past = imgs img = img.to(device) img_last = img_last.to(device) img_past = img_past.to(device) #save_image(img[0], ncol=1, filename="blah.png") if i > 0: ff_last = computeRAFT(raft_model, img_last, img) bf_last = computeRAFT(raft_model, img, img_last) mask_last = fbcCheckTorch(ff_last, bf_last) x_fake_last = past_sty_list[ -1] #self.infer_method((img_last, None, None), y_trg - 1) warp_last = warp(torch.clamp(x_fake_last, 0.0, 1.0), bf_last) else: mask_last = None warp_last = None #mask, x_warp t_start = time.time() x_fake = self.infer_method((img, mask_last, warp_last), y_trg - 1) x_fake = torch.clamp(x_fake, 0.0, 1.0) t_end = time.time() past_sty_list.append(x_fake) dt_vals.append((t_end - t_start) * 1000) if i > 0: tcl_st = ((mask_last * (x_fake - warp_last))**2).mean()**0.5 tcl_st_vals.append(tcl_st.cpu().numpy()) if i >= 5: ff_past = computeRAFT(raft_model, img_past, img) bf_past = computeRAFT(raft_model, img, img_past) mask_past = fbcCheckTorch(ff_past, bf_past) #torch.clamp(self.infer_method((img_past, None, None), y_trg - 1), 0.0, 1.0) warp_past = warp(past_sty_list[0], bf_past) tcl_lt = ((mask_past * (x_fake - warp_past))**2).mean()**0.5 tcl_lt_vals.append(tcl_lt.cpu().numpy()) ''' print(img.shape) print(img_past.shape) print(warp_past.shape) print(x_fake.shape) print(past_sty_list[0].shape) save_image(denormalize(img[0]), ncol=1, filename="blah1.png") save_image(denormalize(img_past[0]), ncol=1, filename="blah2.png") save_image(warp_past[0], ncol=1, filename="blah3.png") save_image(x_fake[0], ncol=1, filename="blah4.png") save_image(past_sty_list[0][0], ncol=1, filename="blah5.png") save_image(mask_past*warp_past, ncol=1, filename="blah6.png") blah''' past_sty_list.pop(0) filename = os.path.join(vid_path, "frame_%04d.png" % i) save_image(x_fake[0], ncol=1, filename=filename, gray=gray) tcl_st_dict["TCL-ST_" + key] = float( np.array(tcl_st_vals).mean()) tcl_lt_dict["TCL-LT_" + key] = float( np.array(tcl_lt_vals).mean()) dt_dict["DT_" + key] = float(np.array(dt_vals).mean()) save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains) save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains) save_dict_as_json("DT", dt_dict, out_path, num_domains)
class FastStyle(): def __init__(self, debug=True): #self.train_dir = 'F:/runs/' #self.train_dir = '/home/tomstrident/projects/LBST/runs/' self.train_dir = 'G:/Code/LBST/runs/' self.debug = debug self.device = 'cuda' self.method = [] self.VGG16_MEAN = [0.485, 0.456, 0.406] self.VGG16_STD = [0.229, 0.224, 0.225] #self.sid_styles = ['autoportrait', 'edtaonisl', 'composition', 'edtaonisl', 'udnie', 'starry_night']#'candy', self.sid_styles = [ 's1_starry_night', 's2_the_scream', 's3_take_on_me' ] #'candy', style_grid = np.arange(0, len(self.sid_styles), dtype=np.float32) self.style_id_grid = torch.Tensor(style_grid).to(self.device).float() if debug and not os.path.exists("debug/"): os.mkdir("debug/") def vectorize_parameters(self, params, n_styles): vec_pararms = [p * np.ones(n_styles) for p in np.array(params)] return np.array(vec_pararms).T def concat_id(self, params): run_id = "" for j, p in enumerate(params): for pi in p: run_id += "_" + self.loss_letters[j] + ("%d" % np.log10(pi)) return run_id + "/" def train(self, sid=2, epochs=3, emphasis_parameter=[1e0, 1e1], batchsize=16, learning_rate=1e-3, dset='FC2'): if isinstance(sid, list): styles = [self.sid_styles[sidx] for sidx in sid] run_id = "msid%d_ep%d_bs%d_lr%d" % (len(sid), epochs, batchsize, np.log10(learning_rate)) emphasis_parameter = self.vectorize_parameters( emphasis_parameter, len(sid)) else: styles = [self.sid_styles[sid]] run_id = "sid%d_ep%d_bs%d_lr%d" % (sid, epochs, batchsize, np.log10(learning_rate)) emphasis_parameter = self.vectorize_parameters( emphasis_parameter, 1) #self.train_dir = self.train_dir[:8] + dset + '/' + self.method + '/' self.train_dir = self.train_dir + dset + '/' + self.method + '/' run_id = self.setup_method(run_id, emphasis_parameter.T) adv_train_dir = self.train_dir + run_id print(adv_train_dir) if not os.path.exists(adv_train_dir): os.makedirs(adv_train_dir) if os.path.exists(adv_train_dir + '/epoch_' + str(epochs - 1) + '.pth'): print('Warning: config already exists! Returning ...') return self.prep_training(batch_sz=batchsize, styles=styles, dset=dset) self.adam = torch.optim.Adam(self.model.parameters(), lr=learning_rate) loss_list = [] n_styles = len(self.styles) style_grid = np.arange(0, n_styles) style_id_grid = torch.LongTensor(style_grid).to(self.device) for epoch in range(epochs): for itr, (imgs, masks, flows) in enumerate(self.dataloader): imgs = torch.split(imgs, 3, dim=1) self.prep_adam(itr) if n_styles > 1: style_id = style_id_grid[np.random.randint(0, n_styles)] else: style_id = 0 losses, styled_img, loss_string = self.train_method( imgs, masks, flows, emphasis_parameter[style_id], style_id) self.adam.step() if (itr + 1) % 1000 == 0: torch.save( self.model.state_dict(), '%sfinal_epoch_%d_itr_%d.pth' % (adv_train_dir, epoch, itr // 1000)) if (itr) % 1000 == 0 and self.debug: imageio.imsave('debug/%d_%d_img1.png' % (epoch, itr), imgs[0].cpu().numpy()[0].transpose(1, 2, 0)) imageio.imsave( 'debug/%d_%d_styled_img1.png' % (epoch, itr), styled_img.detach().cpu().numpy()[0].transpose( 1, 2, 0)) out_string = "[%d/%d][%d/%d] sid%d" % ( epoch, epochs, itr, len(self.dataloader), style_id) print(out_string + loss_string) loss_list.append( torch.FloatTensor(losses).detach().cpu().numpy()) torch.save(self.model.state_dict(), '%sepoch_%d.pth' % (adv_train_dir, epoch)) loss_list = np.array(loss_list) np.save(adv_train_dir + "loss_list.npy", loss_list) #============================================================================ def infer(self, sid, n_styles, epochs, n_epochs, emphasis_parameter, batchsize=16, learning_rate=1e-3, dset='FC2', sintel_id='temple_2', sintel_path='D:/Datasets/', vid_fps=20, out_img_path=None, out_img_num=[10]): if n_styles > 1: run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate)) else: run_id = "sid%d_ep%d_bs%d_lr%d" % (sid, epochs, batchsize, np.log10(learning_rate)) emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles) #self.train_dir = self.train_dir[:8] + dset + '/' + self.method + '/' self.train_dir = self.train_dir + dset + '/' + self.method + '/' run_id = self.setup_method(run_id, emphasis_parameter.T) #infer_id = run_id[:4] + str(sid) + run_id[5:-1] print(self.train_dir + run_id + 'epoch_' + str(n_epochs) + '.pth') self.model.load_state_dict( torch.load(self.train_dir + run_id + 'epoch_' + str(n_epochs) + '.pth')) writer = imageio.get_writer('styled_' + self.method + str(sid) + '.mp4', fps=vid_fps) dataloader = DataLoader(SintelDataset(sintel_path, sintel_id), batch_size=1) warped = [] mask = [] cst_list = [] lt_cst_list = [] ft_count = [] styled_list = [] #debug_path = 'C:/Users/Tom/Documents/Python Scripts/Masters Project/debug/' style_grid = np.arange(0, len(self.sid_styles), dtype=np.float32) style_id_grid = torch.Tensor(style_grid).to(self.device).float() style_id = style_id_grid[sid] for itr, (frame, mask, flow, lt_data) in enumerate(dataloader): if itr > 0: flow = flow[0].permute(1, 2, 0).cpu().numpy() warped = self.warp_image(styled_list[-1], flow) t_start = time.time() torch_output = self.infer_method((frame, mask, warped), style_id) t_end = time.time() ft_count.append(t_end - t_start) torch_output = torch.clamp(torch_output, 0.0, 1.0) styled_frame = torch_output[0].permute(1, 2, 0).detach().cpu().numpy() #imageio.imwrite(debug_path + '/img' + str(itr) + '.png', (styled_frame*255.0).astype(np.uint8)) if itr > 0: #imageio.imwrite(debug_path + '/warp' + str(itr) + '.png', (warped*255.0).astype(np.uint8)) #imageio.imwrite(debug_path + '/mask' + str(itr) + '.png', (mask*255.0).astype(np.uint8)) mask = mask[0].permute(1, 2, 0).cpu().numpy() cst = ((mask * (warped - styled_frame))**2).mean() cst_list.append(cst) #print('FPS:', 1/ft_count[-1], 'CST:', cst_list[-1]) styled_list.append(styled_frame) lt_len = 5 if not (itr - lt_len < 0 or itr == len(dataloader) - 1): lt_flow, lt_mask = lt_data lt_flow = lt_flow[0].permute(1, 2, 0).cpu().numpy() lt_mask = lt_mask[0].permute(1, 2, 0).cpu().numpy() f_idx2 = itr - lt_len + 1 #imageio.imwrite(debug_path + '/styled_frame2.png', (styled_list[f_idx1]*255.0).astype(np.uint8)) #imageio.imwrite(debug_path + '/styled_frame1.png', (styled_list[f_idx2]*255.0).astype(np.uint8)) warped = self.warp_image(styled_list[f_idx2], lt_flow) #imageio.imwrite(debug_path + '/warp' + '.png', (warped*255.0).astype(np.uint8)) #imageio.imwrite(debug_path + '/wmask' + '.png', (lt_mask*255.0).astype(np.uint8)) lt_cst = ((lt_mask[0] * (warped - styled_frame))**2).mean() lt_cst_list.append(lt_cst) real_fid = len(dataloader) - 1 - itr if out_img_path != None and real_fid in out_img_num: #imageio.imwrite(self.train_dir + infer_path + '_c.png', (np_f*255.0).astype(np.uint8)) print(out_img_path + dset + "_" + run_id[:-1] + "_" + str(real_fid) + ".png") imageio.imwrite( out_img_path + dset + "_" + run_id[:-1] + "_" + str(real_fid) + ".png", (styled_frame * 255.0).astype(np.uint8)) cv2.imshow('frame', styled_frame[:, :, [2, 1, 0]]) if cv2.waitKey(1) & 0xFF == ord('q'): break #writer.append_data((styled_frame*255.0).astype(np.uint8)) cv2.destroyAllWindows() for styled_frame in styled_list[::-1]: writer.append_data((styled_frame * 255.0).astype(np.uint8)) writer.close() ft_count = np.array(ft_count[3:]) fps_count = np.array([1 / x for x in ft_count]) avg_ft = ft_count.mean() avg_fps = fps_count.mean() #avg_ft = ft_count.mean() #opl_ft = np.percentile(np.sort(ft_count), 1) #avg_fps = fps_count.mean() #opl_fps = np.percentile(np.sort(fps_count), 1) #oph_ft = self.high_percentile(ft_count, 5) #opl_fps2 = self.high_percentile(fps_count, 5) mse_cst = (np.array(cst_list).mean())**0.5 mse_lt_cst = (np.array(lt_cst_list).mean())**0.5 print('consistency mse:', mse_cst) print('lt consistency mse:', mse_lt_cst) print('avg ft:', avg_ft * 1000, avg_fps) #print('opl ft:', opl_ft*1000, 1/opl_ft, oph_ft, opl_fps, opl_fps2) return avg_ft * 1000, avg_fps, mse_cst, mse_lt_cst ''' infer(self, sid, n_styles, epochs, n_epochs, emphasis_parameter, batchsize=16, learning_rate=1e-3, dset='FC2', sintel_id='temple_2', sintel_path='D:/Datasets/', vid_fps=20, out_img_path=None, out_img_num=[10]): ''' def evaluate_sintel(self, args, n_styles, epochs, n_epochs, emphasis_parameter, sintel_dir="D:/Datasets/MPI-Sintel-complete/", batchsize=16, learning_rate=1e-3, dset='FC2'): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') out_path = "G:/Code/LBST/eval_sintel/" + self.method + "/" raft_model = initRaftModel(args) num_domains = 4 transform = [] transform.append(transforms.ToTensor()) transform.append( transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = transforms.Compose(transform) train_dir = os.path.join(sintel_dir, "training", "final") train_list = os.listdir(train_dir) train_list.sort() test_dir = os.path.join(sintel_dir, "test", "final") test_list = os.listdir(test_dir) test_list.sort() #video_list = [os.path.join(train_dir, vid) for vid in train_list] #video_list += [os.path.join(test_dir, vid) for vid in test_list] video_list = [ os.path.join(train_dir, "alley_2"), os.path.join(train_dir, "market_6"), os.path.join(train_dir, "temple_2") ] #vid_list = train_list + test_list vid_list = ["alley_2", "market_6", "temple_2"] tcl_st_dict = {} tcl_lt_dict = {} tcl_st_dict = OrderedDict() tcl_lt_dict = OrderedDict() dt_dict = OrderedDict() #emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles) tmp_dir = self.train_dir + dset + '/' + self.method + '/' tmp_list = os.listdir(tmp_dir) tmp_list.sort() #run_id = self.setup_method(run_id, emphasis_parameter.T) if self.method == "ruder": self.model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device) self.pre_style_model = FastStyleNet(3, n_styles).to(self.device) else: self.model = FastStyleNet(3, n_styles).to(self.device) first = True for j, vid_dir in enumerate(video_list): vid = vid_list[j] #print(vid_dir) sintel_dset = SingleSintelVideo(vid_dir, transform) loader = data.DataLoader(dataset=sintel_dset, batch_size=1, shuffle=False, num_workers=0) for y in range(1, num_domains): y_trg = torch.Tensor([y])[0].type(torch.LongTensor).to(device) key = vid + "_s" + str(y) vid_path = os.path.join(out_path, key) if not os.path.exists(vid_path): os.makedirs(vid_path) if y == 3: gray = True else: gray = False tcl_st_vals = [] tcl_lt_vals = [] dt_vals = [] #if n_styles > 1: # run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate)) #else: # run_id = "sid%d_ep%d_bs%d_lr%d" % (y - 1, epochs, batchsize, np.log10(learning_rate)) if n_styles > 1: if first: self.model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[y - 1] + '/epoch_' + str(n_epochs) + '.pth')) first = False else: #print(tmp_dir + '/' + tmp_list[y-1] + '/epoch_' + str(n_epochs) + '.pth') self.model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[y - 1] + '/epoch_' + str(n_epochs) + '.pth')) if self.method == "ruder": pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + str( y - 1) + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth" self.pre_style_model.load_state_dict( torch.load(pre_style_path)) past_sty_list = [] for i, imgs in enumerate(tqdm(loader, total=len(loader))): img, img_last, img_past = imgs img = img.to(device) img_last = img_last.to(device) img_past = img_past.to(device) #save_image(img[0], ncol=1, filename="blah.png") if i > 0: ff_last = computeRAFT(raft_model, img_last, img) bf_last = computeRAFT(raft_model, img, img_last) mask_last = fbcCheckTorch(ff_last, bf_last) x_fake_last = past_sty_list[ -1] #self.infer_method((img_last, None, None), y_trg - 1) warp_last = warp(torch.clamp(x_fake_last, 0.0, 1.0), bf_last) else: mask_last = None warp_last = None #mask, x_warp t_start = time.time() x_fake = self.infer_method((img, mask_last, warp_last), y_trg - 1) x_fake = torch.clamp(x_fake, 0.0, 1.0) t_end = time.time() past_sty_list.append(x_fake) dt_vals.append((t_end - t_start) * 1000) if i > 0: tcl_st = ((mask_last * (x_fake - warp_last))**2).mean()**0.5 tcl_st_vals.append(tcl_st.cpu().numpy()) if i >= 5: ff_past = computeRAFT(raft_model, img_past, img) bf_past = computeRAFT(raft_model, img, img_past) mask_past = fbcCheckTorch(ff_past, bf_past) #torch.clamp(self.infer_method((img_past, None, None), y_trg - 1), 0.0, 1.0) warp_past = warp(past_sty_list[0], bf_past) tcl_lt = ((mask_past * (x_fake - warp_past))**2).mean()**0.5 tcl_lt_vals.append(tcl_lt.cpu().numpy()) ''' print(img.shape) print(img_past.shape) print(warp_past.shape) print(x_fake.shape) print(past_sty_list[0].shape) save_image(denormalize(img[0]), ncol=1, filename="blah1.png") save_image(denormalize(img_past[0]), ncol=1, filename="blah2.png") save_image(warp_past[0], ncol=1, filename="blah3.png") save_image(x_fake[0], ncol=1, filename="blah4.png") save_image(past_sty_list[0][0], ncol=1, filename="blah5.png") save_image(mask_past*warp_past, ncol=1, filename="blah6.png") blah''' past_sty_list.pop(0) filename = os.path.join(vid_path, "frame_%04d.png" % i) save_image(x_fake[0], ncol=1, filename=filename, gray=gray) tcl_st_dict["TCL-ST_" + key] = float( np.array(tcl_st_vals).mean()) tcl_lt_dict["TCL-LT_" + key] = float( np.array(tcl_lt_vals).mean()) dt_dict["DT_" + key] = float(np.array(dt_vals).mean()) save_dict_as_json("TCL-ST", tcl_st_dict, out_path, num_domains) save_dict_as_json("TCL-LT", tcl_lt_dict, out_path, num_domains) save_dict_as_json("DT", dt_dict, out_path, num_domains) @torch.no_grad() def evaluate_fc2(self, args, n_styles, epochs, n_epochs, emphasis_parameter, batchsize=16, learning_rate=1e-3, dset='FC2'): print('Calculating evaluation metrics...') #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') data_dir = "G:/Datasets/FC2/DATAFiles/" style_dir = "G:/Datasets/FC2/styled-files/" temp_dir = "G:/Datasets/FC2/styled-files3/" eval_dir = os.getcwd() + "/eval_fc2/" + self.method + "/" num_workers = 0 args.batch_size = 4 domains = os.listdir(style_dir) domains.sort() num_domains = len(domains) print('Number of domains: %d' % num_domains) print("Batch Size:", args.batch_size) _, eval_loader = get_loaderFC2(data_dir, style_dir, temp_dir, args.batch_size, num_workers, num_domains) tmp_dir = self.train_dir + dset + '/' + self.method + '/' tmp_list = os.listdir(tmp_dir) tmp_list.sort() models = [] pre_models = [] if n_styles > 1: model = FastStyleNet(3, n_styles).to(self.device) model.load_state_dict( torch.load(tmp_dir + '/' + tmp_list[0] + '/epoch_' + str(n_epochs) + '.pth')) else: if self.method == "ruder": for tmp in tmp_list: model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device) model.load_state_dict( torch.load(tmp_dir + '/' + tmp + '/epoch_' + str(n_epochs) + '.pth')) models.append(model) pre_style_path = "G:/Code/LBST/runs/johnson/FC2/johnson/sid" + tmp[ 3] + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth" model = FastStyleNet(3, n_styles).to(self.device) model.load_state_dict(torch.load(pre_style_path)) pre_models.append(model) else: for tmp in tmp_list: model = FastStyleNet(3, n_styles).to(self.device) model.load_state_dict( torch.load(tmp_dir + '/' + tmp + '/epoch_' + str(n_epochs) + '.pth')) models.append(model) generate_new = True tcl_dict = {} # prepare for d in range(1, num_domains): src_domain = "style0" trg_domain = "style" + str(d) t1 = '%s2%s' % (src_domain, trg_domain) t2 = '%s2%s' % (trg_domain, src_domain) tcl_dict[t1] = [] tcl_dict[t2] = [] if generate_new: create_task_folders(eval_dir, t1) #create_task_folders(eval_dir, t2) # generate for i, x_src_all in enumerate(tqdm(eval_loader, total=len(eval_loader))): x_real, x_real2, y_org, x_ref, y_trg, mask, flow = x_src_all x_real = x_real.to(self.device) x_real2 = x_real2.to(self.device) y_org = y_org.to(self.device) x_ref = x_ref.to(self.device) y_trg = y_trg.to(self.device) mask = mask.to(self.device) flow = flow.to(self.device) N = x_real.size(0) for k in range(N): y_org_np = y_org[k].cpu().numpy() y_trg_np = y_trg[k].cpu().numpy() src_domain = "style" + str(y_org_np) trg_domain = "style" + str(y_trg_np) if src_domain == trg_domain or y_trg_np == 0: continue task = '%s2%s' % (src_domain, trg_domain) if n_styles > 1: self.model = model else: self.model = models[y_trg_np - 1] if self.method == "ruder": self.pre_style_model = pre_models[y_trg_np - 1] x_fake = self.infer_method((x_real, None, None), y_trg[k] - 1) #x_fake = torch.clamp(x_fake, 0.0, 1.0) x_warp = warp(x_fake, flow) x_fake2 = self.infer_method((x_real2, mask, x_warp), y_trg[k] - 1) #x_fake2 = torch.clamp(x_fake2, 0.0, 1.0) tcl_err = ((mask * (x_fake2 - x_warp))**2).mean(dim=(1, 2, 3))**0.5 tcl_dict[task].append(tcl_err[k].cpu().numpy()) path_ref = os.path.join(eval_dir, task + "/ref") path_fake = os.path.join(eval_dir, task + "/fake") if generate_new: filename = os.path.join( path_ref, '%.4i.png' % (i * args.batch_size + (k + 1))) save_image(denormalize(x_ref[k]), ncol=1, filename=filename) filename = os.path.join( path_fake, '%.4i.png' % (i * args.batch_size + (k + 1))) save_image(x_fake[k], ncol=1, filename=filename) # evaluate print("computing fid, lpips and tcl") tasks = [ dir for dir in os.listdir(eval_dir) if os.path.isdir(os.path.join(eval_dir, dir)) ] tasks.sort() # fid and lpips fid_values = OrderedDict() #lpips_dict = OrderedDict() tcl_values = OrderedDict() for task in tasks: print(task) path_ref = os.path.join(eval_dir, task + "/ref") path_fake = os.path.join(eval_dir, task + "/fake") tcl_data = tcl_dict[task] print("TCL", len(tcl_data)) tcl_mean = np.array(tcl_data).mean() print(tcl_mean) tcl_values['TCL_%s' % (task)] = float(tcl_mean) print("FID") fid_value = calculate_fid_given_paths(paths=[path_ref, path_fake], img_size=256, batch_size=args.batch_size) fid_values['FID_%s' % (task)] = fid_value # calculate the average FID for all tasks fid_mean = 0 for key, value in fid_values.items(): fid_mean += value / len(fid_values) fid_values['FID_mean'] = fid_mean # report FID values filename = os.path.join(eval_dir, 'FID.json') utils.save_json(fid_values, filename) # calculate the average TCL for all tasks tcl_mean = 0 for _, value in tcl_values.items(): tcl_mean += value / len(tcl_values) tcl_values['TCL_mean'] = float(tcl_mean) # report TCL values filename = os.path.join(eval_dir, 'TCL.json') utils.save_json(tcl_values, filename) def setup_train(self): raise NotImplementedError("Please Implement this method") def train_method(self): raise NotImplementedError("Please Implement this method") def infer_method(self): raise NotImplementedError("Please Implement this method") def setup_method(self): raise NotImplementedError("Please Implement this method") def loadStyles(self, style_name_list, style_size=512): styles = [] for i, style_name in enumerate(style_name_list): style = io.imread('styles/' + style_name + '.jpg') style = torch.from_numpy( transform.resize(style, (style_size, style_size))).to( self.device).permute(2, 0, 1).float().unsqueeze(0) if self.debug: imageio.imsave('debug/0_0_style_' + str(i) + '.png', style.cpu().numpy()[0].transpose(1, 2, 0)) style = self.normalize(style) styled_featuresR = self.vgg(style) style_GM = [self.gram_matrix(f) for f in styled_featuresR] styles.append(style_GM) return styles def load_model(self, model_path): print('loading model ...') self.model.load_state_dict(torch.load(self.train_dir + model_path)) def prep_training(self, batch_sz=16, styles=['composition'], dset='FC2'): #dset_path = 'F:/Datasets/' + dset + '/' dset_path = '/home/tomstrident/datasets/' + dset + '/' if dset == 'FC2': self.dataloader = DataLoader(FlyingChairs2Dataset( dset_path, batch_sz), batch_size=batch_sz) #, num_workers=4 elif dset == 'HW2': self.dataloader = DataLoader(Hollywood2Dataset( dset_path, batch_sz), batch_size=batch_sz) elif dset == 'CO2': self.dataloader = DataLoader(COCODataset(dset_path, batch_sz), batch_size=batch_sz) else: assert False, "Invalid dataset specified error!" self.train_dir = self.train_dir[:5] + dset + '/' self.L2distance = nn.MSELoss().to(self.device) self.L2distancematrix = nn.MSELoss(reduction='none').to(self.device) self.vgg = Vgg16().to(self.device) #self.vgg = Vgg19().to(self.device) for param in self.vgg.parameters(): param.requires_grad = False self.styles = self.loadStyles(styles) self.adam = [] def prep_adam(self, itr, batch_sz=16): self.adam.zero_grad() if (itr + 1) % np.int32(500 / batch_sz) == 0: for param in self.adam.param_groups: param['lr'] = max(param['lr'] / 1.2, 1e-4) def calc_tv_loss(self, I): sij = I[:, :, :-1, :-1] si1j = I[:, :, :-1, 1:] sij1 = I[:, :, 1:, :-1] tv_mat1 = torch.norm(sij1 - sij, dim=1)**2 tv_mat2 = torch.norm(si1j - sij, dim=1)**2 return torch.sum((tv_mat1 + tv_mat2)**0.5) def load_mp4(self, video_path): reader = imageio.get_reader(video_path + '.mp4') fps = reader.get_meta_data()['fps'] num_f = reader.count_frames() print(num_f) return num_f, fps, reader def gram_matrix(self, inp): b, c, h, w = inp.size() features = inp.view(b, c, h * w) G = torch.bmm(features, features.transpose(1, 2)) return G.div(h * w) def normalize(self, img): mean = img.new_tensor(self.VGG16_MEAN).view(-1, 1, 1) std = img.new_tensor(self.VGG16_STD).view(-1, 1, 1) return (img - mean) / std def warp_image(self, A, flow): h, w = flow.shape[:2] x = (flow[..., 0] + np.arange(w)).astype(A.dtype) y = (flow[..., 1] + np.arange(h)[:, np.newaxis]).astype(A.dtype) W_m = cv2.remap(A, x, y, cv2.INTER_LINEAR) return W_m.reshape(A.shape) def styleFrame(self, frame, sid): style_id = torch.from_numpy(np.float32([sid ])).to(self.device).float()[0] torch_f = torch.from_numpy(frame).to(self.device).permute( 2, 0, 1).float().unsqueeze(0) torch_m = torch.zeros(1, 1, frame.shape[0], frame.shape[1]) torch_w = torch_f torch_output = self.infer_method((torch_f, torch_m, torch_w), style_id) torch_output = torch.clamp(torch_output, 0.0, 1.0) styled_frame = torch_output[0].permute(1, 2, 0).detach().cpu().numpy() return styled_frame def loadModel(self, sid, n_styles, epochs, n_epochs, emphasis_parameter, batchsize=6, learning_rate=1e-3, dset='FC2'): if n_styles > 1: run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate)) else: run_id = "sid%d_ep%d_bs%d_lr%d" % (sid, epochs, batchsize, np.log10(learning_rate)) emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles) self.train_dir = self.train_dir[:8] + dset + '/' + self.method + '/' run_id = self.setup_method(run_id, emphasis_parameter.T) print(self.train_dir + run_id + 'epoch_' + str(n_epochs) + '.pth') self.loadModelID(self.train_dir + run_id + 'epoch_' + str(n_epochs) + '.pth') def loadModelID(self, n_styles, model_id): self.model = FastStyleNet(3, n_styles).to(self.device) self.model.load_state_dict(torch.load(model_id))
class FastStyle(): def __init__(self, debug=True): self.train_dir = 'F:/runs/' self.debug = debug self.device = 'cuda' self.method = [] self.VGG16_MEAN = [0.485, 0.456, 0.406] self.VGG16_STD = [0.229, 0.224, 0.225] self.sid_styles = ['autoportrait', 'edtaonisl', 'composition', 'edtaonisl', 'udnie', 'starry_night']#'candy', style_grid = np.arange(0, len(self.sid_styles), dtype=np.float32) self.style_id_grid = torch.Tensor(style_grid).to(self.device).float() if debug and not os.path.exists("debug/"): os.mkdir("debug/") def vectorize_parameters(self, params, n_styles): vec_pararms = [p*np.ones(n_styles) for p in np.array(params)] return np.array(vec_pararms).T def concat_id(self, params): run_id = "" for j, p in enumerate(params): for pi in p: run_id += "_" + self.loss_letters[j] + ("%d" % np.log10(pi)) return run_id + "/" def train(self, sid=2, epochs=3, emphasis_parameter=[1e0, 1e1], batchsize=6, learning_rate=1e-3, dset='FC2'): if isinstance(sid, list): styles = [self.sid_styles[sidx] for sidx in sid] run_id = "msid%d_ep%d_bs%d_lr%d" % (len(sid), epochs, batchsize, np.log10(learning_rate)) emphasis_parameter = self.vectorize_parameters(emphasis_parameter, len(sid)) else: styles = [self.sid_styles[sid]] run_id = "sid%d_ep%d_bs%d_lr%d" % (sid, epochs, batchsize, np.log10(learning_rate)) emphasis_parameter = self.vectorize_parameters(emphasis_parameter, 1) self.train_dir = self.train_dir[:8] + dset + '/' + self.method + '/' run_id = self.setup_method(run_id, emphasis_parameter.T) adv_train_dir = self.train_dir + run_id print(adv_train_dir) if not os.path.exists(adv_train_dir): os.makedirs(adv_train_dir) if os.path.exists(adv_train_dir + '/epoch_' + str(epochs-1) + '.pth'): print('Warning: config already exists! Returning ...') return self.prep_training(batch_sz=batchsize, styles=styles, dset=dset) self.adam = torch.optim.Adam(self.model.parameters(), lr=learning_rate) loss_list = [] n_styles = len(self.styles) style_grid = np.arange(0, n_styles) style_id_grid = torch.LongTensor(style_grid).to(self.device) for epoch in range(epochs): for itr, (imgs, masks, flows) in enumerate(self.dataloader): imgs = torch.split(imgs, 3, dim=1) self.prep_adam(itr) if n_styles > 1: style_id = style_id_grid[np.random.randint(0, n_styles)] else: style_id = 0 losses, styled_img, loss_string = self.train_method(imgs, masks, flows, emphasis_parameter[style_id], style_id) self.adam.step() if (itr+1)%1000 == 0: torch.save(self.model.state_dict(), '%sfinal_epoch_%d_itr_%d.pth' % (adv_train_dir, epoch, itr//1000)) if (itr)%1000 == 0 and self.debug: imageio.imsave('debug/%d_%d_img1.png' % (epoch, itr), imgs[0].cpu().numpy()[0].transpose(1,2,0)) imageio.imsave('debug/%d_%d_styled_img1.png' % (epoch, itr), styled_img.detach().cpu().numpy()[0].transpose(1,2,0)) out_string = "[%d/%d][%d/%d] sid%d" % (epoch, epochs, itr, len(self.dataloader), style_id) print(out_string + loss_string) loss_list.append(torch.FloatTensor(losses).detach().cpu().numpy()) torch.save(self.model.state_dict(), '%sepoch_%d.pth' % (adv_train_dir, epoch)) loss_list = np.array(loss_list) np.save(adv_train_dir + "loss_list.npy", loss_list) #============================================================================ def infer(self, sid, n_styles, epochs, n_epochs, emphasis_parameter, batchsize=6, learning_rate=1e-3, dset='FC2', sintel_id='temple_2', sintel_path='D:/Datasets/', vid_fps=20, out_img_path=None, out_img_num=[10]): if n_styles > 1: run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate)) else: run_id = "sid%d_ep%d_bs%d_lr%d" % (sid, epochs, batchsize, np.log10(learning_rate)) emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles) self.train_dir = self.train_dir[:8] + dset + '/' + self.method + '/' run_id = self.setup_method(run_id, emphasis_parameter.T) #infer_id = run_id[:4] + str(sid) + run_id[5:-1] print(self.train_dir + run_id + 'epoch_' + str(n_epochs) + '.pth') self.model.load_state_dict(torch.load(self.train_dir + run_id + 'epoch_' + str(n_epochs) + '.pth')) writer = imageio.get_writer('styled_' + self.method + '.mp4', fps=vid_fps) dataloader = DataLoader(SintelDataset(sintel_path, sintel_id), batch_size=1) warped = [] mask = [] cst_list = [] lt_cst_list = [] ft_count = [] styled_list = [] #debug_path = 'C:/Users/Tom/Documents/Python Scripts/Masters Project/debug/' style_grid = np.arange(0, len(self.sid_styles), dtype=np.float32) style_id_grid = torch.Tensor(style_grid).to(self.device).float() style_id = style_id_grid[sid] for itr, (frame, mask, flow, lt_data) in enumerate(dataloader): if itr > 0: flow = flow[0].permute(1, 2, 0).cpu().numpy() warped = self.warp_image(styled_list[-1], flow) t_start = time.time() torch_output = self.infer_method((frame, mask, warped), style_id) t_end = time.time() ft_count.append(t_end - t_start) torch_output = torch.clamp(torch_output, 0.0, 1.0) styled_frame = torch_output[0].permute(1, 2, 0).detach().cpu().numpy() #imageio.imwrite(debug_path + '/img' + str(itr) + '.png', (styled_frame*255.0).astype(np.uint8)) if itr > 0: #imageio.imwrite(debug_path + '/warp' + str(itr) + '.png', (warped*255.0).astype(np.uint8)) #imageio.imwrite(debug_path + '/mask' + str(itr) + '.png', (mask*255.0).astype(np.uint8)) mask = mask[0].permute(1, 2, 0).cpu().numpy() cst = ((mask*(warped - styled_frame))**2).mean() cst_list.append(cst) #print('FPS:', 1/ft_count[-1], 'CST:', cst_list[-1]) styled_list.append(styled_frame) lt_len = 5 if not (itr - lt_len < 0 or itr == len(dataloader) - 1): lt_flow, lt_mask = lt_data lt_flow = lt_flow[0].permute(1, 2, 0).cpu().numpy() lt_mask = lt_mask[0].permute(1, 2, 0).cpu().numpy() f_idx2 = itr-lt_len+1 #imageio.imwrite(debug_path + '/styled_frame2.png', (styled_list[f_idx1]*255.0).astype(np.uint8)) #imageio.imwrite(debug_path + '/styled_frame1.png', (styled_list[f_idx2]*255.0).astype(np.uint8)) warped = self.warp_image(styled_list[f_idx2], lt_flow) #imageio.imwrite(debug_path + '/warp' + '.png', (warped*255.0).astype(np.uint8)) #imageio.imwrite(debug_path + '/wmask' + '.png', (lt_mask*255.0).astype(np.uint8)) lt_cst = ((lt_mask[0]*(warped - styled_frame))**2).mean() lt_cst_list.append(lt_cst) real_fid = len(dataloader) - 1 - itr if out_img_path != None and real_fid in out_img_num: #imageio.imwrite(self.train_dir + infer_path + '_c.png', (np_f*255.0).astype(np.uint8)) print(out_img_path + dset + "_" + run_id[:-1] + "_" + str(real_fid) + ".png") imageio.imwrite(out_img_path + dset + "_" + run_id[:-1] + "_" + str(real_fid) + ".png", (styled_frame*255.0).astype(np.uint8)) cv2.imshow('frame', styled_frame[:,:,[2, 1, 0]]) if cv2.waitKey(1) & 0xFF == ord('q'): break #writer.append_data((styled_frame*255.0).astype(np.uint8)) cv2.destroyAllWindows() for styled_frame in styled_list[::-1]: writer.append_data((styled_frame*255.0).astype(np.uint8)) writer.close() ft_count = np.array(ft_count[3:]) fps_count = np.array([1/x for x in ft_count]) avg_ft = ft_count.mean() avg_fps = fps_count.mean() #avg_ft = ft_count.mean() #opl_ft = np.percentile(np.sort(ft_count), 1) #avg_fps = fps_count.mean() #opl_fps = np.percentile(np.sort(fps_count), 1) #oph_ft = self.high_percentile(ft_count, 5) #opl_fps2 = self.high_percentile(fps_count, 5) mse_cst = (np.array(cst_list).mean())**0.5 mse_lt_cst = (np.array(lt_cst_list).mean())**0.5 print('consistency mse:', mse_cst) print('lt consistency mse:', mse_lt_cst) print('avg ft:', avg_ft*1000, avg_fps) #print('opl ft:', opl_ft*1000, 1/opl_ft, oph_ft, opl_fps, opl_fps2) return avg_ft*1000, avg_fps, mse_cst, mse_lt_cst def setup_train(self): raise NotImplementedError("Please Implement this method") def train_method(self): raise NotImplementedError("Please Implement this method") def infer_method(self): raise NotImplementedError("Please Implement this method") def setup_method(self): raise NotImplementedError("Please Implement this method") def loadStyles(self, style_name_list, style_size=512): styles = [] for i, style_name in enumerate(style_name_list): style = io.imread('styles/' + style_name + '.jpg') style = torch.from_numpy(transform.resize(style, (style_size, style_size))).to(self.device).permute(2, 0, 1).float().unsqueeze(0) if self.debug: imageio.imsave('debug/0_0_style_' + str(i) + '.png', style.cpu().numpy()[0].transpose(1,2,0)) style = self.normalize(style) styled_featuresR = self.vgg(style) style_GM = [self.gram_matrix(f) for f in styled_featuresR] styles.append(style_GM) return styles def load_model(self, model_path): print('loading model ...') self.model.load_state_dict(torch.load(self.train_dir + model_path)) def prep_training(self, batch_sz=6, styles=['composition'], dset='FC2'): dset_path = 'F:/Datasets/' + dset + '/' if dset == 'FC2': self.dataloader = DataLoader(FlyingChairs2Dataset(dset_path, batch_sz), batch_size=batch_sz) elif dset == 'HW2': self.dataloader = DataLoader(Hollywood2Dataset(dset_path, batch_sz), batch_size=batch_sz) elif dset == 'CO2': self.dataloader = DataLoader(COCODataset(dset_path, batch_sz), batch_size=batch_sz) else: assert False, "Invalid dataset specified error!" self.train_dir = self.train_dir[:5] + dset + '/' self.L2distance = nn.MSELoss().to(self.device) self.L2distancematrix = nn.MSELoss(reduction='none').to(self.device) self.vgg = Vgg16().to(self.device) #self.vgg = Vgg19().to(self.device) for param in self.vgg.parameters(): param.requires_grad = False self.styles = self.loadStyles(styles) self.adam = [] def prep_adam(self, itr, batch_sz=1): self.adam.zero_grad() if (itr+1) % np.int32(500 / batch_sz) == 0: for param in self.adam.param_groups: param['lr'] = max(param['lr']/1.2, 1e-4) def calc_tv_loss(self, I): sij = I[:, :, :-1, :-1] si1j = I[:, :, :-1, 1:] sij1 = I[:, :, 1:, :-1] tv_mat1 = torch.norm(sij1 - sij, dim=1)**2 tv_mat2 = torch.norm(si1j - sij, dim=1)**2 return torch.sum((tv_mat1 + tv_mat2)**0.5) def load_mp4(self, video_path): reader = imageio.get_reader(video_path + '.mp4') fps = reader.get_meta_data()['fps'] num_f = reader.count_frames() print(num_f) return num_f, fps, reader def gram_matrix(self, inp): b, c, h, w = inp.size() features = inp.view(b, c, h*w) G = torch.bmm(features, features.transpose(1, 2)) return G.div(h*w) def normalize(self, img): mean = img.new_tensor(self.VGG16_MEAN).view(-1, 1, 1) std = img.new_tensor(self.VGG16_STD).view(-1, 1, 1) return (img - mean) / std def warp_image(self, A, flow): h, w = flow.shape[:2] x = (flow[...,0] + np.arange(w)).astype(A.dtype) y = (flow[...,1] + np.arange(h)[:,np.newaxis]).astype(A.dtype) W_m = cv2.remap(A, x, y, cv2.INTER_LINEAR) return W_m.reshape(A.shape) def styleFrame(self, frame, sid): style_id = torch.from_numpy(np.float32([sid])).to(self.device).float()[0] torch_f = torch.from_numpy(frame).to(self.device).permute(2, 0, 1).float().unsqueeze(0) torch_m = torch.zeros(1, 1, frame.shape[0], frame.shape[1]) torch_w = torch_f torch_output = self.infer_method((torch_f, torch_m, torch_w), style_id) torch_output = torch.clamp(torch_output, 0.0, 1.0) styled_frame = torch_output[0].permute(1, 2, 0).detach().cpu().numpy() return styled_frame def loadModel(self, sid, n_styles, epochs, n_epochs, emphasis_parameter, batchsize=6, learning_rate=1e-3, dset='FC2'): if n_styles > 1: run_id = "msid%d_ep%d_bs%d_lr%d" % (n_styles, epochs, batchsize, np.log10(learning_rate)) else: run_id = "sid%d_ep%d_bs%d_lr%d" % (sid, epochs, batchsize, np.log10(learning_rate)) emphasis_parameter = self.vectorize_parameters(emphasis_parameter, n_styles) self.train_dir = self.train_dir[:8] + dset + '/' + self.method + '/' run_id = self.setup_method(run_id, emphasis_parameter.T) print(self.train_dir + run_id + 'epoch_' + str(n_epochs) + '.pth') self.loadModelID(self.train_dir + run_id + 'epoch_' + str(n_epochs) + '.pth') def loadModelID(self, n_styles, model_id): self.model = FastStyleNet(3, n_styles).to(self.device) self.model.load_state_dict(torch.load(model_id))
class Ruder(FastStyle): def __init__(self, n_styles=1): FastStyle.__init__(self) self.method = 'ruder' self.train_dir += self.method + '/' self.loss_labels = ['total', 'content', 'style', 'temporal_loss'] self.loss_letters = ["a", "b", "c"] def roll(self, p): return True if np.random.random() < p else False def setup_method(self, run_id, emphasis_parameter): run_id += self.concat_id(emphasis_parameter) if run_id[0] == 'm': pre_style_path = self.train_dir[:5] + "FC2/dumoulin/msid" + run_id[4] + "_ep20_bs16_lr-3_a0_a0_a0_b1_b1_b1/epoch_19.pth" n_styles = int(run_id[4]) else: pre_style_path = self.train_dir[:5] + "FC2/johnson/sid" + run_id[3] + "_ep20_bs16_lr-3_a0_b1_d-4/epoch_19.pth" n_styles = 1 self.model = FastStyleNet(3 + 1 + 3, n_styles).to(self.device) self.pre_style_model = FastStyleNet(3, n_styles).to(self.device) self.pre_style_model.load_state_dict(torch.load(pre_style_path)) self.first_frame = True return run_id def train_method(self, imgs, masks, flows, emphasis_parameter, style_id): alpha, beta, gamma = emphasis_parameter masks = torch.split(masks, 1, dim=1) flows = torch.split(flows, 2, dim=1) rand_roll = self.roll(0.5) if rand_roll: _, styled_img1 = self.pre_style_model(imgs[0], s_id=style_id) styled_img1 /= 255.0 warped1 = warp(styled_img1, flows[0]) _, styled_img2 = self.model(torch.cat((imgs[1], masks[0], warped1), 1), s_id=style_id) styled_img2 /= 255.0 loss_img = imgs[1] loss_styled = styled_img2 loss_warped = warped1 if len(imgs) > 2: warped2 = warp(styled_img2, flows[1]) _, styled_img3 = self.model(torch.cat((imgs[2], masks[1], warped2), 1), s_id=style_id) styled_img3 /= 255.0 loss_img = imgs[2] loss_styled = styled_img3 loss_warped = warped2 if len(imgs) > 4: warped3 = warp(styled_img3, flows[2]) _, styled_img4 = self.model(torch.cat((imgs[3], masks[2], warped3), 1), s_id=style_id) styled_img4 /= 255.0 warped4 = warp(styled_img4, flows[3]) _, styled_img5 = self.model(torch.cat((imgs[4], masks[3], warped4), 1), s_id=style_id) styled_img5 /= 255.0 loss_img = imgs[4] loss_styled = styled_img5 loss_warped = warped4 else: _, styled_img2 = self.model(torch.cat((imgs[1], 0.0*masks[0], 0.0*imgs[1]), 1), s_id=style_id) styled_img2 /= 255.0 loss_img = imgs[1] loss_styled = styled_img2 loss_warped = styled_img2 styled_features = self.vgg(self.normalize(loss_styled)) img_features = self.vgg(self.normalize(loss_img)) content_loss = alpha*self.L2distance(styled_features[2], img_features[2]) style_loss = 0 for i, gram_s in enumerate(self.styles[style_id]): gram_img1 = self.gram_matrix(styled_features[i]) style_loss += ((gram_img1 - gram_s)**2).mean()#float(weight)* style_loss *= beta if rand_roll: temporal_loss = gamma*((masks[-1]*(loss_warped - loss_styled))**2).mean() else: temporal_loss = 0.0 loss = content_loss + style_loss + temporal_loss losses = tuple([loss, style_loss, content_loss, temporal_loss]) loss_string = " L: %.4f CL: %.4f SL: %.4f TL: %.4f" % losses loss.backward() return losses, styled_img2, loss_string def infer_method(self, params, style_id): torch_f = params[0] if params[1] == None:#self.first_frame: _, styled_frame = self.pre_style_model(torch_f, s_id=style_id) #self.first_frame = False else: torch_m = params[1] torch_w = params[2]#torch.from_numpy(params[2]).to(self.device).permute(2, 0, 1).float().unsqueeze(0) _, styled_frame = self.model(torch.cat((torch_f, torch_m, torch_w), 1), s_id=style_id) return styled_frame/255.0