class Trainer(): def __init__(self, config): self.config = config self.summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr) ## model print(toGreen('Loading Model...')) self.model = create_model(config) self.model.print() ## checkpoint manager self.ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode, config.max_ckpt_num) if config.load_pretrained: print(toGreen('Loading pretrained Model...')) load_result = self.model.get_network().load_state_dict( torch.load(os.path.join('./ckpt', config.pre_ckpt_name))) lr = config.lr_init * (config.decay_rate**(config.epoch_start // config.decay_every)) print(toRed('\tlearning rate: {}'.format(lr))) print(toRed('\tload result: {}'.format(load_result))) self.model.set_optim(lr) ## training vars self.max_epoch = 10000 self.epoch_range = np.arange(config.epoch_start, self.max_epoch) self.itr_global = 0 self.itr = 0 self.err_epoch_train = 0 self.err_epoch_test = 0 def train(self): print(toYellow('======== TRAINING START =========')) for epoch in self.epoch_range: ## TRAIN ## self.itr = 0 self.err_epoch_train = 0 self.model.train() while True: if self.iteration(epoch): break err_epoch_train = self.err_epoch_train / self.itr ## TEST ## self.itr = 0 self.err_epoch_test = 0 self.model.eval() while True: with torch.no_grad(): if self.iteration(epoch, is_train=False): break err_epoch_test = self.err_epoch_test / self.itr ## LOG if epoch % self.config.write_ckpt_every_epoch == 0: self.ckpt_manager.save_ckpt(self.model.get_network(), epoch + 1, score=err_epoch_train) remove_file_end_with(self.config.LOG_DIR.sample, '*.png') self.summary.add_scalar('loss/epoch_train', err_epoch_train, epoch) self.summary.add_scalar('loss/epoch_test', err_epoch_test, epoch) def iteration(self, epoch, is_train=True): lr = None itr_time = time.time() is_end = self.model.iteration(epoch, is_train) if is_end: return True inputs = self.model.visuals['inputs'] errs = self.model.visuals['errs'] outs = self.model.visuals['outs'] lr = self.model.visuals['lr'] num_itr = self.model.visuals['num_itr'] if is_train: lr = self.model.update(epoch) if self.itr % config.write_log_every_itr == 0: try: self.summary.add_scalar('loss/itr', errs['total'].item(), self.itr_global) vutils.save_image(inputs['input'], '{}/{}_{}_1_input.png'.format( self.config.LOG_DIR.sample, epoch, self.itr), nrow=3, padding=0, normalize=False) i = 2 for key, val in outs.items(): if val is not None: vutils.save_image(val, '{}/{}_{}_{}_out_{}.png'.format( self.config.LOG_DIR.sample, epoch, self.itr, i, key), nrow=3, padding=0, normalize=False) i += 1 vutils.save_image(inputs['gt'], '{}/{}_{}_{}_gt.png'.format( self.config.LOG_DIR.sample, epoch, self.itr, i), nrow=3, padding=0, normalize=False) except Exception as ex: print('saving error: ', ex) print_logs('TRAIN', self.config.mode, epoch, itr_time, self.itr, num_itr, errs=errs, lr=lr) self.err_epoch_train += errs['total'].item() self.itr += 1 self.itr_global += 1 else: print_logs('TEST', self.config.mode, epoch, itr_time, self.itr, num_itr, errs=errs) self.err_epoch_test += errs['total'].item() self.itr += 1 gc.collect() return is_end
def train(config): summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr) ## inputs inputs = {'b_t_1': None, 'b_t': None, 's_t_1': None, 's_t': None} inputs = collections.OrderedDict(sorted(inputs.items(), key=lambda t: t[0])) ## model print(toGreen('Loading Model...')) moduleNetwork = Network().to(device) moduleNetwork.apply(weights_init) moduleNetwork_gt = Network().to(device) print(moduleNetwork) ## checkpoint manager ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode, config.max_ckpt_num) moduleNetwork.load_state_dict( torch.load('./network/network-default.pytorch')) moduleNetwork_gt.load_state_dict( torch.load('./network/network-default.pytorch')) ## data loader print(toGreen('Loading Data Loader...')) data_loader = Data_Loader(config, is_train=True, name='train', thread_num=config.thread_num) data_loader_test = Data_Loader(config, is_train=False, name="test", thread_num=config.thread_num) data_loader.init_data_loader(inputs) data_loader_test.init_data_loader(inputs) ## loss, optim print(toGreen('Building Loss & Optim...')) MSE_sum = torch.nn.MSELoss(reduction='sum') MSE_mean = torch.nn.MSELoss() optimizer = optim.Adam(moduleNetwork.parameters(), lr=config.lr_init, betas=(config.beta1, 0.999)) errs = collections.OrderedDict() print(toYellow('======== TRAINING START =========')) max_epoch = 10000 itr = 0 for epoch in np.arange(max_epoch): # train while True: itr_time = time.time() inputs, is_end = data_loader.get_feed() if is_end: break if config.loss == 'image': flow_bb = torch.nn.functional.interpolate( input=moduleNetwork(inputs['b_t'], inputs['b_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_bs = torch.nn.functional.interpolate( input=moduleNetwork(inputs['b_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_sb = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['b_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_ss = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) with torch.no_grad(): flow_ss_gt = torch.nn.functional.interpolate( input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_ss_gt) s_t_warped_bb = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_bb) s_t_warped_bs = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_bs) s_t_warped_sb = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_sb) s_t_warped_ss = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_ss) s_t_warped_bb_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_bb) s_t_warped_bs_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_bs) s_t_warped_sb_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_sb) s_t_warped_ss_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_ss) optimizer.zero_grad() errs['MSE_bb'] = MSE_sum( s_t_warped_bb * s_t_warped_bb_mask, inputs['s_t']) / s_t_warped_bb_mask.sum() errs['MSE_bs'] = MSE_sum( s_t_warped_bs * s_t_warped_bs_mask, inputs['s_t']) / s_t_warped_bs_mask.sum() errs['MSE_sb'] = MSE_sum( s_t_warped_sb * s_t_warped_sb_mask, inputs['s_t']) / s_t_warped_sb_mask.sum() errs['MSE_ss'] = MSE_sum( s_t_warped_ss * s_t_warped_ss_mask, inputs['s_t']) / s_t_warped_ss_mask.sum() errs['MSE_bb_mask_shape'] = MSE_mean(s_t_warped_bb_mask, s_t_warped_ss_mask_gt) errs['MSE_bs_mask_shape'] = MSE_mean(s_t_warped_bs_mask, s_t_warped_ss_mask_gt) errs['MSE_sb_mask_shape'] = MSE_mean(s_t_warped_sb_mask, s_t_warped_ss_mask_gt) errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask, s_t_warped_ss_mask_gt) errs['total'] = errs['MSE_bb'] + errs['MSE_bs'] + errs['MSE_sb'] + errs['MSE_ss'] \ + errs['MSE_bb_mask_shape'] + errs['MSE_bs_mask_shape'] + errs['MSE_sb_mask_shape'] + errs['MSE_ss_mask_shape'] if config.loss == 'image_ss': flow_ss = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) with torch.no_grad(): flow_ss_gt = torch.nn.functional.interpolate( input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_ss_gt) s_t_warped_ss = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_ss) s_t_warped_ss_mask = warp(tensorInput=torch.ones_like( inputs['s_t_1'], device=device), tensorFlow=flow_ss) optimizer.zero_grad() errs['MSE_ss'] = MSE_sum( s_t_warped_ss * s_t_warped_ss_mask, inputs['s_t']) / s_t_warped_ss_mask.sum() errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask, s_t_warped_ss_mask_gt) errs['total'] = errs['MSE_ss'] + errs['MSE_ss_mask_shape'] if config.loss == 'flow_only': flow_bb = torch.nn.functional.interpolate( input=moduleNetwork(inputs['b_t'], inputs['b_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_bs = torch.nn.functional.interpolate( input=moduleNetwork(inputs['b_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_sb = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['b_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) flow_ss = torch.nn.functional.interpolate( input=moduleNetwork(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) s_t_warped_ss = warp(tensorInput=inputs['s_t_1'], tensorFlow=flow_ss) with torch.no_grad(): flow_ss_gt = torch.nn.functional.interpolate( input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']), size=(config.height, config.width), mode='bilinear', align_corners=False) optimizer.zero_grad() # liteflow_flow_only errs['MSE_bb_ss'] = MSE_mean(flow_bb, flow_ss_gt) errs['MSE_bs_ss'] = MSE_mean(flow_bs, flow_ss_gt) errs['MSE_sb_ss'] = MSE_mean(flow_sb, flow_ss_gt) errs['MSE_ss_ss'] = MSE_mean(flow_ss, flow_ss_gt) errs['total'] = errs['MSE_bb_ss'] + errs['MSE_bs_ss'] + errs[ 'MSE_sb_ss'] + errs['MSE_ss_ss'] errs['total'].backward() optimizer.step() lr = adjust_learning_rate(optimizer, epoch, config.decay_rate, config.decay_every, config.lr_init) if itr % config.write_log_every_itr == 0: summary.add_scalar('loss/loss_mse', errs['total'].item(), itr) vutils.save_image(inputs['s_t_1'].detach().cpu(), '{}/{}_1_input.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) vutils.save_image(s_t_warped_ss.detach().cpu(), '{}/{}_2_warped_ss.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) vutils.save_image(inputs['s_t'].detach().cpu(), '{}/{}_3_gt.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) if config.loss == 'image_ss': vutils.save_image(s_t_warped_ss_mask.detach().cpu(), '{}/{}_4_s_t_wapred_ss_mask.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) elif config.loss != 'flow_only': vutils.save_image(s_t_warped_bb_mask.detach().cpu(), '{}/{}_4_s_t_wapred_bb_mask.png'.format( config.LOG_DIR.sample, itr), nrow=3, padding=0, normalize=False) if itr % config.refresh_image_log_every_itr == 0: remove_file_end_with(config.LOG_DIR.sample, '*.png') print_logs('TRAIN', config.mode, epoch, itr_time, itr, data_loader.num_itr, errs=errs, lr=lr) itr += 1 if epoch % config.write_ckpt_every_epoch == 0: ckpt_manager.save_ckpt(moduleNetwork, epoch, score=errs['total'].item())
def train(config, mode): ## Managers print(toGreen('Loading checkpoint manager...')) ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, mode, config.max_ckpt_num) ckpt_manager_itr = CKPT_Manager(config.LOG_DIR.ckpt_itr, mode, config.max_ckpt_num) ckpt_manager_init = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt, mode, config.max_ckpt_num) ckpt_manager_init_itr = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt_itr, mode, config.max_ckpt_num) ckpt_manager_perm = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt_perm, mode, 1) ## DEFINE SESSION seed_value = 1 tf.set_random_seed(seed_value) np.random.seed(seed_value) random.seed(seed_value) print(toGreen('Initializing session...')) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) ## DEFINE MODEL stabNet = StabNet(config.height, config.width) ## DEFINE DATA LOADERS print(toGreen('Loading dataloader...')) data_loader = Data_Loader(config, is_train=True, thread_num=config.thread_num) data_loader_test = Data_Loader(config.TEST, is_train=False, thread_num=config.thread_num) ## DEFINE TRAINER print(toGreen('Initializing Trainer...')) trainer = Trainer(stabNet, [data_loader, data_loader_test], config) ## DEFINE SUMMARY WRITER print(toGreen('Building summary writer...')) if config.is_pretrain: writer_scalar_itr_init = tf.summary.FileWriter( config.PRETRAIN.LOG_DIR.log_scalar_train_itr, flush_secs=30, filename_suffix='.scalor_log_itr_init') writer_scalar_epoch_init = tf.summary.FileWriter( config.PRETRAIN.LOG_DIR.log_scalar_train_epoch, flush_secs=30, filename_suffix='.scalor_log_epoch_init') writer_scalar_epoch_valid_init = tf.summary.FileWriter( config.PRETRAIN.LOG_DIR.log_scalar_valid, flush_secs=30, filename_suffix='.scalor_log_epoch_test_init') writer_image_init = tf.summary.FileWriter( config.PRETRAIN.LOG_DIR.log_image, flush_secs=30, filename_suffix='.image_log_init') writer_scalar_itr = tf.summary.FileWriter( config.LOG_DIR.log_scalar_train_itr, flush_secs=30, filename_suffix='.scalor_log_itr') writer_scalar_epoch = tf.summary.FileWriter( config.LOG_DIR.log_scalar_train_epoch, flush_secs=30, filename_suffix='.scalor_log_epoch') writer_scalar_epoch_valid = tf.summary.FileWriter( config.LOG_DIR.log_scalar_valid, flush_secs=30, filename_suffix='.scalor_log_epoch_test') writer_image = tf.summary.FileWriter(config.LOG_DIR.log_image, flush_secs=30, filename_suffix='.image_log') ## INITIALIZE SESSION print(toGreen('Initializing network...')) sess.run(tf.global_variables_initializer()) trainer.init_vars(sess) ckpt_manager_init.load_ckpt(sess, by_score=False) ckpt_manager_perm.load_ckpt(sess) if config.is_pretrain: print(toYellow('======== PRETRAINING START =========')) global_step = 0 for epoch in range(0, config.PRETRAIN.n_epoch): #for epoch in range(0, 1): # update learning rate trainer.update_learning_rate(epoch, config.PRETRAIN.lr_init, config.PRETRAIN.lr_decay_rate, config.PRETRAIN.decay_every, sess) errs_total_pretrain = collections.OrderedDict.fromkeys( trainer.pretrain_loss.keys(), 0.) errs = None epoch_time = time.time() idx = 0 while True: #for idx in range(0, 2): step_time = time.time() feed_dict, is_end = data_loader.feed_the_network() if is_end: break feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs) _, lr, errs = sess.run([ trainer.optim_init, trainer.learning_rate, trainer.pretrain_loss ], feed_dict) errs_total_pretrain = dict_operations(errs_total_pretrain, '+', errs) if global_step % config.write_log_every_itr == 0: summary_loss_itr, summary_image = sess.run( [trainer.scalar_sum_itr_init, trainer.image_sum_init], feed_dict) writer_scalar_itr_init.add_summary(summary_loss_itr, global_step) writer_image_init.add_summary(summary_image, global_step) # save checkpoint if (global_step) % config.PRETRAIN.write_ckpt_every_itr == 0: ckpt_manager_init_itr.save_ckpt( sess, trainer.pretraining_save_vars, '{:05d}_{:05d}'.format(epoch, global_step), score=errs_total_pretrain['total'] / (idx + 1)) print_logs('PRETRAIN', mode, epoch, step_time, idx, data_loader.num_itr, errs=errs, coefs=trainer.coef_container, lr=lr) global_step += 1 idx += 1 # save log errs_total_pretrain = dict_operations(errs_total_pretrain, '/', data_loader.num_itr) summary_loss_epoch_init = sess.run( trainer.summary_epoch_init, feed_dict=dict_operations(trainer.loss_epoch_init_placeholder, '=', errs_total_pretrain)) writer_scalar_epoch_init.add_summary(summary_loss_epoch_init, epoch) ## TEST errs_total_pretrain_test = collections.OrderedDict.fromkeys( trainer.pretrain_loss_test.keys(), 0.) errs = None epoch_time_test = time.time() idx = 0 while True: #for idx in range(0, 2): step_time = time.time() feed_dict, is_end = data_loader_test.feed_the_network() if is_end: break feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs) errs = sess.run(trainer.pretrain_loss_test, feed_dict) errs_total_pretrain_test = dict_operations( errs_total_pretrain_test, '+', errs) print_logs('PRETRAIN TEST', mode, epoch, step_time, idx, data_loader_test.num_itr, errs=errs, coefs=trainer.coef_container) idx += 1 # save log errs_total_pretrain_test = dict_operations( errs_total_pretrain_test, '/', data_loader_test.num_itr) summary_loss_test_init = sess.run( trainer.summary_epoch_init, feed_dict=dict_operations(trainer.loss_epoch_init_placeholder, '=', errs_total_pretrain_test)) writer_scalar_epoch_valid_init.add_summary(summary_loss_test_init, epoch) print_logs('TRAIN SUMMARY', mode, epoch, epoch_time, errs=errs_total_pretrain) print_logs('TEST SUMMARY', mode, epoch, epoch_time_test, errs=errs_total_pretrain_test) # save checkpoint if epoch % config.write_ckpt_every_epoch == 0: ckpt_manager_init.save_ckpt( sess, trainer.pretraining_save_vars, epoch, score=errs_total_pretrain_test['total']) # reset image log if epoch % config.refresh_image_log_every_epoch == 0: writer_image_init.close() remove_file_end_with(config.PRETRAIN.LOG_DIR.log_image, '*.image_log') writer_image_init.reopen() if config.pretrain_only: return else: data_loader.reset_to_train_input(stabNet) data_loader_test.reset_to_train_input(stabNet) print(toYellow('========== TRAINING START ==========')) global_step = 0 for epoch in range(0, config.n_epoch): #for epoch in range(0, 1): # update learning rate trainer.update_learning_rate(epoch, config.lr_init, config.lr_decay_rate, config.decay_every, sess) ## TRAIN errs_total_train = collections.OrderedDict.fromkeys( trainer.loss.keys(), 0.) errs = None epoch_time = time.time() idx = 0 #while True: for idx in range(0, 2): step_time = time.time() feed_dict, is_end = data_loader.feed_the_network() if is_end: break feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs) _, lr, errs = sess.run( [trainer.optim_main, trainer.learning_rate, trainer.loss], feed_dict) errs_total_train = dict_operations(errs_total_train, '+', errs) if global_step % config.write_ckpt_every_itr == 0: ckpt_manager_itr.save_ckpt( sess, trainer.save_vars, '{:05d}_{:05d}'.format(epoch, global_step), score=errs_total_train['total'] / (idx + 1)) if global_step % config.write_log_every_itr == 0: summary_loss_itr, summary_image = sess.run( [trainer.scalar_sum_itr, trainer.image_sum], feed_dict) writer_scalar_itr.add_summary(summary_loss_itr, global_step) writer_image.add_summary(summary_image, global_step) print_logs('TRAIN', mode, epoch, step_time, idx, data_loader.num_itr, errs=errs, coefs=trainer.coef_container, lr=lr) global_step += 1 idx += 1 # SAVE LOGS errs_total_train = dict_operations(errs_total_train, '/', data_loader.num_itr) summary_loss_epoch = sess.run(trainer.summary_epoch, feed_dict=dict_operations( trainer.loss_epoch_placeholder, '=', errs_total_train)) writer_scalar_epoch.add_summary(summary_loss_epoch, epoch) ## TEST errs_total_test = collections.OrderedDict.fromkeys( trainer.loss_test.keys(), 0.) epoch_time_test = time.time() idx = 0 #while True: for idx in range(0, 2): step_time = time.time() feed_dict, is_end = data_loader_test.feed_the_network() if is_end: break feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs) errs = sess.run(trainer.loss_test, feed_dict) errs_total_test = dict_operations(errs_total_test, '+', errs) print_logs('TEST', mode, epoch, step_time, idx, data_loader_test.num_itr, errs=errs, coefs=trainer.coef_container) idx += 1 # SAVE LOGS errs_total_test = dict_operations(errs_total_test, '/', data_loader_test.num_itr) summary_loss_epoch_test = sess.run(trainer.summary_epoch, feed_dict=dict_operations( trainer.loss_epoch_placeholder, '=', errs_total_test)) writer_scalar_epoch_valid.add_summary(summary_loss_epoch_test, epoch) ## CKPT if epoch % config.write_ckpt_every_epoch == 0: ckpt_manager.save_ckpt(sess, trainer.save_vars, epoch, score=errs_total_test['total']) ## RESET IMAGE SUMMARY if epoch % config.refresh_image_log_every_epoch == 0: writer_image.close() remove_file_end_with(config.LOG_DIR.log_image, '*.image_log') writer_image.reopen() print_logs('TRAIN SUMMARY', mode, epoch, epoch_time, errs=errs_total_train) print_logs('TEST SUMMARY', mode, epoch, epoch_time_test, errs=errs_total_test)