def __init__(self, config): self.config = config self.summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr) ## model print(toGreen('Loading Model...')) self.model = create_model(config) self.model.print() ## checkpoint manager self.ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode, config.max_ckpt_num) if config.load_pretrained: print(toGreen('Loading pretrained Model...')) load_result = self.model.get_network().load_state_dict( torch.load(os.path.join('./ckpt', config.pre_ckpt_name))) lr = config.lr_init * (config.decay_rate**(config.epoch_start // config.decay_every)) print(toRed('\tlearning rate: {}'.format(lr))) print(toRed('\tload result: {}'.format(load_result))) self.model.set_optim(lr) ## training vars self.max_epoch = 10000 self.epoch_range = np.arange(config.epoch_start, self.max_epoch) self.itr_global = 0 self.itr = 0 self.err_epoch_train = 0 self.err_epoch_test = 0
def evaluate(config, mode): date = datetime.datetime.now().strftime('%Y_%m_%d/%H-%M') save_path = os.path.join(config.LOG_DIR.save, date, config.eval_mode) exists_or_mkdir(save_path) print(toGreen('Loading checkpoint manager...')) ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, 10) #ckpt_manager = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt, mode, 10) batch_size = config.batch_size sample_num = config.sample_num skip_length = np.array(config.skip_length) ## DEFINE SESSION sess = tf.Session(config = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False)) ## DEFINE MODEL print(toGreen('Building model...')) num_control_points = 5 param_dim = num_control_points ** 2 inputs, outputs = get_evaluation_model(sample_num, param_dim, num_control_points, config.height, config.width) ## INITIALIZING VARIABLE print(toGreen('Initializing variables')) sess.run(tf.global_variables_initializer()) print(toGreen('Loading checkpoint...')) ckpt_manager.load_ckpt(sess, by_score = config.load_ckpt_by_score) print(toYellow('======== EVALUATION START =========')) offset = '/data1/junyonglee/video_stab/eval' file_path = os.path.join(offset, 'train_unstab') test_video_list = np.array(sorted(tl.files.load_file_list(path = file_path, regx = '.*', printable = False))) for k in np.arange(len(test_video_list)): test_video_name = test_video_list[k] eval_path_stab = os.path.join(offset, 'train_stab') eval_path_unstab = os.path.join(offset, 'train_unstab') cap_stab = cv2.VideoCapture(os.path.join(eval_path_stab, test_video_name)) cap_unstab = cv2.VideoCapture(os.path.join(eval_path_unstab, test_video_name)) total_frame_num = int(cap_unstab.get(7)) base = os.path.basename(test_video_name) base_name = os.path.splitext(base)[0] fourcc = cv2.VideoWriter_fourcc('M','J','P','G') fps = cap_unstab.get(5) out = cv2.VideoWriter(os.path.join(save_path, str(k) + '_' + config.eval_mode + '_' + base_name + '_out.avi'), fourcc, fps, (2 * config.width, config.height)) print(toYellow('reading filename: {}, total frame: {}'.format(test_video_name, total_frame_num))) # read frame def read_frame(cap): ref, frame = cap.read() if ref != False: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = cv2.resize(frame / 255., (config.width, config.height)) return ref, frame # reading all frames in the video total_frames_stab = [] total_frames_unstab = [] print(toGreen('reading all frames...')) while True: ref_stab, frame_stab = read_frame(cap_stab) ref_unstab, frame_unstab = read_frame(cap_unstab) if ref_stab == False or ref_unstab == False: break total_frames_stab.append(frame_stab) total_frames_unstab.append(frame_unstab) # duplicate first frames 32 times for i in np.arange(skip_length[-1] - skip_length[0]): total_frames_unstab[i] = total_frames_stab[i] print(toGreen('stabilizaing video...')) total_frame_num = len(total_frames_unstab) total_frames_stab = np.array(total_frames_stab) total_frames_unstab = np.array(total_frames_unstab) sample_idx = skip_length for frame_idx in range(skip_length[-1] - skip_length[0], total_frame_num): batch_frames = total_frames_unstab[sample_idx] batch_frames = np.expand_dims(np.concatenate(batch_frames, axis = 2), axis = 0) feed_dict = { inputs['patches_t']: batch_frames, inputs['u_t']: batch_frames[:, :, :, 18:] } s_t_pred = np.squeeze(sess.run(outputs['s_t_pred'], feed_dict)) output = np.uint8(np.concatenate([total_frames_unstab[sample_idx[-1]].copy(), s_t_pred], axis = 1) * 255.) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) out.write(np.uint8(output)) total_frames_unstab[sample_idx[-1]] = total_frames_stab[sample_idx[-1]] print('{}/{} {}/{} frame index: {}'.format(k + 1, len(test_video_list), frame_idx, int(total_frame_num - 1), sample_idx), flush = True) sample_idx = sample_idx + 1 cap_stab.release() cap_unstab.release() out.release()
def train(config): summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr) ## inputs inputs = {'b_t_1': None, 'b_t': None, 's_t_1': None, 's_t': None} inputs = collections.OrderedDict(sorted(inputs.items(), key=lambda t: t[0])) ## model print(toGreen('Loading Model...')) moduleNetwork = Network().to(device) moduleNetwork.apply(weights_init) moduleNetwork_gt = Network().to(device) print(moduleNetwork) ## checkpoint manager ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode, config.max_ckpt_num) moduleNetwork.load_state_dict( torch.load('./network/network-default.pytorch')) moduleNetwork_gt.load_state_dict( torch.load('./network/network-default.pytorch')) ## data loader print(toGreen('Loading Data Loader...')) data_loader = Data_Loader(config, is_train=True, name='train', thread_num=config.thread_num) data_loader_test = Data_Loader(config, is_train=False, name="test", thread_num=config.thread_num) data_loader.init_data_loader(inputs) data_loader_test.init_data_loader(inputs) ## loss, optim print(toGreen('Building Loss & Optim...')) MSE_sum = torch.nn.MSELoss(reduction='sum') MSE_mean = torch.nn.MSELoss() optimizer = optim.Adam(moduleNetwork.parameters(), lr=config.lr_init, betas=(config.beta1, 0.999)) errs = collections.OrderedDict() print(toYellow('======== TRAINING START =========')) max_epoch = 10000 itr = 0 for epoch in np.arange(max_epoch): # train while True: itr_time = time.time() inputs, is_end = data_loader.get_feed() if is_end: break if config.loss == 'image': flow_bb = torch.nn.functional.interpolate( input=moduleNetwork(inputs['b_t'], inputs['b_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_bs = torch.nn.functional.interpolate( input=moduleNetwork(inputs['b_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_sb = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['b_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_ss = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) with torch.no_grad(): flow_ss_gt = torch.nn.functional.interpolate( input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_ss_gt) s_t_warped_bb = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_bb) s_t_warped_bs = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_bs) s_t_warped_sb = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_sb) s_t_warped_ss = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_ss) s_t_warped_bb_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_bb) s_t_warped_bs_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_bs) s_t_warped_sb_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_sb) s_t_warped_ss_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_ss) optimizer.zero_grad() errs['MSE_bb'] = MSE_sum( s_t_warped_bb * s_t_warped_bb_mask, inputs['s_t']) / s_t_warped_bb_mask.sum() errs['MSE_bs'] = MSE_sum( s_t_warped_bs * s_t_warped_bs_mask, inputs['s_t']) / s_t_warped_bs_mask.sum() errs['MSE_sb'] = MSE_sum( s_t_warped_sb * s_t_warped_sb_mask, inputs['s_t']) / s_t_warped_sb_mask.sum() errs['MSE_ss'] = MSE_sum( s_t_warped_ss * s_t_warped_ss_mask, inputs['s_t']) / s_t_warped_ss_mask.sum() errs['MSE_bb_mask_shape'] = MSE_mean(s_t_warped_bb_mask, s_t_warped_ss_mask_gt) errs['MSE_bs_mask_shape'] = MSE_mean(s_t_warped_bs_mask, s_t_warped_ss_mask_gt) errs['MSE_sb_mask_shape'] = MSE_mean(s_t_warped_sb_mask, s_t_warped_ss_mask_gt) errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask, s_t_warped_ss_mask_gt) errs['total'] = errs['MSE_bb'] + errs['MSE_bs'] + errs['MSE_sb'] + errs['MSE_ss'] \ + errs['MSE_bb_mask_shape'] + errs['MSE_bs_mask_shape'] + errs['MSE_sb_mask_shape'] + errs['MSE_ss_mask_shape'] if config.loss == 'image_ss': flow_ss = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) with torch.no_grad(): flow_ss_gt = torch.nn.functional.interpolate( input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_ss_gt) s_t_warped_ss = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_ss) s_t_warped_ss_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_ss) optimizer.zero_grad() errs['MSE_ss'] = MSE_sum( s_t_warped_ss * s_t_warped_ss_mask, inputs['s_t']) / s_t_warped_ss_mask.sum() errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask, s_t_warped_ss_mask_gt) errs['total'] = errs['MSE_ss'] + errs['MSE_ss_mask_shape'] if config.loss == 'flow_only': flow_bb = torch.nn.functional.interpolate( input=moduleNetwork(inputs['b_t'], inputs['b_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_bs = torch.nn.functional.interpolate( input=moduleNetwork(inputs['b_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_sb = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['b_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_ss = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) s_t_warped_ss = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_ss) with torch.no_grad(): flow_ss_gt = torch.nn.functional.interpolate( input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) optimizer.zero_grad() # liteflow_flow_only errs['MSE_bb_ss'] = MSE_mean(flow_bb, flow_ss_gt) errs['MSE_bs_ss'] = MSE_mean(flow_bs, flow_ss_gt) errs['MSE_sb_ss'] = MSE_mean(flow_sb, flow_ss_gt) errs['MSE_ss_ss'] = MSE_mean(flow_ss, flow_ss_gt) errs['total'] = errs['MSE_bb_ss'] + errs['MSE_bs_ss'] + errs[ 'MSE_sb_ss'] + errs['MSE_ss_ss'] errs['total'].backward() optimizer.step() lr = adjust_learning_rate(optimizer, epoch, config.decay_rate, config.decay_every, config.lr_init) if itr % config.write_log_every_itr == 0: summary.add_scalar('loss/loss_mse', errs['total'].item(), itr) vutils.save_image(inputs['s_t_1'].detach().cpu(), '{}/{}_1_input.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) vutils.save_image(s_t_warped_ss.detach().cpu(), '{}/{}_2_warped_ss.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) vutils.save_image(inputs['s_t'].detach().cpu(), '{}/{}_3_gt.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) if config.loss == 'image_ss': vutils.save_image(s_t_warped_ss_mask.detach().cpu(), '{}/{}_4_s_t_wapred_ss_mask.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) elif config.loss != 'flow_only': vutils.save_image(s_t_warped_bb_mask.detach().cpu(), '{}/{}_4_s_t_wapred_bb_mask.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) if itr % config.refresh_image_log_every_itr == 0: remove_file_end_with(config.LOG_DIR.sample, '*.png') print_logs('TRAIN', config.mode, epoch, itr_time, itr, data_loader.num_itr, errs=errs, lr=lr) itr += 1 if epoch % config.write_ckpt_every_epoch == 0: ckpt_manager.save_ckpt(moduleNetwork, epoch, score=errs['total'].item())
class Trainer(): def __init__(self, config): self.config = config self.summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr) ## model print(toGreen('Loading Model...')) self.model = create_model(config) self.model.print() ## checkpoint manager self.ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode, config.max_ckpt_num) if config.load_pretrained: print(toGreen('Loading pretrained Model...')) load_result = self.model.get_network().load_state_dict( torch.load(os.path.join('./ckpt', config.pre_ckpt_name))) lr = config.lr_init * (config.decay_rate**(config.epoch_start // config.decay_every)) print(toRed('\tlearning rate: {}'.format(lr))) print(toRed('\tload result: {}'.format(load_result))) self.model.set_optim(lr) ## training vars self.max_epoch = 10000 self.epoch_range = np.arange(config.epoch_start, self.max_epoch) self.itr_global = 0 self.itr = 0 self.err_epoch_train = 0 self.err_epoch_test = 0 def train(self): print(toYellow('======== TRAINING START =========')) for epoch in self.epoch_range: ## TRAIN ## self.itr = 0 self.err_epoch_train = 0 self.model.train() while True: if self.iteration(epoch): break err_epoch_train = self.err_epoch_train / self.itr ## TEST ## self.itr = 0 self.err_epoch_test = 0 self.model.eval() while True: with torch.no_grad(): if self.iteration(epoch, is_train=False): break err_epoch_test = self.err_epoch_test / self.itr ## LOG if epoch % self.config.write_ckpt_every_epoch == 0: self.ckpt_manager.save_ckpt(self.model.get_network(), epoch + 1, score=err_epoch_train) remove_file_end_with(self.config.LOG_DIR.sample, '*.png') self.summary.add_scalar('loss/epoch_train', err_epoch_train, epoch) self.summary.add_scalar('loss/epoch_test', err_epoch_test, epoch) def iteration(self, epoch, is_train=True): lr = None itr_time = time.time() is_end = self.model.iteration(epoch, is_train) if is_end: return True inputs = self.model.visuals['inputs'] errs = self.model.visuals['errs'] outs = self.model.visuals['outs'] lr = self.model.visuals['lr'] num_itr = self.model.visuals['num_itr'] if is_train: lr = self.model.update(epoch) if self.itr % config.write_log_every_itr == 0: try: self.summary.add_scalar('loss/itr', errs['total'].item(), self.itr_global) vutils.save_image(inputs['input'], '{}/{}_{}_1_input.png'.format( self.config.LOG_DIR.sample, epoch, self.itr), nrow=3, padding=0, normalize=False) i = 2 for key, val in outs.items(): if val is not None: vutils.save_image(val, '{}/{}_{}_{}_out_{}.png'.format( self.config.LOG_DIR.sample, epoch, self.itr, i, key), nrow=3, padding=0, normalize=False) i += 1 vutils.save_image(inputs['gt'], '{}/{}_{}_{}_gt.png'.format( self.config.LOG_DIR.sample, epoch, self.itr, i), nrow=3, padding=0, normalize=False) except Exception as ex: print('saving error: ', ex) print_logs('TRAIN', self.config.mode, epoch, itr_time, self.itr, num_itr, errs=errs, lr=lr) self.err_epoch_train += errs['total'].item() self.itr += 1 self.itr_global += 1 else: print_logs('TEST', self.config.mode, epoch, itr_time, self.itr, num_itr, errs=errs) self.err_epoch_test += errs['total'].item() self.itr += 1 gc.collect() return is_end
def evaluate(config, mode): date = datetime.datetime.now().strftime('%Y_%m_%d/%H-%M') save_path = os.path.join(config.LOG_DIR.save, date, config.eval_mode) exists_or_mkdir(save_path) print(toGreen('Loading checkpoint manager...')) ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, 10) #ckpt_manager = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt, mode, 10) batch_size = config.batch_size sample_num = config.sample_num skip_length = config.skip_length ## DEFINE SESSION sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) ## DEFINE MODEL print(toGreen('Building model...')) stabNet = StabNet(config.height, config.width, config.F_dim, is_train=False) stable_path_net = stabNet.get_stable_path_init(sample_num) unstable_path_net = stabNet.get_unstable_path_init(sample_num) outputs_net = stabNet.init_evaluation_model(sample_num) ## INITIALIZING VARIABLE print(toGreen('Initializing variables')) sess.run(tf.global_variables_initializer()) print(toGreen('Loading checkpoint...')) ckpt_manager.load_ckpt(sess, by_score=config.load_ckpt_by_score) print(toYellow('======== EVALUATION START =========')) test_video_list = np.array( sorted( tl.files.load_file_list(path=config.unstab_path, regx='.*', printable=False))) for k in np.arange(len(test_video_list)): test_video_name = test_video_list[k] cap = cv2.VideoCapture( os.path.join(config.unstab_path, test_video_name)) fps = cap.get(5) resize_h = config.height resize_w = config.width # out_h = int(cap.get(4)) # out_w = int(cap.get(3)) out_h = resize_h out_w = resize_w # refine_temp = np.ones((h, w)) # refine_temp = refine_image(refine_temp) # [h, w] = refine_temp.shape[:2] total_frame_num = int(cap.get(7)) fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') base = os.path.basename(test_video_name) base_name = os.path.splitext(base)[0] out = cv2.VideoWriter( os.path.join( save_path, str(k) + '_' + config.eval_mode + '_' + base_name + '_out.avi'), fourcc, fps, (3 * out_w, out_h)) print( toYellow('reading filename: {}, total frame: {}'.format( test_video_name, total_frame_num))) # read frame def refine_frame(frame): return cv2.resize(frame / 255., (resize_w, resize_h)) def read_frame(cap): ref, frame = cap.read() if ref != False: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = cv2.resize(frame, (out_w, out_h)) return ref, frame # reading all frames in the video total_frames = [] print(toGreen('reading all frames...')) while True: ref, frame = read_frame(cap) if ref == False: break total_frames.append(refine_frame(frame)) # duplicate first frames 30 times for i in np.arange((sample_num - 1) * skip_length): total_frames.insert(0, total_frames[0]) print(toGreen('stabilizaing video...')) total_frame_num = len(total_frames) total_frames = np.array(total_frames) S = [None] * total_frame_num U = [None] * total_frame_num C_0_list = [None] * total_frame_num S_t_1_seq = None U_t_1_seq = None for i in np.arange((sample_num - 1) * skip_length): C_0_list[i] = np.zeros([1, out_h, out_w, 2]) sample_idx = np.arange(0, 0 + sample_num * skip_length, skip_length) for frame_idx in range((sample_num - 1) * skip_length, total_frame_num): batch_frames = total_frames[sample_idx] batch_frames = np.expand_dims(np.concatenate(np.expand_dims( batch_frames, axis=0), axis=0), axis=0) if U[sample_idx[0]] is None: feed_dict = { stabNet.inputs['IS']: batch_frames[:, :-1, :, :, :] } stable_path_init = sess.run(stable_path_net, feed_dict) feed_dict = { stabNet.inputs['IU']: batch_frames[:, :-1, :, :, :] } unstable_path_init = sess.run(unstable_path_net, feed_dict) S_t_1_seq = stable_path_init['S_t_1_seq'] U_t_1_seq = unstable_path_init['U_t_1_seq'] for i in np.arange(S_t_1_seq.shape[1]): S[sample_idx[i + 1]] = S_t_1_seq[:, i:i + 1, :, :, :] U[sample_idx[i + 1]] = U_t_1_seq[:, i:i + 1, :, :, :] idxs = sample_idx[1:-1] i = 0 for idx in idxs: if i == 0: S_t_1_seq = S[idx] U_t_1_seq = U[idx] else: S_t_1_seq = np.concatenate([S_t_1_seq, S[idx]], axis=1) U_t_1_seq = np.concatenate([U_t_1_seq, U[idx]], axis=1) i += 1 C_0 = C_0_list[sample_idx[0]] feed_dict = { stabNet.inputs['IU']: batch_frames[:, :-1, :, :, :], stabNet.inputs['Iu']: np.expand_dims(batch_frames[:, -1, :, :, :], axis=1), stabNet.inputs['U_t_1_seq']: U_t_1_seq, stabNet.inputs['S_t_1_seq']: S_t_1_seq, stabNet.inputs['C_0']: C_0, } Is_pred, Is_pred_wo_C0, S_t_pred, U_t, C_0 = sess.run([ outputs_net['Is_pred'], outputs_net['Is_pred_wo_C0'], outputs_net['S_t_pred_seq'], outputs_net['U_t_seq'], outputs_net['B_t_wo_C0'] ], feed_dict) S[sample_idx[-1]] = S_t_pred U[sample_idx[-1]] = U_t C_0_list[sample_idx[-1]] = C_0 Is_pred = np.squeeze(Is_pred) Is_pred_wo_C0 = np.squeeze(Is_pred_wo_C0) output = np.uint8( np.concatenate( (total_frames[frame_idx].copy(), Is_pred, Is_pred_wo_C0), axis=1) * 255.) output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) out.write(np.uint8(output)) print('{}/{} {}/{} frame index: {}'.format( k + 1, len(test_video_list), frame_idx, int(total_frame_num - 1), sample_idx), flush=True) sample_idx = sample_idx + 1 cap.release() out.release()
def train(config, mode): ## Managers print(toGreen('Loading checkpoint manager...')) ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, mode, config.max_ckpt_num) ckpt_manager_itr = CKPT_Manager(config.LOG_DIR.ckpt_itr, mode, config.max_ckpt_num) ckpt_manager_init = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt, mode, config.max_ckpt_num) ckpt_manager_init_itr = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt_itr, mode, config.max_ckpt_num) ckpt_manager_perm = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt_perm, mode, 1) ## DEFINE SESSION seed_value = 1 tf.set_random_seed(seed_value) np.random.seed(seed_value) random.seed(seed_value) print(toGreen('Initializing session...')) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) ## DEFINE MODEL stabNet = StabNet(config.height, config.width) ## DEFINE DATA LOADERS print(toGreen('Loading dataloader...')) data_loader = Data_Loader(config, is_train=True, thread_num=config.thread_num) data_loader_test = Data_Loader(config.TEST, is_train=False, thread_num=config.thread_num) ## DEFINE TRAINER print(toGreen('Initializing Trainer...')) trainer = Trainer(stabNet, [data_loader, data_loader_test], config) ## DEFINE SUMMARY WRITER print(toGreen('Building summary writer...')) if config.is_pretrain: writer_scalar_itr_init = tf.summary.FileWriter( config.PRETRAIN.LOG_DIR.log_scalar_train_itr, flush_secs=30, filename_suffix='.scalor_log_itr_init') writer_scalar_epoch_init = tf.summary.FileWriter( config.PRETRAIN.LOG_DIR.log_scalar_train_epoch, flush_secs=30, filename_suffix='.scalor_log_epoch_init') writer_scalar_epoch_valid_init = tf.summary.FileWriter( config.PRETRAIN.LOG_DIR.log_scalar_valid, flush_secs=30, filename_suffix='.scalor_log_epoch_test_init') writer_image_init = tf.summary.FileWriter( config.PRETRAIN.LOG_DIR.log_image, flush_secs=30, filename_suffix='.image_log_init') writer_scalar_itr = tf.summary.FileWriter( config.LOG_DIR.log_scalar_train_itr, flush_secs=30, filename_suffix='.scalor_log_itr') writer_scalar_epoch = tf.summary.FileWriter( config.LOG_DIR.log_scalar_train_epoch, flush_secs=30, filename_suffix='.scalor_log_epoch') writer_scalar_epoch_valid = tf.summary.FileWriter( config.LOG_DIR.log_scalar_valid, flush_secs=30, filename_suffix='.scalor_log_epoch_test') writer_image = tf.summary.FileWriter(config.LOG_DIR.log_image, flush_secs=30, filename_suffix='.image_log') ## INITIALIZE SESSION print(toGreen('Initializing network...')) sess.run(tf.global_variables_initializer()) trainer.init_vars(sess) ckpt_manager_init.load_ckpt(sess, by_score=False) ckpt_manager_perm.load_ckpt(sess) if config.is_pretrain: print(toYellow('======== PRETRAINING START =========')) global_step = 0 for epoch in range(0, config.PRETRAIN.n_epoch): #for epoch in range(0, 1): # update learning rate trainer.update_learning_rate(epoch, config.PRETRAIN.lr_init, config.PRETRAIN.lr_decay_rate, config.PRETRAIN.decay_every, sess) errs_total_pretrain = collections.OrderedDict.fromkeys( trainer.pretrain_loss.keys(), 0.) errs = None epoch_time = time.time() idx = 0 while True: #for idx in range(0, 2): step_time = time.time() feed_dict, is_end = data_loader.feed_the_network() if is_end: break feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs) _, lr, errs = sess.run([ trainer.optim_init, trainer.learning_rate, trainer.pretrain_loss ], feed_dict) errs_total_pretrain = dict_operations(errs_total_pretrain, '+', errs) if global_step % config.write_log_every_itr == 0: summary_loss_itr, summary_image = sess.run( [trainer.scalar_sum_itr_init, trainer.image_sum_init], feed_dict) writer_scalar_itr_init.add_summary(summary_loss_itr, global_step) writer_image_init.add_summary(summary_image, global_step) # save checkpoint if (global_step) % config.PRETRAIN.write_ckpt_every_itr == 0: ckpt_manager_init_itr.save_ckpt( sess, trainer.pretraining_save_vars, '{:05d}_{:05d}'.format(epoch, global_step), score=errs_total_pretrain['total'] / (idx + 1)) print_logs('PRETRAIN', mode, epoch, step_time, idx, data_loader.num_itr, errs=errs, coefs=trainer.coef_container, lr=lr) global_step += 1 idx += 1 # save log errs_total_pretrain = dict_operations(errs_total_pretrain, '/', data_loader.num_itr) summary_loss_epoch_init = sess.run( trainer.summary_epoch_init, feed_dict=dict_operations(trainer.loss_epoch_init_placeholder, '=', errs_total_pretrain)) writer_scalar_epoch_init.add_summary(summary_loss_epoch_init, epoch) ## TEST errs_total_pretrain_test = collections.OrderedDict.fromkeys( trainer.pretrain_loss_test.keys(), 0.) errs = None epoch_time_test = time.time() idx = 0 while True: #for idx in range(0, 2): step_time = time.time() feed_dict, is_end = data_loader_test.feed_the_network() if is_end: break feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs) errs = sess.run(trainer.pretrain_loss_test, feed_dict) errs_total_pretrain_test = dict_operations( errs_total_pretrain_test, '+', errs) print_logs('PRETRAIN TEST', mode, epoch, step_time, idx, data_loader_test.num_itr, errs=errs, coefs=trainer.coef_container) idx += 1 # save log errs_total_pretrain_test = dict_operations( errs_total_pretrain_test, '/', data_loader_test.num_itr) summary_loss_test_init = sess.run( trainer.summary_epoch_init, feed_dict=dict_operations(trainer.loss_epoch_init_placeholder, '=', errs_total_pretrain_test)) writer_scalar_epoch_valid_init.add_summary(summary_loss_test_init, epoch) print_logs('TRAIN SUMMARY', mode, epoch, epoch_time, errs=errs_total_pretrain) print_logs('TEST SUMMARY', mode, epoch, epoch_time_test, errs=errs_total_pretrain_test) # save checkpoint if epoch % config.write_ckpt_every_epoch == 0: ckpt_manager_init.save_ckpt( sess, trainer.pretraining_save_vars, epoch, score=errs_total_pretrain_test['total']) # reset image log if epoch % config.refresh_image_log_every_epoch == 0: writer_image_init.close() remove_file_end_with(config.PRETRAIN.LOG_DIR.log_image, '*.image_log') writer_image_init.reopen() if config.pretrain_only: return else: data_loader.reset_to_train_input(stabNet) data_loader_test.reset_to_train_input(stabNet) print(toYellow('========== TRAINING START ==========')) global_step = 0 for epoch in range(0, config.n_epoch): #for epoch in range(0, 1): # update learning rate trainer.update_learning_rate(epoch, config.lr_init, config.lr_decay_rate, config.decay_every, sess) ## TRAIN errs_total_train = collections.OrderedDict.fromkeys( trainer.loss.keys(), 0.) errs = None epoch_time = time.time() idx = 0 #while True: for idx in range(0, 2): step_time = time.time() feed_dict, is_end = data_loader.feed_the_network() if is_end: break feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs) _, lr, errs = sess.run( [trainer.optim_main, trainer.learning_rate, trainer.loss], feed_dict) errs_total_train = dict_operations(errs_total_train, '+', errs) if global_step % config.write_ckpt_every_itr == 0: ckpt_manager_itr.save_ckpt( sess, trainer.save_vars, '{:05d}_{:05d}'.format(epoch, global_step), score=errs_total_train['total'] / (idx + 1)) if global_step % config.write_log_every_itr == 0: summary_loss_itr, summary_image = sess.run( [trainer.scalar_sum_itr, trainer.image_sum], feed_dict) writer_scalar_itr.add_summary(summary_loss_itr, global_step) writer_image.add_summary(summary_image, global_step) print_logs('TRAIN', mode, epoch, step_time, idx, data_loader.num_itr, errs=errs, coefs=trainer.coef_container, lr=lr) global_step += 1 idx += 1 # SAVE LOGS errs_total_train = dict_operations(errs_total_train, '/', data_loader.num_itr) summary_loss_epoch = sess.run(trainer.summary_epoch, feed_dict=dict_operations( trainer.loss_epoch_placeholder, '=', errs_total_train)) writer_scalar_epoch.add_summary(summary_loss_epoch, epoch) ## TEST errs_total_test = collections.OrderedDict.fromkeys( trainer.loss_test.keys(), 0.) epoch_time_test = time.time() idx = 0 #while True: for idx in range(0, 2): step_time = time.time() feed_dict, is_end = data_loader_test.feed_the_network() if is_end: break feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs) errs = sess.run(trainer.loss_test, feed_dict) errs_total_test = dict_operations(errs_total_test, '+', errs) print_logs('TEST', mode, epoch, step_time, idx, data_loader_test.num_itr, errs=errs, coefs=trainer.coef_container) idx += 1 # SAVE LOGS errs_total_test = dict_operations(errs_total_test, '/', data_loader_test.num_itr) summary_loss_epoch_test = sess.run(trainer.summary_epoch, feed_dict=dict_operations( trainer.loss_epoch_placeholder, '=', errs_total_test)) writer_scalar_epoch_valid.add_summary(summary_loss_epoch_test, epoch) ## CKPT if epoch % config.write_ckpt_every_epoch == 0: ckpt_manager.save_ckpt(sess, trainer.save_vars, epoch, score=errs_total_test['total']) ## RESET IMAGE SUMMARY if epoch % config.refresh_image_log_every_epoch == 0: writer_image.close() remove_file_end_with(config.LOG_DIR.log_image, '*.image_log') writer_image.reopen() print_logs('TRAIN SUMMARY', mode, epoch, epoch_time, errs=errs_total_train) print_logs('TEST SUMMARY', mode, epoch, epoch_time_test, errs=errs_total_test)
def evaluate(config, mode): print(toGreen('Loading checkpoint manager...')) print(config.LOG_DIR.ckpt, mode) ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, mode, 10) date = datetime.datetime.now().strftime('%Y.%m.%d.(%H%M)') print(toYellow('======== EVALUATION START =========')) ################################################## ## DEFINE MODEL print(toGreen('Initializing model')) model = create_model(config) model.eval() model.print() inputs = { 'inp': None, 'ref': None, 'inp_hist': None, 'ref_hist': None, 'seg_inp': None, 'seg_ref': None } inputs = collections.OrderedDict(sorted(inputs.items(), key=lambda t: t[0])) ## INITIALIZING VARIABLE print(toGreen('Initializing variables')) result, ckpt_name = ckpt_manager.load_ckpt( model.get_network(), by_score=config.EVAL.load_ckpt_by_score, name=config.EVAL.ckpt_name) print(result) save_path_root = os.path.join(config.EVAL.LOG_DIR.save, config.EVAL.eval_mode, ckpt_name, date) exists_or_mkdir(save_path_root) torch.save(model.get_network().state_dict(), os.path.join(save_path_root, ckpt_name + '.pytorch')) ################################################## inp_folder_path_list, _, _ = load_file_list(config.EVAL.inp_path) ref_folder_path_list, _, _ = load_file_list(config.EVAL.ref_path) inp_segmap_folder_path_list, _, _ = load_file_list( config.EVAL.inp_segmap_path) ref_segmap_folder_path_list, _, _ = load_file_list( config.EVAL.ref_segmap_path) print(toGreen('Starting Color Trasfer')) itr = 0 for folder_idx in np.arange(len(inp_folder_path_list)): inp_folder_path = inp_folder_path_list[folder_idx] ref_video_path = ref_folder_path_list[folder_idx] inp_segmap_folder_path = inp_segmap_folder_path_list[folder_idx] ref_segmap_video_path = ref_segmap_folder_path_list[folder_idx] _, inp_file_path_list, _ = load_file_list(inp_folder_path) _, ref_file_path_list, _ = load_file_list(ref_video_path) _, inp_segmap_file_path_list, _ = load_file_list( inp_segmap_folder_path) _, ref_segmap_file_path_list, _ = load_file_list(ref_segmap_video_path) for file_idx in np.arange(len(inp_file_path_list)): inp_path = inp_file_path_list[file_idx] ref_path = ref_file_path_list[file_idx] inp_segmap_path = inp_segmap_file_path_list[file_idx] ref_segmap_path = ref_segmap_file_path_list[file_idx] # inputs['inp'], inputs['inp_hist'] = torch.FloatTensor(_read_frame_cv(inp_path, config).transpose(0, 3, 1, 2)).cuda() inputs['inp'], inputs['inp_hist'] = _read_frame_cv( inp_path, config) inputs['ref'], inputs['ref_hist'] = _read_frame_cv( ref_path, config) inputs['seg_inp'], inputs['seg_ref'] = _read_segmap( inputs['inp'], inp_segmap_path, ref_segmap_path, config.is_rep) p = 30 reppad = torch.nn.ReplicationPad2d(p) for key, val in inputs.items(): inputs[key] = torch.FloatTensor(inputs[key].transpose( 0, 3, 1, 2)).cuda() inputs['inp'] = reppad(inputs['inp']) inputs['ref'] = reppad(inputs['ref']) inputs['seg_inp'] = reppad(inputs['seg_inp']) inputs['seg_ref'] = reppad(inputs['seg_ref']) with torch.no_grad(): outs = model.get_results(inputs) inputs['inp'] = inputs['inp'][:, :, p:(inputs['inp'].size(2) - p), p:(inputs['inp'].size(3) - p)] inputs['ref'] = inputs['ref'][:, :, p:(inputs['ref'].size(2) - p), p:(inputs['ref'].size(3) - p)] outs['result'] = outs['result'][:, :, p:(outs['result'].size(2) - p), p:(outs['result'].size(3) - p)] outs['result_idt'] = outs['result_idt'][:, :, p:( outs['result_idt'].size(2) - p), p:(outs['result_idt'].size(3) - p)] outs['seg_inp'] = outs['seg_inp'][:, :, p:(outs['seg_inp'].size(2) - p), p:(outs['seg_inp'].size(3) - p)] outs['seg_ref'] = outs['seg_ref'][:, :, p:(outs['seg_ref'].size(2) - p), p:(outs['seg_ref'].size(3) - p)] file_name = os.path.basename( inp_file_path_list[file_idx]).split('.')[0] save_path = save_path_root exists_or_mkdir(save_path) vutils.save_image(LAB2RGB_cv(inputs['inp'].detach().cpu(), config.type), '{}/{}_1_inp.png'.format(save_path, itr), nrow=3, padding=0, normalize=False) vutils.save_image(LAB2RGB_cv(inputs['ref'].detach().cpu(), config.type), '{}/{}_2_ref.png'.format(save_path, itr), nrow=3, padding=0, normalize=False) i = 3 out_keys = ['result', 'result_idt', 'seg_inp', 'seg_ref'] for key, val in outs.items(): if key in out_keys: if val is not None: if 'seg' in key: if config.identity is False and 'idt' in key: continue vutils.save_image( LAB2RGB_cv(val.detach().cpu(), 'rgb') / 8., '{}/{}_{}_out_{}.png'.format( save_path, itr, i, key), nrow=3, padding=0, normalize=False) else: if config.identity is False and 'idt' in key: continue vutils.save_image(LAB2RGB_cv( val.detach().cpu(), config.type), '{}/{}_{}_out_{}.png'.format( save_path, itr, i, key), nrow=3, padding=0, normalize=False) i += 1 #PSNR print( '[{}][{}/{}'.format(ckpt_name, folder_idx + 1, len(inp_folder_path_list)), '{}/{}]'.format(file_idx + 1, len(inp_file_path_list))) itr += 1