def export(self, estimator: tf.estimator.Estimator): # Reload summaries and select best step LOGGER.info(f"Reloading summaries from {estimator.model_dir}") summaries = read_eval_metrics(estimator.eval_dir()).items() for step, metrics in sorted(summaries): LOGGER.info(f"- {step}: {metrics}") sorted_summaries = sorted(summaries, key=lambda t: t[1][self.metric]) if self.mode == BestMode.INCREASE: best_step, best_metrics = sorted_summaries[-1] elif self.mode == BestMode.DECREASE: best_step, best_metrics = sorted_summaries[0] else: raise ValueError(f"Mode {self.mode} not recognized.") LOGGER.info(f"Best summary at step {best_step}: {best_metrics}") # List available checkpoints and select closes to best_step checkpoints = Path(estimator.model_dir).glob(_CHEKPOINT_PATTERN) checkpoint_steps = [ int(re.findall(r"-(\d+).index", str(path))[0]) for path in checkpoints ] selected_step = sorted(checkpoint_steps, key=lambda step: abs(step - best_step))[0] LOGGER.info(f"Selected checkpoint {selected_step}") # Change checkpoint information with Path(estimator.model_dir, "checkpoint").open("r") as file: lines = file.read().split("\n") lines[0] = f'model_checkpoint_path: "model.ckpt-{selected_step}"' with Path(estimator.model_dir, "checkpoint").open("w") as file: file.write("\n".join(lines)) # Check that change is effective global_step = estimator.get_variable_value("global_step") if not global_step == selected_step: msg = f"Changed checkpoint file to use step {selected_step}, but estimator uses {global_step}" raise ValueError(msg) # Log to MLFlow if self.use_mlflow: mlflow.log_metric(key=self.tag, value=global_step)
def __call__( self, estimator: tf.estimator.Estimator) -> tf.estimator.SessionRunHook: if estimator.config.is_chief: return _StopOnPredicateHook( partial( _no_metric_improvement_fn, eval_dir=estimator.eval_dir(), min_steps=self.min_steps, metric=self.metric, max_steps_without_improvement=self. max_steps_without_improvement, mode=self.mode, ), run_every_secs=self.run_every_secs, run_every_steps=self.run_every_steps, final_step=self.final_step, ) else: return _CheckForStoppingHook()