示例#1
0
    def __init__(self,
                 model,
                 task,
                 metric,
                 estop_patience=None,
                 best_checkpoint_path=None,
                 auto_average_checkpoints=True,
                 best_avg_checkpoint_path=None,
                 top_checkpoints_to_keep=0):
        """ Initializes manager for arbitrary evaluation strategies.

        Args:
            model: The custom keras model (inherent BaseModel).
            task: The custom task.
            metric: The evaluation metric object.
            estop_patience: An integer, the training process will automatically shut down until the program
                fail to acquire a better metric score anymore if `early_stop_patience` greater than 0.
            best_checkpoint_path: The path for checkpoints with best metric scores if provided,
                otherwise, default \"`model_dir`_best\" will be used.
            best_avg_checkpoint_path: The path to saving the averaged checkpoints.
            auto_average_checkpoints: A boolean, whether to do checkpoint average on all model weights.
                An extra directory for averaged weights will be created. It is only available when
                `eval_best_checkpoint_path` is provided.
            top_checkpoints_to_keep: An integer, the maximum number of checkpoints to be saved
                (`max_to_keep` for checkpoint manager), and the number of latest checkpoints to be averaged
                if `eval_auto_average_checkpoints` is True. If <= 0, no more checkpoints will be saved.
        """
        self._model = model
        self._task = task
        self._metric = metric
        self._estop_patience = estop_patience
        self._best_checkpoint_path = best_checkpoint_path
        self._auto_average_checkpoints = auto_average_checkpoints
        self._best_avg_checkpoint_path = best_avg_checkpoint_path
        self._top_checkpoints_to_keep = top_checkpoints_to_keep
        self._keep_best_ckpt_saver = None
        self._average_ckpt_saver = None
        if self._top_checkpoints_to_keep and self._top_checkpoints_to_keep > 0:
            self._keep_best_ckpt_saver = KeepBestCheckpointSaver(
                model=self._model,
                directory=self._best_checkpoint_path,
                metric=self._metric,
                max_to_keep=self._top_checkpoints_to_keep)
            ModelConfigs.dump(self._task.model_configs(self._model),
                              self._keep_best_ckpt_saver.directory)
            if self._auto_average_checkpoints:
                self._average_ckpt_saver = AverageCheckpointSaver(
                    model=self._model,
                    directory=self._best_avg_checkpoint_path,
                    metric=self._metric,
                    max_to_keep=self._top_checkpoints_to_keep)
                ModelConfigs.dump(self._task.model_configs(self._model),
                                  self._average_ckpt_saver.directory)
        self._best_metric_result = None
        self._bad_count = 0
示例#2
0
class TrainingStatusRecorder(object):
    """ Manage the training status with the best metrics. """
    def __init__(self,
                 model,
                 task,
                 metric,
                 estop_patience=None,
                 best_checkpoint_path=None,
                 auto_average_checkpoints=True,
                 best_avg_checkpoint_path=None,
                 top_checkpoints_to_keep=0):
        """ Initializes manager for arbitrary evaluation strategies.

        Args:
            model: The custom keras model (inherent BaseModel).
            task: The custom task.
            metric: The evaluation metric object.
            estop_patience: An integer, the training process will automatically shut down until the program
                fail to acquire a better metric score anymore if `early_stop_patience` greater than 0.
            best_checkpoint_path: The path for checkpoints with best metric scores if provided,
                otherwise, default \"`model_dir`_best\" will be used.
            best_avg_checkpoint_path: The path to saving the averaged checkpoints.
            auto_average_checkpoints: A boolean, whether to do checkpoint average on all model weights.
                An extra directory for averaged weights will be created. It is only available when
                `eval_best_checkpoint_path` is provided.
            top_checkpoints_to_keep: An integer, the maximum number of checkpoints to be saved
                (`max_to_keep` for checkpoint manager), and the number of latest checkpoints to be averaged
                if `eval_auto_average_checkpoints` is True. If <= 0, no more checkpoints will be saved.
        """
        self._model = model
        self._task = task
        self._metric = metric
        self._estop_patience = estop_patience
        self._best_checkpoint_path = best_checkpoint_path
        self._auto_average_checkpoints = auto_average_checkpoints
        self._best_avg_checkpoint_path = best_avg_checkpoint_path
        self._top_checkpoints_to_keep = top_checkpoints_to_keep
        self._keep_best_ckpt_saver = None
        self._average_ckpt_saver = None
        if self._top_checkpoints_to_keep and self._top_checkpoints_to_keep > 0:
            self._keep_best_ckpt_saver = KeepBestCheckpointSaver(
                model=self._model,
                directory=self._best_checkpoint_path,
                metric=self._metric,
                max_to_keep=self._top_checkpoints_to_keep)
            ModelConfigs.dump(self._task.model_configs(self._model),
                              self._keep_best_ckpt_saver.directory)
            if self._auto_average_checkpoints:
                self._average_ckpt_saver = AverageCheckpointSaver(
                    model=self._model,
                    directory=self._best_avg_checkpoint_path,
                    metric=self._metric,
                    max_to_keep=self._top_checkpoints_to_keep)
                ModelConfigs.dump(self._task.model_configs(self._model),
                                  self._average_ckpt_saver.directory)
        self._best_metric_result = None
        self._bad_count = 0

    @property
    def best(self):
        return self._best_metric_result

    def record(self, step, metric_result):
        """ Records the metrics and keep the best. """
        metric_result = to_numpy_or_python_type(metric_result)
        if (self._best_metric_result is None or self._metric.greater_or_eq(
                metric_result, self._best_metric_result)):
            self._bad_count = 0
            self._best_metric_result = metric_result
        else:
            self._bad_count += 1

        # re-save the best checkpoint
        if self._keep_best_ckpt_saver is not None:
            start_time = time.time()
            stat = self._keep_best_ckpt_saver.save(step, metric_result)
            logging.info(
                "Checking the best checkpoints kept and %s. Elapsed %.2fs",
                "a new checkpoint was saved"
                if stat else "no checkpoint was saved.",
                time.time() - start_time)
        if self._average_ckpt_saver is not None:
            start_time = time.time()
            stat = self._average_ckpt_saver.save(step, metric_result)
            if stat:
                logging.info("An averaged checkpoint was saved. Elapsed %.2fs",
                             time.time() - start_time)

        if self._estop_patience is not None:
            logging.info(
                f"Evaluating {self._metric.flag} at step={step} with bad count={self._bad_count} "
                f"(early_stop_patience={self._estop_patience}).")
        if self._estop_patience and self._bad_count >= self._estop_patience > 0:
            logging.info("Hit maximum patience! Early Stop!!!")

            # kill self and exit with code=0
            def handler(*args):
                sys.exit(0)

            # register for signal
            signal.signal(signal.SIGUSR1, handler)
            os.kill(os.getpid(), signal.SIGUSR1)