def __init__(self, config: Dict): """Init.""" super().__init__() self._config: Dict = deepcopy(config) self._stage_config: Dict = self._config["stages"] self._seed: int = get_by_keys(self._config, "args", "seed", default=42) self._verbose: bool = get_by_keys(self._config, "args", "verbose", default=False) self._timeit: bool = get_by_keys(self._config, "args", "timeit", default=False) self._check: bool = get_by_keys(self._config, "args", "check", default=False) self._overfit: bool = get_by_keys(self._config, "args", "overfit", default=False) self._name: str = self._get_run_name() self._logdir: str = self._get_run_logdir() # @TODO: hack for catalyst-dl tune, could be done better self._trial = None
def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": """Returns the callbacks for a given stage.""" callbacks_params = get_by_keys(self._stage_config, stage, "callbacks", default={}) callbacks = OrderedDict(REGISTRY.get_from_params(**callbacks_params)) is_callback_exists = lambda callback_fn: any( callback_isinstance(x, callback_fn) for x in callbacks.values()) if self._verbose and not is_callback_exists(TqdmCallback): callbacks["_verbose"] = TqdmCallback() if self._timeit and not is_callback_exists(TimerCallback): callbacks["_timer"] = TimerCallback() if self._check and not is_callback_exists(CheckRunCallback): callbacks["_check"] = CheckRunCallback() if self._overfit and not is_callback_exists(BatchOverfitCallback): callbacks["_overfit"] = BatchOverfitCallback() if self._logdir is not None and not is_callback_exists( ICheckpointCallback): callbacks["_checkpoint"] = CheckpointCallback(logdir=os.path.join( self._logdir, "checkpoints"), ) return callbacks
def get_optimizer(self, model: RunnerModel, stage: str) -> RunnerOptimizer: """ Returns the optimizer for a given stage and epoch. Args: model: model or a dict of models stage: current stage name Returns: optimizer for selected stage and epoch """ if "optimizer" not in self._stage_config[stage]: return None optimizer_params = get_by_keys(self._stage_config, stage, "optimizer", default={}) optimizer_params = deepcopy(optimizer_params) is_key_value = optimizer_params.pop("_key_value", False) if is_key_value: optimizer = {} for key, params in optimizer_params.items(): optimizer[key] = self._get_optimizer_from_params(model=model, stage=stage, **params) else: optimizer = self._get_optimizer_from_params(model=model, stage=stage, **optimizer_params) return optimizer
def _get_run_logdir(self) -> str: # noqa: WPS112 output = None exclude_tag = "none" logdir: str = get_by_keys(self._config, "args", "logdir", default=None) baselogdir: str = get_by_keys(self._config, "args", "baselogdir", default=None) if logdir is not None and logdir.lower() != exclude_tag: output = logdir elif baselogdir is not None and baselogdir.lower() != exclude_tag: logdir = self._get_logdir(self._config) output = f"{baselogdir}/{logdir}" return output
def get_criterion(self, stage: str) -> RunnerCriterion: """Returns the criterion for a given stage.""" criterion_params = get_by_keys(self._stage_config, stage, "criterion", default={}) criterion = REGISTRY.get_from_params(**criterion_params) return criterion or None
def get_criterion(self, stage: str) -> RunnerCriterion: """Returns the criterion for a given stage.""" if "criterion" not in self._stage_config[stage]: return None criterion_params = get_by_keys(self._stage_config, stage, "criterion", default={}) criterion = self._get_criterion_from_params(**criterion_params) return criterion
def get_scheduler(self, optimizer: RunnerOptimizer, stage: str) -> RunnerScheduler: """Returns the scheduler for a given stage.""" if "scheduler" not in self._stage_config[stage]: return None scheduler_params = get_by_keys(self._stage_config, stage, "scheduler", default={}) scheduler = self._get_scheduler_from_params(optimizer=optimizer, **scheduler_params) return scheduler
def get_stage_len(self, stage: str) -> int: """Returns number of epochs for the selected stage. Args: stage: current stage Returns: number of epochs in stage Example:: >>> runner.get_stage_len("pretraining") 3 """ return get_by_keys(self._stage_config, stage, "num_epochs", default=1)
def get_samplers(self, stage: str) -> "OrderedDict[str, Sampler]": """ Returns samplers for a given stage. Args: stage: stage name Returns: Dict of samplers """ samplers_params = get_by_keys(self._stage_config, stage, "loaders", "samplers", default={}) samplers = REGISTRY.get_from_params(**samplers_params) return OrderedDict(samplers)
def _get_run_name(self) -> str: timestamp = get_utcnow_time() config_hash = get_short_hash(self._config) default_name = f"{timestamp}-{config_hash}" name = get_by_keys(self._config, "args", "name", default=default_name) return name