Example #1
0
    def export(self, estimator: tf.estimator.Estimator):
        # Reload summaries and select best step
        LOGGER.info(f"Reloading summaries from {estimator.model_dir}")
        summaries = read_eval_metrics(estimator.eval_dir()).items()
        for step, metrics in sorted(summaries):
            LOGGER.info(f"- {step}: {metrics}")
        sorted_summaries = sorted(summaries, key=lambda t: t[1][self.metric])
        if self.mode == BestMode.INCREASE:
            best_step, best_metrics = sorted_summaries[-1]
        elif self.mode == BestMode.DECREASE:
            best_step, best_metrics = sorted_summaries[0]
        else:
            raise ValueError(f"Mode {self.mode} not recognized.")
        LOGGER.info(f"Best summary at step {best_step}: {best_metrics}")

        # List available checkpoints and select closes to best_step
        checkpoints = Path(estimator.model_dir).glob(_CHEKPOINT_PATTERN)
        checkpoint_steps = [
            int(re.findall(r"-(\d+).index", str(path))[0])
            for path in checkpoints
        ]
        selected_step = sorted(checkpoint_steps,
                               key=lambda step: abs(step - best_step))[0]
        LOGGER.info(f"Selected checkpoint {selected_step}")

        # Change checkpoint information
        with Path(estimator.model_dir, "checkpoint").open("r") as file:
            lines = file.read().split("\n")
            lines[0] = f'model_checkpoint_path: "model.ckpt-{selected_step}"'

        with Path(estimator.model_dir, "checkpoint").open("w") as file:
            file.write("\n".join(lines))

        # Check that change is effective
        global_step = estimator.get_variable_value("global_step")
        if not global_step == selected_step:
            msg = f"Changed checkpoint file to use step {selected_step}, but estimator uses {global_step}"
            raise ValueError(msg)

        # Log to MLFlow
        if self.use_mlflow:
            mlflow.log_metric(key=self.tag, value=global_step)
Example #2
0
 def __call__(
         self,
         estimator: tf.estimator.Estimator) -> tf.estimator.SessionRunHook:
     if estimator.config.is_chief:
         return _StopOnPredicateHook(
             partial(
                 _no_metric_improvement_fn,
                 eval_dir=estimator.eval_dir(),
                 min_steps=self.min_steps,
                 metric=self.metric,
                 max_steps_without_improvement=self.
                 max_steps_without_improvement,
                 mode=self.mode,
             ),
             run_every_secs=self.run_every_secs,
             run_every_steps=self.run_every_steps,
             final_step=self.final_step,
         )
     else:
         return _CheckForStoppingHook()