def restore_model(logs, sess, restore_path, var_list=None): # Restore variables from pretrained model (restore path) print_and_write(logs, "Restoring Model in " + str(restore_path)) sess.run(tf.global_variables_initializer()) var_list = tf.global_variables() if var_list is None else var_list restorer = tf.train.Saver(var_list=var_list) restorer.restore(sess, restore_path)
def _updt_val_logs(self, logs, val_outputs, step): super()._updt_val_logs(logs, val_outputs, step) print_and_write( logs, "\nValidation Loss: " + str(np.mean(np.array([o["loss"] for o in val_outputs])))) print_and_write( logs, " Validation PPL: " + str(np.mean(np.array([o["ppl"] for o in val_outputs]))) + "\n")
def save_model(sess, logs, global_step, saver, save_path): """ Model Checkpoint :param sess: tf.Session :param logs: log file :param global_step: integer :param saver: tf.Saver() object :param save_path: file path where to save the checkpoint :return: """ print_and_write( logs, "\nSaving Best Model at step " + str(global_step) + "\n") saver.save(sess, save_path)
def _validation_op(self, sess, logs, val_summaries, step, val_size, best_loss, best_step, stop_cond, restore_cond, saver): """ Runs the model on the validation (or part) :param sess: tf.Session() object :param logs: an opened file where to save logs :param val_summaries: a list of summaries for validation :param step: current time step :param val_size: number of examples from the validation to assess :param best_loss: current best loss value :param best_step: current best step :param stop_cond: Number of steps, without improvements :param restore_cond: Number of steps without improvements since the last restore :param saver: tf.train.Saver() object :return: Updated values of best_loss, stop_cond, restore_cond, best_step """ print_and_write(logs, "STEP:" + str(step)) print_and_write(logs, "\n\n VALIDATION EVALUATION \n\n") start = 0 val_outputs = [] for j in range(len(self.dataset.val_x) // val_size): val_batch = self._get_val_batch((start, start + val_size)) val_outputs.append( self.model_val.val_op(sess, val_batch, val_summaries)) start += val_size self._updt_val_logs(logs, val_outputs, step) val_loss = np.mean( np.array([o["to_optimize_loss"] for o in val_outputs])) best_loss, stop_cond, restore_cond, best_step = self.checkpoint_manager( sess, logs, val_loss, best_loss, best_step, stop_cond, restore_cond, global_step=step, saver=saver, save_path=self.save_path) return best_loss, stop_cond, restore_cond, best_step
def _run_train_session(self, logs, train_summaries, val_summaries, val_size): saver = tf.train.Saver() with tf.Session(config=self.get_config_proto()) as sess: sess.run(tf.compat.v1.global_variables_initializer()) if self.config.restore_model: restore_path = os.path.join(self.config.restore_path, "model.ckpt") self.restore_model(logs, sess, restore_path) tf.summary.FileWriter(self.summary_writer_path, sess.graph) self.dataset.initialize(sess) step = 0 best_step = -1 best_loss, stop_cond, restore_cond, max_no_improve = self.get_initial_exp_conds( ) # early stops conditions batch_size = self.config.learning.batch_size for e in range(self.config.learning.n_epochs): print_and_write(logs, "Epoch: " + str(e)) for batch in self.dataset.get_batches(batch_size): batch_output = self.model.run_train_op( sess, batch, train_summaries) # train step if step % 100 == 0: self._updt_train_logs(logs, batch_output, step) step += 1 best_loss, stop_cond, restore_cond, best_step = self._validation_op( sess, logs, val_summaries, step, val_size, best_loss, best_step, stop_cond, restore_cond, saver) if self.early_stop(logs, stop_cond, max_no_improve): break self.callback_hooks(sess) return best_loss, best_step
def early_stop(logs, stop_cond, max_no_improve): if stop_cond >= max_no_improve: print_and_write(logs, "\nEARLY STOPPING\n") return True return False
def _updt_train_logs(self, logs, batch_outputs, step): super()._updt_train_logs(logs, batch_outputs, step) print_and_write(logs, "\nLoss: " + str(batch_outputs["loss"])) print_and_write(logs, " PPL: " + str(batch_outputs["ppl"]))