Esempio n. 1
0
def _build_model_use_tfdata(test_set_tfrecords_dir, ckpt_dir):
    '''
  test_set_tfrecords_dir: '/xxx/xxx/*.tfrecords'
  '''
    g = tf.Graph()
    with g.as_default():
        # region TFRecord+DataSet
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                x_batch, y_batch, Xtheta_batch, Ytheta_batch, lengths_batch, iter_test = get_batch_use_tfdata(
                    test_set_tfrecords_dir, get_theta=True)

        with tf.name_scope('model'):
            test_model = PARAM.SE_MODEL(x_batch,
                                        lengths_batch,
                                        y_batch,
                                        Xtheta_batch,
                                        Ytheta_batch,
                                        behavior=PARAM.SE_MODEL.infer)

        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.Session(config=config)
        sess.run(init)

        ckpt = tf.train.get_checkpoint_state(
            os.path.join(PARAM.SAVE_DIR, ckpt_dir))
        test_model.saver.restore(sess, ckpt.model_checkpoint_path)
    g.finalize()
    return sess, test_model, iter_test
Esempio n. 2
0
def build_session(ckpt_dir, batch_size, finalizeG=True):
    batch_size = None  # None is OK
    g = tf.Graph()
    with g.as_default():
        with tf.device('/cpu:0'):
            with tf.name_scope('input'):
                x_batch = tf.placeholder(
                    tf.float32,
                    shape=[batch_size, None, PARAM.FFT_DOT],
                    name='x_batch')
                lengths_batch = tf.placeholder(tf.int32,
                                               shape=[batch_size],
                                               name='lengths_batch')
                y_batch = tf.placeholder(
                    tf.float32,
                    shape=[batch_size, None, PARAM.FFT_DOT],
                    name='y_batch')
                x_theta = tf.placeholder(
                    tf.float32,
                    shape=[batch_size, None, PARAM.FFT_DOT],
                    name='x_theta')
                y_theta = tf.placeholder(
                    tf.float32,
                    shape=[batch_size, None, PARAM.FFT_DOT],
                    name='y_theta')
        with tf.name_scope('model'):
            model = PARAM.SE_MODEL(x_batch,
                                   lengths_batch,
                                   y_batch,
                                   x_theta,
                                   y_theta,
                                   behavior=PARAM.SE_MODEL.infer)

        init = tf.group(tf.global_variables_initializer(),
                        tf.local_variables_initializer())

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.Session(config=config)
        sess.run(init)

        ckpt = tf.train.get_checkpoint_state(
            os.path.join(PARAM.SAVE_DIR, ckpt_dir))
        if ckpt and ckpt.model_checkpoint_path:
            tf.logging.info("Restore from " + ckpt.model_checkpoint_path)
            model.saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            tf.logging.fatal("checkpoint not found.")
            sys.exit(-1)
    if finalizeG:
        g.finalize()
        return sess, model
    return sess, model, g
Esempio n. 3
0
def train():

  g = tf.Graph()
  with g.as_default():
    # region TFRecord+DataSet
    with tf.device('/cpu:0'):
      with tf.name_scope('input'):
        train_tfrecords, val_tfrecords, _, _ = generate_tfrecord(
            gen=PARAM.GENERATE_TFRECORD)
        if PARAM.GENERATE_TFRECORD:
          print("TFRecords preparation over.")
          # exit(0)  # set gen=True and exit to generate tfrecords

        PSIRM = True if PARAM.PIPLINE_GET_THETA else False
        x_batch_tr, y_batch_tr, Xtheta_batch_tr, Ytheta_batch_tr, lengths_batch_tr, iter_train = get_batch_use_tfdata(
            train_tfrecords,
            get_theta=PSIRM)
        x_batch_val, y_batch_val,  Xtheta_batch_val, Ytheta_batch_val, lengths_batch_val, iter_val = get_batch_use_tfdata(
            val_tfrecords,
            get_theta=PSIRM)
    # endregion

    # build model
    with tf.name_scope('model'):
      tr_model = PARAM.SE_MODEL(x_batch_tr,
                                lengths_batch_tr,
                                y_batch_tr,
                                Xtheta_batch_tr,
                                Ytheta_batch_tr,
                                PARAM.SE_MODEL.train)
      tf.get_variable_scope().reuse_variables()
      val_model = PARAM.SE_MODEL(x_batch_val,
                                 lengths_batch_val,
                                 y_batch_val,
                                 Xtheta_batch_val,
                                 Ytheta_batch_val,
                                 PARAM.SE_MODEL.validation)

    utils.tf_tool.show_all_variables()
    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = PARAM.GPU_RAM_ALLOW_GROWTH
    config.allow_soft_placement = False
    sess = tf.Session(config=config)
    sess.run(init)

    # region resume training
    if PARAM.resume_training.lower() == 'true':
      ckpt = tf.train.get_checkpoint_state(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT))
      if ckpt and ckpt.model_checkpoint_path:
        # tf.logging.info("restore from" + ckpt.model_checkpoint_path)
        tr_model.saver.restore(sess, ckpt.model_checkpoint_path)
        best_path = ckpt.model_checkpoint_path
      else:
        tf.logging.fatal("checkpoint not found")
      with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f:
        f.writelines('Training resumed.\n')
    else:
      if os.path.exists(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log')):
        os.remove(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'))
    # endregion

    # region validation before training.
    valstart_time = time.time()
    sess.run(iter_val.initializer)
    mag_loss_prev, logmag_loss_prev = eval_one_epoch(sess,
                                                     val_model)
    cross_val_msg = "CROSSVAL PRERUN LOGBIASNET_LOSS_MASKNET_LOSS (%.4F,%.4F),  Costime %dS" % (
        mag_loss_prev, logmag_loss_prev, time.time()-valstart_time)
    tf.logging.info(cross_val_msg)
    with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f:
      f.writelines(cross_val_msg+'\n')

    tr_model.assign_lr(sess, PARAM.learning_rate_logbiasnet, PARAM.learning_rate_masknet)
    g.finalize()
    # endregion

    # epochs training
    reject_num = 0
    for epoch in range(PARAM.start_epoch, PARAM.max_epochs):
      sess.run([iter_train.initializer, iter_val.initializer])
      start_time = time.time()

      # train one epoch
      tr_mag_loss, tr_logmag_loss, logbiasnet_lr, masknet_lr, log_bias = train_one_epoch(sess,
                                                                                         tr_model)
      # Validation
      val_mag_loss, val_logmag_loss = eval_one_epoch(sess,
                                                     val_model)
      end_time = time.time()

      # Determine checkpoint path
      ckpt_name = "nnet_iter%d_lrate(%e,%e)_trloss(%.4f,%.4f)_cvloss(%.4f,%.4f)_avglogbias%f_duration%ds" % (
          epoch + 1, logbiasnet_lr, masknet_lr, tr_mag_loss, tr_logmag_loss, val_mag_loss, val_logmag_loss, np.mean(log_bias), end_time - start_time)
      ckpt_dir = os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT)
      if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
      ckpt_path = os.path.join(ckpt_dir, ckpt_name)

      # Relative loss between previous and current val_loss
      rel_impr_magloss = np.abs(mag_loss_prev - val_mag_loss) / mag_loss_prev
      rel_impr_logmagloss = np.abs(logmag_loss_prev - val_logmag_loss) / logmag_loss_prev
      # Accept or reject new parameters
      msg = ""
      if (PARAM.TRAIN_TYPE == 'BOTH' and (val_mag_loss < mag_loss_prev) and (val_logmag_loss < logmag_loss_prev)) or (
              PARAM.TRAIN_TYPE == 'LOGBIASNET' and (val_mag_loss < mag_loss_prev)) or (
              PARAM.TRAIN_TYPE == 'MASKNET' and (val_logmag_loss < logmag_loss_prev)):
        reject_num = 0
        tr_model.saver.save(sess, ckpt_path)
        # Logging train loss along with validation loss
        mag_loss_prev = val_mag_loss
        logmag_loss_prev = val_logmag_loss
        best_path = ckpt_path
        msg = ("Train Iteration %03d: \n"
               "    Train.LOSS (%.4f,%.4f), lrate(%e,%e), Val.LOSS (%.4f,%.4f), avglog_bias %f,\n"
               "    %s, ckpt(%s) saved,\n"
               "    EPOCH DURATION: %.2fs\n") % (
            epoch + 1,
            tr_mag_loss, tr_logmag_loss, logbiasnet_lr, masknet_lr, val_mag_loss, val_logmag_loss, np.mean(log_bias),
            "NNET Accepted", ckpt_name, end_time - start_time)
        tf.logging.info(msg)
      else:
        reject_num += 1
        tr_model.saver.restore(sess, best_path)
        msg = ("Train Iteration %03d: \n"
               "    Train.LOSS (%.4f,%.4f), lrate(%e,%e), Val.LOSS (%.4f,%.4f), avglog_bias %f,\n"
               "    %s, ckpt(%s) abandoned,\n"
               "    EPOCH DURATION: %.2fs\n") % (
            epoch + 1,
            tr_mag_loss, tr_logmag_loss, logbiasnet_lr, masknet_lr, val_logmag_loss, val_logmag_loss, np.mean(log_bias),
            "NNET Rejected", ckpt_name, end_time - start_time)
        tf.logging.info(msg)
      with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f:
        f.writelines(msg+'\n')

      # Start halving when improvement is lower than start_halving_impr
      if PARAM.TRAIN_TYPE == 'BOTH':
        if (rel_impr_magloss < PARAM.start_halving_impr) or (rel_impr_logmagloss < PARAM.start_halving_impr) or (reject_num >= 2):
          reject_num = 0
          logbiasnet_lr *= PARAM.halving_factor
          masknet_lr *= PARAM.halving_factor
          tr_model.assign_lr(sess,logbiasnet_lr,masknet_lr)
      elif PARAM.TRAIN_TYPE == 'LOGBIASNET':
        if (rel_impr_magloss < PARAM.start_halving_impr) or (reject_num >= 2):
          reject_num = 0
          logbiasnet_lr *= PARAM.halving_factor
          tr_model.assign_lr(sess,logbiasnet_lr,masknet_lr)
      elif PARAM.TRAIN_TYPE == 'MASKNET':
        if (rel_impr_logmagloss < PARAM.start_halving_impr) or (reject_num >= 2):
          reject_num = 0
          masknet_lr *= PARAM.halving_factor
          tr_model.assign_lr(sess,logbiasnet_lr,masknet_lr)
      else:
        print('Train type error.')
        exit(-1)

      # Stopping criterion
      if rel_impr_magloss < PARAM.end_halving_impr and rel_impr_logmagloss < PARAM.end_halving_impr:
        if epoch < PARAM.min_epochs:
          tf.logging.info(
              "we were supposed to finish, but we continue as "
              "min_epochs : %s" % PARAM.min_epochs)
          continue
        else:
          tf.logging.info(
              "finished, too small rel. improvement (%g,%g)" % (rel_impr_magloss, rel_impr_logmagloss))
          break

    sess.close()
    tf.logging.info("Done training")
Esempio n. 4
0
    def __init__(self,
                 x_mag_spec_batch,
                 lengths_batch,
                 y_mag_spec_batch=None,
                 theta_x_batch=None,
                 theta_y_batch=None,
                 infer=False):
        self._log_bias = tf.get_variable('logbias', [1],
                                         trainable=PARAM.LOG_BIAS_TRAINABEL,
                                         initializer=tf.constant_initializer(
                                             PARAM.INIT_LOG_BIAS))
        self._real_logbias = self._log_bias + DEFAULT_LOG_BIAS
        self._inputs = x_mag_spec_batch
        self._x_mag_spec = self.inputs
        self._norm_x_mag_spec = norm_mag_spec(self._x_mag_spec)
        self._norm_x_logmag_spec = norm_logmag_spec(self._x_mag_spec,
                                                    self._log_bias)

        if not infer:
            self._y_mag_spec = y_mag_spec_batch
            self._norm_y_mag_spec = norm_mag_spec(self._y_mag_spec)
            self._norm_y_logmag_spec = norm_logmag_spec(
                self._y_mag_spec, self._log_bias)

        self._lengths = lengths_batch

        self.batch_size = tf.shape(self._lengths)[0]
        self._model_type = PARAM.MODEL_TYPE

        if PARAM.INPUT_TYPE == 'mag':
            self.net_input = self._norm_x_mag_spec
        elif PARAM.INPUT_TYPE == 'logmag':
            self.net_input = self._norm_x_logmag_spec

        if not infer:
            if PARAM.LABEL_TYPE == 'mag':
                self._labels = self._norm_y_mag_spec
            elif PARAM.LABEL_TYPE == 'logmag':
                self._labels = self._norm_y_logmag_spec

        outputs = self.net_input

        def lstm_cell():
            return tf.contrib.rnn.LSTMCell(
                PARAM.RNN_SIZE,
                forget_bias=1.0,
                use_peepholes=True,
                num_proj=PARAM.LSTM_num_proj,
                initializer=tf.contrib.layers.xavier_initializer(),
                state_is_tuple=True,
                activation=PARAM.LSTM_ACTIVATION)

        lstm_attn_cell = lstm_cell
        if not infer and PARAM.KEEP_PROB < 1.0:

            def lstm_attn_cell():
                return tf.contrib.rnn.DropoutWrapper(
                    lstm_cell(), output_keep_prob=PARAM.KEEP_PROB)

        def GRU_cell():
            return tf.contrib.rnn.GRUCell(
                PARAM.RNN_SIZE,
                # kernel_initializer=tf.contrib.layers.xavier_initializer(),
                activation=PARAM.LSTM_ACTIVATION)

        GRU_attn_cell = lstm_cell
        if not infer and PARAM.KEEP_PROB < 1.0:

            def GRU_attn_cell():
                return tf.contrib.rnn.DropoutWrapper(
                    GRU_cell(), output_keep_prob=PARAM.KEEP_PROB)

        if PARAM.MODEL_TYPE.upper() == 'BLSTM':
            with tf.variable_scope('BLSTM'):

                lstm_fw_cell = tf.contrib.rnn.MultiRNNCell(
                    [lstm_attn_cell() for _ in range(PARAM.RNN_LAYER)],
                    state_is_tuple=True)
                lstm_bw_cell = tf.contrib.rnn.MultiRNNCell(
                    [lstm_attn_cell() for _ in range(PARAM.RNN_LAYER)],
                    state_is_tuple=True)

                lstm_fw_cell = lstm_fw_cell._cells
                lstm_bw_cell = lstm_bw_cell._cells
                result = rnn.stack_bidirectional_dynamic_rnn(
                    cells_fw=lstm_fw_cell,
                    cells_bw=lstm_bw_cell,
                    inputs=outputs,
                    dtype=tf.float32,
                    sequence_length=self._lengths)
                outputs, fw_final_states, bw_final_states = result
        if PARAM.MODEL_TYPE.upper() == 'BGRU':
            with tf.variable_scope('BGRU'):

                gru_fw_cell = tf.contrib.rnn.MultiRNNCell(
                    [GRU_attn_cell() for _ in range(PARAM.RNN_LAYER)],
                    state_is_tuple=True)
                gru_bw_cell = tf.contrib.rnn.MultiRNNCell(
                    [GRU_attn_cell() for _ in range(PARAM.RNN_LAYER)],
                    state_is_tuple=True)

                gru_fw_cell = gru_fw_cell._cells
                gru_bw_cell = gru_bw_cell._cells
                result = rnn.stack_bidirectional_dynamic_rnn(
                    cells_fw=gru_fw_cell,
                    cells_bw=gru_bw_cell,
                    inputs=outputs,
                    dtype=tf.float32,
                    sequence_length=self._lengths)
                outputs, fw_final_states, bw_final_states = result

        with tf.variable_scope('fullconnectOut'):
            if self._model_type.upper()[0] == 'B':  # bidirection
                outputs = tf.reshape(outputs, [-1, 2 * PARAM.LSTM_num_proj])
                in_size = 2 * PARAM.LSTM_num_proj
            out_size = PARAM.OUTPUT_SIZE
            weights = tf.get_variable(
                'weights1', [in_size, out_size],
                initializer=tf.random_normal_initializer(stddev=0.01))
            biases = tf.get_variable('biases1', [out_size],
                                     initializer=tf.constant_initializer(0.0))
            mask = tf.nn.relu(tf.matmul(outputs, weights) + biases)
            self._mask = tf.reshape(mask,
                                    [self.batch_size, -1, PARAM.OUTPUT_SIZE])

        self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=30)
        if infer:
            if PARAM.DECODING_MASK_POSITION == 'mag':
                self._cleaned = rm_norm_mag_spec(self._mask *
                                                 self._norm_x_mag_spec)
            elif PARAM.DECODING_MASK_POSITION == 'logmag':
                self._cleaned = rm_norm_logmag_spec(
                    self._mask * self._norm_x_logmag_spec, self._log_bias)
            return

        if PARAM.TRAINING_MASK_POSITION == 'mag':
            self._cleaned = self._mask * self._norm_x_mag_spec
        elif PARAM.TRAINING_MASK_POSITION == 'logmag':
            self._cleaned = self._mask * self._norm_x_logmag_spec
        if PARAM.MASK_TYPE == 'PSM':
            self._labels *= tf.cos(theta_x_batch - theta_y_batch)
        elif PARAM.MASK_TYPE == 'IRM':
            pass
        else:
            tf.logging.error('Mask type error.')
            exit(-1)

        if PARAM.TRAINING_MASK_POSITION != PARAM.LABEL_TYPE:
            if PARAM.LABEL_TYPE == 'mag':
                self._cleaned = normedLogmag2normedMag(self._cleaned,
                                                       self._log_bias)
            elif PARAM.LABEL_TYPE == 'logmag':
                self._cleaned = normedMag2normedLogmag(self._cleaned,
                                                       self._log_bias)
        self._loss = PARAM.LOSS_FUNC(self._cleaned, self._labels)
        if tf.get_variable_scope().reuse:
            return

        self._lr = tf.Variable(0.0, trainable=False)
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars),
                                          PARAM.CLIP_NORM)
        optimizer = tf.train.AdamOptimizer(self.lr)
        #optimizer = tf.train.GradientDescentOptimizer(self.lr)
        self._train_op = optimizer.apply_gradients(zip(grads, tvars))

        self._new_lr = tf.placeholder(tf.float32,
                                      shape=[],
                                      name='new_learning_rate')
        self._lr_update = tf.assign(self._lr, self._new_lr)
Esempio n. 5
0
def train():

  g = tf.Graph()
  with g.as_default():
    # region TFRecord+DataSet
    with tf.device('/cpu:0'):
      with tf.name_scope('input'):
        train_tfrecords, val_tfrecords, _, _ = generate_tfrecord(
            gen=PARAM.GENERATE_TFRECORD)
        if PARAM.GENERATE_TFRECORD:
          tf.logging.info("TFRecords preparation over.")
          # exit(0)  # set gen=True and exit to generate tfrecords

        PSIRM = True if PARAM.PIPLINE_GET_THETA else False
        x_batch_tr, y_batch_tr, Xtheta_batch_tr, Ytheta_batch_tr, lengths_batch_tr, iter_train = get_batch_use_tfdata(
            train_tfrecords,
            get_theta=PSIRM)
        x_batch_val, y_batch_val,  Xtheta_batch_val, Ytheta_batch_val, lengths_batch_val, iter_val = get_batch_use_tfdata(
            val_tfrecords,
            get_theta=PSIRM)
    # endregion

    # build model
    with tf.name_scope('model'):
      tr_model = PARAM.SE_MODEL(x_batch_tr,
                                lengths_batch_tr,
                                y_batch_tr,
                                Xtheta_batch_tr,
                                Ytheta_batch_tr,
                                PARAM.SE_MODEL.train)
      tf.get_variable_scope().reuse_variables()
      val_model = PARAM.SE_MODEL(x_batch_val,
                                 lengths_batch_val,
                                 y_batch_val,
                                 Xtheta_batch_val,
                                 Ytheta_batch_val,
                                 PARAM.SE_MODEL.validation)

    utils.tf_tool.show_all_variables()
    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = PARAM.GPU_RAM_ALLOW_GROWTH
    config.allow_soft_placement = False
    # config.gpu_options.per_process_gpu_memory_fraction = 0.55
    sess = tf.Session(config=config)
    sess.run(init)

    # region resume training
    if PARAM.resume_training.lower() == 'true':
      ckpt = tf.train.get_checkpoint_state(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT))
      if ckpt and ckpt.model_checkpoint_path:
        # tf.logging.info("restore from" + ckpt.model_checkpoint_path)
        tr_model.saver.restore(sess, ckpt.model_checkpoint_path)
        best_path = ckpt.model_checkpoint_path
      else:
        tf.logging.fatal("checkpoint not found")
      with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f:
        f.writelines('Training resumed.\n')
    else:
      if os.path.exists(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log')):
        os.remove(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'))
    # endregion

    # region validation before training.
    valstart_time = time.time()
    sess.run(iter_val.initializer)
    loss_prev = eval_one_epoch(sess,
                               val_model)
    cross_val_msg = "\n\nCROSSVAL PRERUN AVG.LOSS %.4F  costime %dS\n" % (
        loss_prev, time.time()-valstart_time)
    tf.logging.info(cross_val_msg)
    with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f:
      f.writelines(cross_val_msg+'\n')

    g.finalize()
    # endregion

    # epochs training # TODO resume lr form model_name
    lr = PARAM.learning_rate
    tr_model.assign_lr(sess, lr)
    # if PARAM.resume_training.lower() == 'true':
    #   lr = sess.run(tr_model.lr)
    # else:
    #   lr = PARAM.learning_rate
    #   tr_model.assign_lr(sess, lr)
    lr_halving_time = 0
    for epoch in range(PARAM.start_epoch, PARAM.max_epochs):
      sess.run([iter_train.initializer, iter_val.initializer])
      start_time = time.time()

      # train one epoch
      tr_loss, model_lr, log_bias = train_one_epoch(sess,
                                                    tr_model)

      # Validation
      val_loss = eval_one_epoch(sess,
                                val_model)

      end_time = time.time()

      # Determine checkpoint path
      ckpt_name = "nnet_iter%d_lrate%e_trloss%.4f_cvloss%.4f_avglogbias%f_duration%ds" % (
          epoch + 1, model_lr, tr_loss, val_loss, np.mean(log_bias), end_time - start_time)
      ckpt_dir = os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT)
      if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
      ckpt_path = os.path.join(ckpt_dir, ckpt_name)

      # Relative loss between previous and current val_loss
      rel_impr = (loss_prev - val_loss) / loss_prev
      # Accept or reject new parameters
      msg = ""
      if val_loss < loss_prev:
        tr_model.saver.save(sess, ckpt_path)
        # Logging train loss along with validation loss
        loss_prev = val_loss
        best_path = ckpt_path
        msg = ("Train Iteration %03d: \n"
               "    Train.LOSS %.4f, lrate %e, Val.LOSS %.4f, avglog_bias %f,\n"
               "    %s, ckpt(%s) saved,\n"
               "    EPOCH DURATION: %.2fs\n") % (
            epoch + 1,
            tr_loss, model_lr, val_loss, np.mean(log_bias),
            "NNET Accepted", ckpt_name, end_time - start_time)
        tf.logging.info(msg)
      else:
        tr_model.saver.restore(sess, best_path)
        msg = ("Train Iteration %03d: \n"
               "    Train.LOSS %.4f, lrate%e, Val.LOSS %.4f, avglog_bias %f,\n"
               "    %s, ckpt(%s) abandoned,\n"
               "    EPOCH DURATION: %.2fs\n") % (
            epoch + 1,
            tr_loss, model_lr, val_loss, np.mean(log_bias),
            "NNET Rejected", ckpt_name, end_time - start_time)
        tf.logging.info(msg)
      with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f:
        f.writelines(msg+'\n')

      # Start halving when improvement is lower than start_halving_impr
      if rel_impr < PARAM.start_halving_impr:
        lr *= PARAM.halving_factor
        lr_halving_time += 1
        tr_model.assign_lr(sess, lr)

      # Stopping criterion
      if rel_impr < PARAM.end_halving_impr:
        if (epoch < PARAM.min_epochs) or (lr_halving_time<=PARAM.max_lr_halving_time):
          tf.logging.info(
              "we were supposed to finish, but we continue as "
              "now_epoch<=min_epochs(%s<=%s),"
              " or lr_halving_time<=max_lr_halving_time(%s<=%s)"
              "." % (epoch+1,PARAM.min_epochs,lr_halving_time,PARAM.max_lr_halving_time))
          continue
        else:
          tf.logging.info(
              "finished, too small retive improvement %g." % rel_impr)
          break

    sess.close()
    tf.logging.info("Done training")