コード例 #1
0
 def restore_model(logs, sess, restore_path, var_list=None):
     # Restore variables from pretrained model (restore path)
     print_and_write(logs, "Restoring Model in " + str(restore_path))
     sess.run(tf.global_variables_initializer())
     var_list = tf.global_variables() if var_list is None else var_list
     restorer = tf.train.Saver(var_list=var_list)
     restorer.restore(sess, restore_path)
コード例 #2
0
    def _updt_val_logs(self, logs, val_outputs, step):
        super()._updt_val_logs(logs, val_outputs, step)

        print_and_write(
            logs, "\nValidation Loss: " +
            str(np.mean(np.array([o["loss"] for o in val_outputs]))))
        print_and_write(
            logs, " Validation PPL: " +
            str(np.mean(np.array([o["ppl"] for o in val_outputs]))) + "\n")
コード例 #3
0
 def save_model(sess, logs, global_step, saver, save_path):
     """
     Model Checkpoint
     :param sess: tf.Session
     :param logs: log file
     :param global_step: integer
     :param saver: tf.Saver() object
     :param save_path: file path where to save the checkpoint
     :return:
     """
     print_and_write(
         logs, "\nSaving Best Model at step " + str(global_step) + "\n")
     saver.save(sess, save_path)
コード例 #4
0
    def _validation_op(self, sess, logs, val_summaries, step, val_size,
                       best_loss, best_step, stop_cond, restore_cond, saver):
        """
        Runs the model on the validation (or part)
        :param sess: tf.Session() object
        :param logs: an opened file where to save logs
        :param val_summaries: a list of summaries for validation
        :param step: current time step
        :param val_size: number of examples from the validation to assess
        :param best_loss: current best loss value
        :param best_step: current best step
        :param stop_cond: Number of steps, without improvements
        :param restore_cond: Number of steps without improvements since the last restore
        :param saver: tf.train.Saver() object
        :return: Updated values of best_loss, stop_cond, restore_cond, best_step
        """

        print_and_write(logs, "STEP:" + str(step))
        print_and_write(logs, "\n\n VALIDATION EVALUATION \n\n")

        start = 0
        val_outputs = []
        for j in range(len(self.dataset.val_x) // val_size):
            val_batch = self._get_val_batch((start, start + val_size))
            val_outputs.append(
                self.model_val.val_op(sess, val_batch, val_summaries))
            start += val_size

        self._updt_val_logs(logs, val_outputs, step)

        val_loss = np.mean(
            np.array([o["to_optimize_loss"] for o in val_outputs]))
        best_loss, stop_cond, restore_cond, best_step = self.checkpoint_manager(
            sess,
            logs,
            val_loss,
            best_loss,
            best_step,
            stop_cond,
            restore_cond,
            global_step=step,
            saver=saver,
            save_path=self.save_path)

        return best_loss, stop_cond, restore_cond, best_step
コード例 #5
0
    def _run_train_session(self, logs, train_summaries, val_summaries,
                           val_size):
        saver = tf.train.Saver()
        with tf.Session(config=self.get_config_proto()) as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            if self.config.restore_model:
                restore_path = os.path.join(self.config.restore_path,
                                            "model.ckpt")
                self.restore_model(logs, sess, restore_path)

            tf.summary.FileWriter(self.summary_writer_path, sess.graph)

            self.dataset.initialize(sess)

            step = 0
            best_step = -1
            best_loss, stop_cond, restore_cond, max_no_improve = self.get_initial_exp_conds(
            )  # early stops conditions
            batch_size = self.config.learning.batch_size

            for e in range(self.config.learning.n_epochs):
                print_and_write(logs, "Epoch: " + str(e))
                for batch in self.dataset.get_batches(batch_size):
                    batch_output = self.model.run_train_op(
                        sess, batch, train_summaries)  # train step

                    if step % 100 == 0:
                        self._updt_train_logs(logs, batch_output, step)
                    step += 1

                best_loss, stop_cond, restore_cond, best_step = self._validation_op(
                    sess, logs, val_summaries, step, val_size, best_loss,
                    best_step, stop_cond, restore_cond, saver)

                if self.early_stop(logs, stop_cond, max_no_improve):
                    break

            self.callback_hooks(sess)
        return best_loss, best_step
コード例 #6
0
    def early_stop(logs, stop_cond, max_no_improve):
        if stop_cond >= max_no_improve:
            print_and_write(logs, "\nEARLY STOPPING\n")
            return True

        return False
コード例 #7
0
    def _updt_train_logs(self, logs, batch_outputs, step):
        super()._updt_train_logs(logs, batch_outputs, step)

        print_and_write(logs, "\nLoss: " + str(batch_outputs["loss"]))
        print_and_write(logs, " PPL: " + str(batch_outputs["ppl"]))