예제 #1
0
  def train(self):  # pylint: disable=too-many-locals
    """Train the model."""
    mode = utils.TRAIN
    train_model = self.build(mode)

    # Supervisor
    with tf.name_scope("train"):
      global_step = tf.train.get_or_create_global_step()
      train_op = self.get_train_op(train_model.loss_op, global_step)

      checkpoint_dir = get_checkpoint_dir(self.config)

      # scaffold
      scaffold = self.get_scaffold(mode, global_step,
                                   train_model.iterator.initializer)

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=checkpoint_dir,
        scaffold=scaffold,
        save_checkpoint_steps=self.save_checkpoint_steps,
        config=self.session_conf) as sess:
      # Training loop. For each batch...
      data_size = self.config['data']['train_data_size']
      num_epochs = self.config["data"]["task"]['epochs']
      num_batch = int(math.ceil(data_size * num_epochs / self.batch_size))
      num_batch_per_epoch = int(data_size / self.batch_size)
      logging.info(
          "num_batch: {}, num_batch_per_epoch: {}, num_epochs: {}".format(
              num_batch, num_batch_per_epoch, num_epochs))
      for i in range(num_batch):
        _, _, out_loss = sess.run([train_op, global_step, train_model.loss_op])
        if i % self.print_every == 0 or i == num_batch - 1:
          logging.info("Training for epoch {}: [ {:.2%} ] loss is {:g}".format(
              int(i / num_batch_per_epoch),
              (i % num_batch_per_epoch) / num_batch_per_epoch, out_loss))
예제 #2
0
    def train_and_eval(self):  # pylint: disable=too-many-locals
        """Train and evaluate the model."""
        # train related
        g_train = tf.Graph()
        with g_train.as_default():
            logging.info("Compiling train model ...")
            train_model = self.build(utils.TRAIN)
        # eval related
        g_eval = tf.Graph()
        with g_eval.as_default():
            logging.info("Compiling eval model ...")
            eval_model = self.build(utils.EVAL)
            eval_model.sess = tf.Session(config=self.session_conf,
                                         graph=g_eval)
            eval_model.saver = tf.train.Saver()

        # start train
        with g_train.as_default():
            # Supervisor
            with tf.name_scope("train"):
                global_step = tf.train.get_or_create_global_step()

                train_op = self.get_train_op(train_model.loss_op, global_step)

                checkpoint_dir = get_checkpoint_dir(self.config)

                # scaffold
                scaffold = self.get_scaffold(utils.TRAIN, global_step,
                                             train_model.iterator.initializer)

                with tf.train.MonitoredTrainingSession(
                        checkpoint_dir=checkpoint_dir,
                        scaffold=scaffold,
                        save_checkpoint_steps=self.save_checkpoint_steps,
                        config=self.session_conf) as sess:
                    # Training loop. For each batch...
                    train_data_size = self.config['data']['train_data_size']
                    num_batch = math.ceil(train_data_size * self.num_epochs /
                                          self.batch_size)
                    num_batch_per_epoch = math.ceil(train_data_size /
                                                    self.batch_size)
                    logging.info("Total data size: {}, batch num: {}, "
                                 "batch num per epoch: {}".format(
                                     train_data_size, num_batch,
                                     num_batch_per_epoch))
                    for i in range(0, num_batch):

                        if i % self.save_checkpoint_steps == 0 and i != 0:
                            self.eval_or_infer_core(eval_model, utils.EVAL)
                        _, _, out_loss = sess.run(
                            [train_op, global_step, train_model.loss_op])
                        if i % self.print_every == 0 or i == num_batch - 1 or (
                                i + 1
                        ) % num_batch_per_epoch == 0 or i % num_batch_per_epoch == 0:
                            logging.info(
                                "Training for epoch {}: [ {:.2%} ] loss is {:g}"
                                .format(int(i / num_batch_per_epoch),
                                        (i % num_batch_per_epoch) /
                                        num_batch_per_epoch, out_loss))
        eval_model.sess.close()
예제 #3
0
 def __init__(self, config):
     super().__init__(config)
     self.model_compiled = False
     self.model_path = config['solver']['saver']['model_path']
     self.checkpoint_dir = get_checkpoint_dir(self.config)
     self.session_conf = get_session_conf(self.config)
     self.session = tf.Session(config=self.session_conf)
     tf.keras.backend.set_session(self.session)
     self.metrics = self.get_metrics()