示例#1
0
    def __init__(
        self,
        logdir: str = None,
        # model selection info
        loader_key: str = None,
        metric_key: str = None,
        minimize: bool = None,
        min_delta: float = 1e-6,
        save_n_best: int = 1,
        # loading info
        load_on_stage_start: Union[str, Dict[str, str]] = None,
        load_on_stage_end: Union[str, Dict[str, str]] = None,
        # resume: str = None,
        # resume_dir: str = None,
        # checkpointer info
        metrics_filename: str = "_metrics.json",
        mode: str = "all",
        use_logdir_postfix: bool = False,
        use_runner_logdir: bool = False,
    ):
        """Init."""
        super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
        possible_states = {
            None,
            "best",
            "last",
            "best_full",
            "last_full",
        }
        assert save_n_best >= 0
        if save_n_best == 0:
            assert load_on_stage_end in (None, "last", "last_full")
        if isinstance(load_on_stage_start, str):
            assert load_on_stage_start in possible_states
        if isinstance(load_on_stage_end, str):
            assert load_on_stage_end in possible_states
        # if resume_dir is not None:
        #     assert resume is not None

        if loader_key is not None or metric_key is not None:
            assert loader_key is not None and metric_key is not None, (
                "For checkpoint selection `CheckpointCallback` "
                "requires both `loader_key` and `metric_key` specified.")
            self._use_model_selection = True
            self.minimize = minimize if minimize is not None else True  # loss-oriented selection
        else:
            self._use_model_selection = False
            self.minimize = False  # epoch-num-oriented selection

        assert mode in (
            "all",
            "full",
            "model",
        ), "`CheckpointCallback` could work only in `all`, `full` or `model` modes."

        # checkpointer info
        self.logdir = logdir
        self.mode = mode
        self.metrics_filename = metrics_filename
        self.use_logdir_postfix = use_logdir_postfix
        self.use_runner_logdir = use_runner_logdir
        assert (self.logdir is not None or self.use_runner_logdir
                ), "CheckpointCallback requires specified `logdir`"

        # model selection info
        self.loader_key = loader_key
        self.metric_key = metric_key
        self.is_better = MetricHandler(minimize=minimize, min_delta=min_delta)
        self.save_n_best = save_n_best
        # list with topN metrics [(score, filepath, stage_key, stage_epoch_step, epoch metrics)]
        self.top_best_metrics = []
        self.best_score = None

        # loading info
        self.load_on_stage_start = load_on_stage_start
        self.load_on_stage_end = load_on_stage_end