def __init__(self, model, task, metric, estop_patience=None, best_checkpoint_path=None, auto_average_checkpoints=True, best_avg_checkpoint_path=None, top_checkpoints_to_keep=0): """ Initializes manager for arbitrary evaluation strategies. Args: model: The custom keras model (inherent BaseModel). task: The custom task. metric: The evaluation metric object. estop_patience: An integer, the training process will automatically shut down until the program fail to acquire a better metric score anymore if `early_stop_patience` greater than 0. best_checkpoint_path: The path for checkpoints with best metric scores if provided, otherwise, default \"`model_dir`_best\" will be used. best_avg_checkpoint_path: The path to saving the averaged checkpoints. auto_average_checkpoints: A boolean, whether to do checkpoint average on all model weights. An extra directory for averaged weights will be created. It is only available when `eval_best_checkpoint_path` is provided. top_checkpoints_to_keep: An integer, the maximum number of checkpoints to be saved (`max_to_keep` for checkpoint manager), and the number of latest checkpoints to be averaged if `eval_auto_average_checkpoints` is True. If <= 0, no more checkpoints will be saved. """ self._model = model self._task = task self._metric = metric self._estop_patience = estop_patience self._best_checkpoint_path = best_checkpoint_path self._auto_average_checkpoints = auto_average_checkpoints self._best_avg_checkpoint_path = best_avg_checkpoint_path self._top_checkpoints_to_keep = top_checkpoints_to_keep self._keep_best_ckpt_saver = None self._average_ckpt_saver = None if self._top_checkpoints_to_keep and self._top_checkpoints_to_keep > 0: self._keep_best_ckpt_saver = KeepBestCheckpointSaver( model=self._model, directory=self._best_checkpoint_path, metric=self._metric, max_to_keep=self._top_checkpoints_to_keep) ModelConfigs.dump(self._task.model_configs(self._model), self._keep_best_ckpt_saver.directory) if self._auto_average_checkpoints: self._average_ckpt_saver = AverageCheckpointSaver( model=self._model, directory=self._best_avg_checkpoint_path, metric=self._metric, max_to_keep=self._top_checkpoints_to_keep) ModelConfigs.dump(self._task.model_configs(self._model), self._average_ckpt_saver.directory) self._best_metric_result = None self._bad_count = 0
class TrainingStatusRecorder(object): """ Manage the training status with the best metrics. """ def __init__(self, model, task, metric, estop_patience=None, best_checkpoint_path=None, auto_average_checkpoints=True, best_avg_checkpoint_path=None, top_checkpoints_to_keep=0): """ Initializes manager for arbitrary evaluation strategies. Args: model: The custom keras model (inherent BaseModel). task: The custom task. metric: The evaluation metric object. estop_patience: An integer, the training process will automatically shut down until the program fail to acquire a better metric score anymore if `early_stop_patience` greater than 0. best_checkpoint_path: The path for checkpoints with best metric scores if provided, otherwise, default \"`model_dir`_best\" will be used. best_avg_checkpoint_path: The path to saving the averaged checkpoints. auto_average_checkpoints: A boolean, whether to do checkpoint average on all model weights. An extra directory for averaged weights will be created. It is only available when `eval_best_checkpoint_path` is provided. top_checkpoints_to_keep: An integer, the maximum number of checkpoints to be saved (`max_to_keep` for checkpoint manager), and the number of latest checkpoints to be averaged if `eval_auto_average_checkpoints` is True. If <= 0, no more checkpoints will be saved. """ self._model = model self._task = task self._metric = metric self._estop_patience = estop_patience self._best_checkpoint_path = best_checkpoint_path self._auto_average_checkpoints = auto_average_checkpoints self._best_avg_checkpoint_path = best_avg_checkpoint_path self._top_checkpoints_to_keep = top_checkpoints_to_keep self._keep_best_ckpt_saver = None self._average_ckpt_saver = None if self._top_checkpoints_to_keep and self._top_checkpoints_to_keep > 0: self._keep_best_ckpt_saver = KeepBestCheckpointSaver( model=self._model, directory=self._best_checkpoint_path, metric=self._metric, max_to_keep=self._top_checkpoints_to_keep) ModelConfigs.dump(self._task.model_configs(self._model), self._keep_best_ckpt_saver.directory) if self._auto_average_checkpoints: self._average_ckpt_saver = AverageCheckpointSaver( model=self._model, directory=self._best_avg_checkpoint_path, metric=self._metric, max_to_keep=self._top_checkpoints_to_keep) ModelConfigs.dump(self._task.model_configs(self._model), self._average_ckpt_saver.directory) self._best_metric_result = None self._bad_count = 0 @property def best(self): return self._best_metric_result def record(self, step, metric_result): """ Records the metrics and keep the best. """ metric_result = to_numpy_or_python_type(metric_result) if (self._best_metric_result is None or self._metric.greater_or_eq( metric_result, self._best_metric_result)): self._bad_count = 0 self._best_metric_result = metric_result else: self._bad_count += 1 # re-save the best checkpoint if self._keep_best_ckpt_saver is not None: start_time = time.time() stat = self._keep_best_ckpt_saver.save(step, metric_result) logging.info( "Checking the best checkpoints kept and %s. Elapsed %.2fs", "a new checkpoint was saved" if stat else "no checkpoint was saved.", time.time() - start_time) if self._average_ckpt_saver is not None: start_time = time.time() stat = self._average_ckpt_saver.save(step, metric_result) if stat: logging.info("An averaged checkpoint was saved. Elapsed %.2fs", time.time() - start_time) if self._estop_patience is not None: logging.info( f"Evaluating {self._metric.flag} at step={step} with bad count={self._bad_count} " f"(early_stop_patience={self._estop_patience}).") if self._estop_patience and self._bad_count >= self._estop_patience > 0: logging.info("Hit maximum patience! Early Stop!!!") # kill self and exit with code=0 def handler(*args): sys.exit(0) # register for signal signal.signal(signal.SIGUSR1, handler) os.kill(os.getpid(), signal.SIGUSR1)