Beispiel #1
0
    def __init__(self, config: Dict):
        """Init."""
        super().__init__()
        self._config: Dict = deepcopy(config)
        self._stage_config: Dict = self._config["stages"]

        self._seed: int = get_by_keys(self._config, "args", "seed", default=42)
        self._verbose: bool = get_by_keys(self._config,
                                          "args",
                                          "verbose",
                                          default=False)
        self._timeit: bool = get_by_keys(self._config,
                                         "args",
                                         "timeit",
                                         default=False)
        self._check: bool = get_by_keys(self._config,
                                        "args",
                                        "check",
                                        default=False)
        self._overfit: bool = get_by_keys(self._config,
                                          "args",
                                          "overfit",
                                          default=False)
        self._name: str = self._get_run_name()
        self._logdir: str = self._get_run_logdir()

        # @TODO: hack for catalyst-dl tune, could be done better
        self._trial = None
Beispiel #2
0
    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """Returns the callbacks for a given stage."""
        callbacks_params = get_by_keys(self._stage_config,
                                       stage,
                                       "callbacks",
                                       default={})
        callbacks = OrderedDict(REGISTRY.get_from_params(**callbacks_params))

        is_callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values())
        if self._verbose and not is_callback_exists(TqdmCallback):
            callbacks["_verbose"] = TqdmCallback()
        if self._timeit and not is_callback_exists(TimerCallback):
            callbacks["_timer"] = TimerCallback()
        if self._check and not is_callback_exists(CheckRunCallback):
            callbacks["_check"] = CheckRunCallback()
        if self._overfit and not is_callback_exists(BatchOverfitCallback):
            callbacks["_overfit"] = BatchOverfitCallback()

        if self._logdir is not None and not is_callback_exists(
                ICheckpointCallback):
            callbacks["_checkpoint"] = CheckpointCallback(logdir=os.path.join(
                self._logdir, "checkpoints"), )

        return callbacks
Beispiel #3
0
    def get_optimizer(self, model: RunnerModel, stage: str) -> RunnerOptimizer:
        """
        Returns the optimizer for a given stage and epoch.

        Args:
            model: model or a dict of models
            stage: current stage name

        Returns:
            optimizer for selected stage and epoch
        """
        if "optimizer" not in self._stage_config[stage]:
            return None

        optimizer_params = get_by_keys(self._stage_config,
                                       stage,
                                       "optimizer",
                                       default={})
        optimizer_params = deepcopy(optimizer_params)
        is_key_value = optimizer_params.pop("_key_value", False)

        if is_key_value:
            optimizer = {}
            for key, params in optimizer_params.items():
                optimizer[key] = self._get_optimizer_from_params(model=model,
                                                                 stage=stage,
                                                                 **params)
        else:
            optimizer = self._get_optimizer_from_params(model=model,
                                                        stage=stage,
                                                        **optimizer_params)

        return optimizer
Beispiel #4
0
    def _get_run_logdir(self) -> str:  # noqa: WPS112
        output = None
        exclude_tag = "none"

        logdir: str = get_by_keys(self._config, "args", "logdir", default=None)
        baselogdir: str = get_by_keys(self._config,
                                      "args",
                                      "baselogdir",
                                      default=None)

        if logdir is not None and logdir.lower() != exclude_tag:
            output = logdir
        elif baselogdir is not None and baselogdir.lower() != exclude_tag:
            logdir = self._get_logdir(self._config)
            output = f"{baselogdir}/{logdir}"
        return output
Beispiel #5
0
 def get_criterion(self, stage: str) -> RunnerCriterion:
     """Returns the criterion for a given stage."""
     criterion_params = get_by_keys(self._stage_config,
                                    stage,
                                    "criterion",
                                    default={})
     criterion = REGISTRY.get_from_params(**criterion_params)
     return criterion or None
Beispiel #6
0
 def get_criterion(self, stage: str) -> RunnerCriterion:
     """Returns the criterion for a given stage."""
     if "criterion" not in self._stage_config[stage]:
         return None
     criterion_params = get_by_keys(self._stage_config,
                                    stage,
                                    "criterion",
                                    default={})
     criterion = self._get_criterion_from_params(**criterion_params)
     return criterion
Beispiel #7
0
 def get_scheduler(self, optimizer: RunnerOptimizer,
                   stage: str) -> RunnerScheduler:
     """Returns the scheduler for a given stage."""
     if "scheduler" not in self._stage_config[stage]:
         return None
     scheduler_params = get_by_keys(self._stage_config,
                                    stage,
                                    "scheduler",
                                    default={})
     scheduler = self._get_scheduler_from_params(optimizer=optimizer,
                                                 **scheduler_params)
     return scheduler
Beispiel #8
0
    def get_stage_len(self, stage: str) -> int:
        """Returns number of epochs for the selected stage.

        Args:
            stage: current stage

        Returns:
            number of epochs in stage

        Example::

            >>> runner.get_stage_len("pretraining")
            3
        """
        return get_by_keys(self._stage_config, stage, "num_epochs", default=1)
Beispiel #9
0
    def get_samplers(self, stage: str) -> "OrderedDict[str, Sampler]":
        """
        Returns samplers for a given stage.

        Args:
            stage: stage name

        Returns:
            Dict of samplers
        """
        samplers_params = get_by_keys(self._stage_config,
                                      stage,
                                      "loaders",
                                      "samplers",
                                      default={})
        samplers = REGISTRY.get_from_params(**samplers_params)
        return OrderedDict(samplers)
Beispiel #10
0
 def _get_run_name(self) -> str:
     timestamp = get_utcnow_time()
     config_hash = get_short_hash(self._config)
     default_name = f"{timestamp}-{config_hash}"
     name = get_by_keys(self._config, "args", "name", default=default_name)
     return name