def train(self): # pylint: disable=too-many-locals """Train the model.""" mode = utils.TRAIN train_model = self.build(mode) # Supervisor with tf.name_scope("train"): global_step = tf.train.get_or_create_global_step() train_op = self.get_train_op(train_model.loss_op, global_step) checkpoint_dir = get_checkpoint_dir(self.config) # scaffold scaffold = self.get_scaffold(mode, global_step, train_model.iterator.initializer) with tf.train.MonitoredTrainingSession( checkpoint_dir=checkpoint_dir, scaffold=scaffold, save_checkpoint_steps=self.save_checkpoint_steps, config=self.session_conf) as sess: # Training loop. For each batch... data_size = self.config['data']['train_data_size'] num_epochs = self.config["data"]["task"]['epochs'] num_batch = int(math.ceil(data_size * num_epochs / self.batch_size)) num_batch_per_epoch = int(data_size / self.batch_size) logging.info( "num_batch: {}, num_batch_per_epoch: {}, num_epochs: {}".format( num_batch, num_batch_per_epoch, num_epochs)) for i in range(num_batch): _, _, out_loss = sess.run([train_op, global_step, train_model.loss_op]) if i % self.print_every == 0 or i == num_batch - 1: logging.info("Training for epoch {}: [ {:.2%} ] loss is {:g}".format( int(i / num_batch_per_epoch), (i % num_batch_per_epoch) / num_batch_per_epoch, out_loss))
def train_and_eval(self): # pylint: disable=too-many-locals """Train and evaluate the model.""" # train related g_train = tf.Graph() with g_train.as_default(): logging.info("Compiling train model ...") train_model = self.build(utils.TRAIN) # eval related g_eval = tf.Graph() with g_eval.as_default(): logging.info("Compiling eval model ...") eval_model = self.build(utils.EVAL) eval_model.sess = tf.Session(config=self.session_conf, graph=g_eval) eval_model.saver = tf.train.Saver() # start train with g_train.as_default(): # Supervisor with tf.name_scope("train"): global_step = tf.train.get_or_create_global_step() train_op = self.get_train_op(train_model.loss_op, global_step) checkpoint_dir = get_checkpoint_dir(self.config) # scaffold scaffold = self.get_scaffold(utils.TRAIN, global_step, train_model.iterator.initializer) with tf.train.MonitoredTrainingSession( checkpoint_dir=checkpoint_dir, scaffold=scaffold, save_checkpoint_steps=self.save_checkpoint_steps, config=self.session_conf) as sess: # Training loop. For each batch... train_data_size = self.config['data']['train_data_size'] num_batch = math.ceil(train_data_size * self.num_epochs / self.batch_size) num_batch_per_epoch = math.ceil(train_data_size / self.batch_size) logging.info("Total data size: {}, batch num: {}, " "batch num per epoch: {}".format( train_data_size, num_batch, num_batch_per_epoch)) for i in range(0, num_batch): if i % self.save_checkpoint_steps == 0 and i != 0: self.eval_or_infer_core(eval_model, utils.EVAL) _, _, out_loss = sess.run( [train_op, global_step, train_model.loss_op]) if i % self.print_every == 0 or i == num_batch - 1 or ( i + 1 ) % num_batch_per_epoch == 0 or i % num_batch_per_epoch == 0: logging.info( "Training for epoch {}: [ {:.2%} ] loss is {:g}" .format(int(i / num_batch_per_epoch), (i % num_batch_per_epoch) / num_batch_per_epoch, out_loss)) eval_model.sess.close()
def __init__(self, config): super().__init__(config) self.model_compiled = False self.model_path = config['solver']['saver']['model_path'] self.checkpoint_dir = get_checkpoint_dir(self.config) self.session_conf = get_session_conf(self.config) self.session = tf.Session(config=self.session_conf) tf.keras.backend.set_session(self.session) self.metrics = self.get_metrics()