コード例 #1
0
class Trainer():
    def __init__(self, config):
        self.config = config
        self.summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr)

        ## model
        print(toGreen('Loading Model...'))
        self.model = create_model(config)
        self.model.print()

        ## checkpoint manager
        self.ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode,
                                         config.max_ckpt_num)
        if config.load_pretrained:
            print(toGreen('Loading pretrained Model...'))
            load_result = self.model.get_network().load_state_dict(
                torch.load(os.path.join('./ckpt', config.pre_ckpt_name)))
            lr = config.lr_init * (config.decay_rate**(config.epoch_start //
                                                       config.decay_every))
            print(toRed('\tlearning rate: {}'.format(lr)))
            print(toRed('\tload result: {}'.format(load_result)))
            self.model.set_optim(lr)

        ## training vars
        self.max_epoch = 10000
        self.epoch_range = np.arange(config.epoch_start, self.max_epoch)
        self.itr_global = 0
        self.itr = 0
        self.err_epoch_train = 0
        self.err_epoch_test = 0

    def train(self):
        print(toYellow('======== TRAINING START ========='))
        for epoch in self.epoch_range:
            ## TRAIN ##
            self.itr = 0
            self.err_epoch_train = 0
            self.model.train()
            while True:
                if self.iteration(epoch): break
            err_epoch_train = self.err_epoch_train / self.itr

            ## TEST ##
            self.itr = 0
            self.err_epoch_test = 0
            self.model.eval()
            while True:
                with torch.no_grad():
                    if self.iteration(epoch, is_train=False): break
            err_epoch_test = self.err_epoch_test / self.itr

            ## LOG
            if epoch % self.config.write_ckpt_every_epoch == 0:
                self.ckpt_manager.save_ckpt(self.model.get_network(),
                                            epoch + 1,
                                            score=err_epoch_train)
                remove_file_end_with(self.config.LOG_DIR.sample, '*.png')

            self.summary.add_scalar('loss/epoch_train', err_epoch_train, epoch)
            self.summary.add_scalar('loss/epoch_test', err_epoch_test, epoch)

    def iteration(self, epoch, is_train=True):
        lr = None
        itr_time = time.time()

        is_end = self.model.iteration(epoch, is_train)
        if is_end: return True

        inputs = self.model.visuals['inputs']
        errs = self.model.visuals['errs']
        outs = self.model.visuals['outs']
        lr = self.model.visuals['lr']
        num_itr = self.model.visuals['num_itr']

        if is_train:
            lr = self.model.update(epoch)

            if self.itr % config.write_log_every_itr == 0:
                try:
                    self.summary.add_scalar('loss/itr', errs['total'].item(),
                                            self.itr_global)
                    vutils.save_image(inputs['input'],
                                      '{}/{}_{}_1_input.png'.format(
                                          self.config.LOG_DIR.sample, epoch,
                                          self.itr),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)
                    i = 2
                    for key, val in outs.items():
                        if val is not None:
                            vutils.save_image(val,
                                              '{}/{}_{}_{}_out_{}.png'.format(
                                                  self.config.LOG_DIR.sample,
                                                  epoch, self.itr, i, key),
                                              nrow=3,
                                              padding=0,
                                              normalize=False)
                        i += 1

                    vutils.save_image(inputs['gt'],
                                      '{}/{}_{}_{}_gt.png'.format(
                                          self.config.LOG_DIR.sample, epoch,
                                          self.itr, i),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)
                except Exception as ex:
                    print('saving error: ', ex)

            print_logs('TRAIN',
                       self.config.mode,
                       epoch,
                       itr_time,
                       self.itr,
                       num_itr,
                       errs=errs,
                       lr=lr)
            self.err_epoch_train += errs['total'].item()
            self.itr += 1
            self.itr_global += 1
        else:
            print_logs('TEST',
                       self.config.mode,
                       epoch,
                       itr_time,
                       self.itr,
                       num_itr,
                       errs=errs)
            self.err_epoch_test += errs['total'].item()
            self.itr += 1

        gc.collect()
        return is_end
コード例 #2
0
def train(config):
    summary = SummaryWriter(config.LOG_DIR.log_scalar_train_itr)

    ## inputs
    inputs = {'b_t_1': None, 'b_t': None, 's_t_1': None, 's_t': None}
    inputs = collections.OrderedDict(sorted(inputs.items(),
                                            key=lambda t: t[0]))

    ## model
    print(toGreen('Loading Model...'))
    moduleNetwork = Network().to(device)
    moduleNetwork.apply(weights_init)
    moduleNetwork_gt = Network().to(device)
    print(moduleNetwork)

    ## checkpoint manager
    ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, config.mode,
                                config.max_ckpt_num)
    moduleNetwork.load_state_dict(
        torch.load('./network/network-default.pytorch'))
    moduleNetwork_gt.load_state_dict(
        torch.load('./network/network-default.pytorch'))

    ## data loader
    print(toGreen('Loading Data Loader...'))
    data_loader = Data_Loader(config,
                              is_train=True,
                              name='train',
                              thread_num=config.thread_num)
    data_loader_test = Data_Loader(config,
                                   is_train=False,
                                   name="test",
                                   thread_num=config.thread_num)

    data_loader.init_data_loader(inputs)
    data_loader_test.init_data_loader(inputs)

    ## loss, optim
    print(toGreen('Building Loss & Optim...'))
    MSE_sum = torch.nn.MSELoss(reduction='sum')
    MSE_mean = torch.nn.MSELoss()
    optimizer = optim.Adam(moduleNetwork.parameters(),
                           lr=config.lr_init,
                           betas=(config.beta1, 0.999))
    errs = collections.OrderedDict()

    print(toYellow('======== TRAINING START ========='))
    max_epoch = 10000
    itr = 0
    for epoch in np.arange(max_epoch):

        # train
        while True:
            itr_time = time.time()

            inputs, is_end = data_loader.get_feed()
            if is_end: break

            if config.loss == 'image':
                flow_bb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_bs = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_sb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_ss = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)

                with torch.no_grad():
                    flow_ss_gt = torch.nn.functional.interpolate(
                        input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']),
                        size=(config.height, config.width),
                        mode='bilinear',
                        align_corners=False)
                    s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like(
                        inputs['s_t_1'], device=device),
                                                 tensorFlow=flow_ss_gt)

                s_t_warped_bb = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_bb)
                s_t_warped_bs = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_bs)
                s_t_warped_sb = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_sb)
                s_t_warped_ss = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_ss)

                s_t_warped_bb_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_bb)
                s_t_warped_bs_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_bs)
                s_t_warped_sb_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_sb)
                s_t_warped_ss_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_ss)

                optimizer.zero_grad()

                errs['MSE_bb'] = MSE_sum(
                    s_t_warped_bb * s_t_warped_bb_mask,
                    inputs['s_t']) / s_t_warped_bb_mask.sum()
                errs['MSE_bs'] = MSE_sum(
                    s_t_warped_bs * s_t_warped_bs_mask,
                    inputs['s_t']) / s_t_warped_bs_mask.sum()
                errs['MSE_sb'] = MSE_sum(
                    s_t_warped_sb * s_t_warped_sb_mask,
                    inputs['s_t']) / s_t_warped_sb_mask.sum()
                errs['MSE_ss'] = MSE_sum(
                    s_t_warped_ss * s_t_warped_ss_mask,
                    inputs['s_t']) / s_t_warped_ss_mask.sum()

                errs['MSE_bb_mask_shape'] = MSE_mean(s_t_warped_bb_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['MSE_bs_mask_shape'] = MSE_mean(s_t_warped_bs_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['MSE_sb_mask_shape'] = MSE_mean(s_t_warped_sb_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask,
                                                     s_t_warped_ss_mask_gt)

                errs['total'] = errs['MSE_bb'] + errs['MSE_bs'] + errs['MSE_sb'] + errs['MSE_ss'] \
                              + errs['MSE_bb_mask_shape'] + errs['MSE_bs_mask_shape'] + errs['MSE_sb_mask_shape'] + errs['MSE_ss_mask_shape']

            if config.loss == 'image_ss':
                flow_ss = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                with torch.no_grad():
                    flow_ss_gt = torch.nn.functional.interpolate(
                        input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']),
                        size=(config.height, config.width),
                        mode='bilinear',
                        align_corners=False)
                    s_t_warped_ss_mask_gt = warp(tensorInput=torch.ones_like(
                        inputs['s_t_1'], device=device),
                                                 tensorFlow=flow_ss_gt)

                s_t_warped_ss = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_ss)
                s_t_warped_ss_mask = warp(tensorInput=torch.ones_like(
                    inputs['s_t_1'], device=device),
                                          tensorFlow=flow_ss)

                optimizer.zero_grad()

                errs['MSE_ss'] = MSE_sum(
                    s_t_warped_ss * s_t_warped_ss_mask,
                    inputs['s_t']) / s_t_warped_ss_mask.sum()
                errs['MSE_ss_mask_shape'] = MSE_mean(s_t_warped_ss_mask,
                                                     s_t_warped_ss_mask_gt)
                errs['total'] = errs['MSE_ss'] + errs['MSE_ss_mask_shape']

            if config.loss == 'flow_only':
                flow_bb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_bs = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['b_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_sb = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['b_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)
                flow_ss = torch.nn.functional.interpolate(
                    input=moduleNetwork(inputs['s_t'], inputs['s_t_1']),
                    size=(config.height, config.width),
                    mode='bilinear',
                    align_corners=False)

                s_t_warped_ss = warp(tensorInput=inputs['s_t_1'],
                                     tensorFlow=flow_ss)

                with torch.no_grad():
                    flow_ss_gt = torch.nn.functional.interpolate(
                        input=moduleNetwork_gt(inputs['s_t'], inputs['s_t_1']),
                        size=(config.height, config.width),
                        mode='bilinear',
                        align_corners=False)

                optimizer.zero_grad()

                # liteflow_flow_only
                errs['MSE_bb_ss'] = MSE_mean(flow_bb, flow_ss_gt)
                errs['MSE_bs_ss'] = MSE_mean(flow_bs, flow_ss_gt)
                errs['MSE_sb_ss'] = MSE_mean(flow_sb, flow_ss_gt)
                errs['MSE_ss_ss'] = MSE_mean(flow_ss, flow_ss_gt)
                errs['total'] = errs['MSE_bb_ss'] + errs['MSE_bs_ss'] + errs[
                    'MSE_sb_ss'] + errs['MSE_ss_ss']

            errs['total'].backward()
            optimizer.step()

            lr = adjust_learning_rate(optimizer, epoch, config.decay_rate,
                                      config.decay_every, config.lr_init)

            if itr % config.write_log_every_itr == 0:
                summary.add_scalar('loss/loss_mse', errs['total'].item(), itr)
                vutils.save_image(inputs['s_t_1'].detach().cpu(),
                                  '{}/{}_1_input.png'.format(
                                      config.LOG_DIR.sample, itr),
                                  nrow=3,
                                  padding=0,
                                  normalize=False)
                vutils.save_image(s_t_warped_ss.detach().cpu(),
                                  '{}/{}_2_warped_ss.png'.format(
                                      config.LOG_DIR.sample, itr),
                                  nrow=3,
                                  padding=0,
                                  normalize=False)
                vutils.save_image(inputs['s_t'].detach().cpu(),
                                  '{}/{}_3_gt.png'.format(
                                      config.LOG_DIR.sample, itr),
                                  nrow=3,
                                  padding=0,
                                  normalize=False)

                if config.loss == 'image_ss':
                    vutils.save_image(s_t_warped_ss_mask.detach().cpu(),
                                      '{}/{}_4_s_t_wapred_ss_mask.png'.format(
                                          config.LOG_DIR.sample, itr),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)
                elif config.loss != 'flow_only':
                    vutils.save_image(s_t_warped_bb_mask.detach().cpu(),
                                      '{}/{}_4_s_t_wapred_bb_mask.png'.format(
                                          config.LOG_DIR.sample, itr),
                                      nrow=3,
                                      padding=0,
                                      normalize=False)

            if itr % config.refresh_image_log_every_itr == 0:
                remove_file_end_with(config.LOG_DIR.sample, '*.png')

            print_logs('TRAIN',
                       config.mode,
                       epoch,
                       itr_time,
                       itr,
                       data_loader.num_itr,
                       errs=errs,
                       lr=lr)
            itr += 1

        if epoch % config.write_ckpt_every_epoch == 0:
            ckpt_manager.save_ckpt(moduleNetwork,
                                   epoch,
                                   score=errs['total'].item())
コード例 #3
0
def train(config, mode):
    ## Managers
    print(toGreen('Loading checkpoint manager...'))
    ckpt_manager = CKPT_Manager(config.LOG_DIR.ckpt, mode, config.max_ckpt_num)
    ckpt_manager_itr = CKPT_Manager(config.LOG_DIR.ckpt_itr, mode,
                                    config.max_ckpt_num)
    ckpt_manager_init = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt, mode,
                                     config.max_ckpt_num)
    ckpt_manager_init_itr = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt_itr,
                                         mode, config.max_ckpt_num)
    ckpt_manager_perm = CKPT_Manager(config.PRETRAIN.LOG_DIR.ckpt_perm, mode,
                                     1)

    ## DEFINE SESSION
    seed_value = 1
    tf.set_random_seed(seed_value)
    np.random.seed(seed_value)
    random.seed(seed_value)

    print(toGreen('Initializing session...'))
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))

    ## DEFINE MODEL
    stabNet = StabNet(config.height, config.width)

    ## DEFINE DATA LOADERS
    print(toGreen('Loading dataloader...'))
    data_loader = Data_Loader(config,
                              is_train=True,
                              thread_num=config.thread_num)
    data_loader_test = Data_Loader(config.TEST,
                                   is_train=False,
                                   thread_num=config.thread_num)

    ## DEFINE TRAINER
    print(toGreen('Initializing Trainer...'))
    trainer = Trainer(stabNet, [data_loader, data_loader_test], config)

    ## DEFINE SUMMARY WRITER
    print(toGreen('Building summary writer...'))
    if config.is_pretrain:
        writer_scalar_itr_init = tf.summary.FileWriter(
            config.PRETRAIN.LOG_DIR.log_scalar_train_itr,
            flush_secs=30,
            filename_suffix='.scalor_log_itr_init')
        writer_scalar_epoch_init = tf.summary.FileWriter(
            config.PRETRAIN.LOG_DIR.log_scalar_train_epoch,
            flush_secs=30,
            filename_suffix='.scalor_log_epoch_init')
        writer_scalar_epoch_valid_init = tf.summary.FileWriter(
            config.PRETRAIN.LOG_DIR.log_scalar_valid,
            flush_secs=30,
            filename_suffix='.scalor_log_epoch_test_init')
        writer_image_init = tf.summary.FileWriter(
            config.PRETRAIN.LOG_DIR.log_image,
            flush_secs=30,
            filename_suffix='.image_log_init')

    writer_scalar_itr = tf.summary.FileWriter(
        config.LOG_DIR.log_scalar_train_itr,
        flush_secs=30,
        filename_suffix='.scalor_log_itr')
    writer_scalar_epoch = tf.summary.FileWriter(
        config.LOG_DIR.log_scalar_train_epoch,
        flush_secs=30,
        filename_suffix='.scalor_log_epoch')
    writer_scalar_epoch_valid = tf.summary.FileWriter(
        config.LOG_DIR.log_scalar_valid,
        flush_secs=30,
        filename_suffix='.scalor_log_epoch_test')
    writer_image = tf.summary.FileWriter(config.LOG_DIR.log_image,
                                         flush_secs=30,
                                         filename_suffix='.image_log')

    ## INITIALIZE SESSION
    print(toGreen('Initializing network...'))
    sess.run(tf.global_variables_initializer())
    trainer.init_vars(sess)

    ckpt_manager_init.load_ckpt(sess, by_score=False)
    ckpt_manager_perm.load_ckpt(sess)

    if config.is_pretrain:
        print(toYellow('======== PRETRAINING START ========='))
        global_step = 0
        for epoch in range(0, config.PRETRAIN.n_epoch):
            #for epoch in range(0, 1):

            # update learning rate
            trainer.update_learning_rate(epoch, config.PRETRAIN.lr_init,
                                         config.PRETRAIN.lr_decay_rate,
                                         config.PRETRAIN.decay_every, sess)

            errs_total_pretrain = collections.OrderedDict.fromkeys(
                trainer.pretrain_loss.keys(), 0.)
            errs = None
            epoch_time = time.time()
            idx = 0
            while True:
                #for idx in range(0, 2):
                step_time = time.time()

                feed_dict, is_end = data_loader.feed_the_network()
                if is_end: break
                feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs)
                _, lr, errs = sess.run([
                    trainer.optim_init, trainer.learning_rate,
                    trainer.pretrain_loss
                ], feed_dict)
                errs_total_pretrain = dict_operations(errs_total_pretrain, '+',
                                                      errs)

                if global_step % config.write_log_every_itr == 0:
                    summary_loss_itr, summary_image = sess.run(
                        [trainer.scalar_sum_itr_init, trainer.image_sum_init],
                        feed_dict)
                    writer_scalar_itr_init.add_summary(summary_loss_itr,
                                                       global_step)
                    writer_image_init.add_summary(summary_image, global_step)

                # save checkpoint
                if (global_step) % config.PRETRAIN.write_ckpt_every_itr == 0:
                    ckpt_manager_init_itr.save_ckpt(
                        sess,
                        trainer.pretraining_save_vars,
                        '{:05d}_{:05d}'.format(epoch, global_step),
                        score=errs_total_pretrain['total'] / (idx + 1))

                print_logs('PRETRAIN',
                           mode,
                           epoch,
                           step_time,
                           idx,
                           data_loader.num_itr,
                           errs=errs,
                           coefs=trainer.coef_container,
                           lr=lr)

                global_step += 1
                idx += 1

            # save log
            errs_total_pretrain = dict_operations(errs_total_pretrain, '/',
                                                  data_loader.num_itr)
            summary_loss_epoch_init = sess.run(
                trainer.summary_epoch_init,
                feed_dict=dict_operations(trainer.loss_epoch_init_placeholder,
                                          '=', errs_total_pretrain))
            writer_scalar_epoch_init.add_summary(summary_loss_epoch_init,
                                                 epoch)

            ## TEST
            errs_total_pretrain_test = collections.OrderedDict.fromkeys(
                trainer.pretrain_loss_test.keys(), 0.)
            errs = None
            epoch_time_test = time.time()
            idx = 0
            while True:
                #for idx in range(0, 2):
                step_time = time.time()

                feed_dict, is_end = data_loader_test.feed_the_network()
                if is_end: break
                feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs)
                errs = sess.run(trainer.pretrain_loss_test, feed_dict)

                errs_total_pretrain_test = dict_operations(
                    errs_total_pretrain_test, '+', errs)

                print_logs('PRETRAIN TEST',
                           mode,
                           epoch,
                           step_time,
                           idx,
                           data_loader_test.num_itr,
                           errs=errs,
                           coefs=trainer.coef_container)
                idx += 1

            # save log
            errs_total_pretrain_test = dict_operations(
                errs_total_pretrain_test, '/', data_loader_test.num_itr)
            summary_loss_test_init = sess.run(
                trainer.summary_epoch_init,
                feed_dict=dict_operations(trainer.loss_epoch_init_placeholder,
                                          '=', errs_total_pretrain_test))
            writer_scalar_epoch_valid_init.add_summary(summary_loss_test_init,
                                                       epoch)

            print_logs('TRAIN SUMMARY',
                       mode,
                       epoch,
                       epoch_time,
                       errs=errs_total_pretrain)
            print_logs('TEST SUMMARY',
                       mode,
                       epoch,
                       epoch_time_test,
                       errs=errs_total_pretrain_test)

            # save checkpoint
            if epoch % config.write_ckpt_every_epoch == 0:
                ckpt_manager_init.save_ckpt(
                    sess,
                    trainer.pretraining_save_vars,
                    epoch,
                    score=errs_total_pretrain_test['total'])

            # reset image log
            if epoch % config.refresh_image_log_every_epoch == 0:
                writer_image_init.close()
                remove_file_end_with(config.PRETRAIN.LOG_DIR.log_image,
                                     '*.image_log')
                writer_image_init.reopen()

        if config.pretrain_only:
            return
        else:
            data_loader.reset_to_train_input(stabNet)
            data_loader_test.reset_to_train_input(stabNet)

    print(toYellow('========== TRAINING START =========='))
    global_step = 0
    for epoch in range(0, config.n_epoch):
        #for epoch in range(0, 1):
        # update learning rate
        trainer.update_learning_rate(epoch, config.lr_init,
                                     config.lr_decay_rate, config.decay_every,
                                     sess)

        ## TRAIN
        errs_total_train = collections.OrderedDict.fromkeys(
            trainer.loss.keys(), 0.)
        errs = None
        epoch_time = time.time()
        idx = 0
        #while True:
        for idx in range(0, 2):
            step_time = time.time()

            feed_dict, is_end = data_loader.feed_the_network()
            if is_end: break
            feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs)
            _, lr, errs = sess.run(
                [trainer.optim_main, trainer.learning_rate, trainer.loss],
                feed_dict)

            errs_total_train = dict_operations(errs_total_train, '+', errs)

            if global_step % config.write_ckpt_every_itr == 0:
                ckpt_manager_itr.save_ckpt(
                    sess,
                    trainer.save_vars,
                    '{:05d}_{:05d}'.format(epoch, global_step),
                    score=errs_total_train['total'] / (idx + 1))

            if global_step % config.write_log_every_itr == 0:
                summary_loss_itr, summary_image = sess.run(
                    [trainer.scalar_sum_itr, trainer.image_sum], feed_dict)
                writer_scalar_itr.add_summary(summary_loss_itr, global_step)
                writer_image.add_summary(summary_image, global_step)

            print_logs('TRAIN',
                       mode,
                       epoch,
                       step_time,
                       idx,
                       data_loader.num_itr,
                       errs=errs,
                       coefs=trainer.coef_container,
                       lr=lr)
            global_step += 1
            idx += 1

        # SAVE LOGS
        errs_total_train = dict_operations(errs_total_train, '/',
                                           data_loader.num_itr)
        summary_loss_epoch = sess.run(trainer.summary_epoch,
                                      feed_dict=dict_operations(
                                          trainer.loss_epoch_placeholder, '=',
                                          errs_total_train))
        writer_scalar_epoch.add_summary(summary_loss_epoch, epoch)

        ## TEST
        errs_total_test = collections.OrderedDict.fromkeys(
            trainer.loss_test.keys(), 0.)
        epoch_time_test = time.time()
        idx = 0
        #while True:
        for idx in range(0, 2):
            step_time = time.time()

            feed_dict, is_end = data_loader_test.feed_the_network()
            if is_end: break
            feed_dict = trainer.adjust_loss_coef(feed_dict, epoch, errs)
            errs = sess.run(trainer.loss_test, feed_dict)

            errs_total_test = dict_operations(errs_total_test, '+', errs)
            print_logs('TEST',
                       mode,
                       epoch,
                       step_time,
                       idx,
                       data_loader_test.num_itr,
                       errs=errs,
                       coefs=trainer.coef_container)
            idx += 1

        # SAVE LOGS
        errs_total_test = dict_operations(errs_total_test, '/',
                                          data_loader_test.num_itr)
        summary_loss_epoch_test = sess.run(trainer.summary_epoch,
                                           feed_dict=dict_operations(
                                               trainer.loss_epoch_placeholder,
                                               '=', errs_total_test))
        writer_scalar_epoch_valid.add_summary(summary_loss_epoch_test, epoch)

        ## CKPT
        if epoch % config.write_ckpt_every_epoch == 0:
            ckpt_manager.save_ckpt(sess,
                                   trainer.save_vars,
                                   epoch,
                                   score=errs_total_test['total'])

        ## RESET IMAGE SUMMARY
        if epoch % config.refresh_image_log_every_epoch == 0:
            writer_image.close()
            remove_file_end_with(config.LOG_DIR.log_image, '*.image_log')
            writer_image.reopen()

        print_logs('TRAIN SUMMARY',
                   mode,
                   epoch,
                   epoch_time,
                   errs=errs_total_train)
        print_logs('TEST SUMMARY',
                   mode,
                   epoch,
                   epoch_time_test,
                   errs=errs_total_test)