Esempio n. 1
0
    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """
        Override of ``BaseExperiment.get_callbacks`` method.
        Will add several of callbacks by default in case they missed.

        Args:
            stage: name of stage. It should start with `infer` if you
                don't need default callbacks, as they required only for
                training stages.

        Returns:
            OrderedDict[str, Callback]: Ordered dictionary of callbacks
                for experiment
        """
        callbacks = super().get_callbacks(stage=stage) or OrderedDict()

        # default_callbacks = [(Name, InterfaceClass, InstanceFactory)]
        default_callbacks = []

        is_amp_enabled = (
            self.distributed_params.get("amp", False) and check_amp_available()
        )
        optimizer_cls = (
            AMPOptimizerCallback if is_amp_enabled else OptimizerCallback
        )

        if not stage.startswith("infer"):
            if self._criterion is not None and isinstance(
                self._criterion, Criterion
            ):
                default_callbacks.append(
                    ("_criterion", None, CriterionCallback)
                )
            if self._optimizer is not None and isinstance(
                self._optimizer, Optimizer
            ):
                default_callbacks.append(
                    ("_optimizer", IOptimizerCallback, optimizer_cls)
                )
            if self._scheduler is not None and isinstance(
                self._scheduler, (Scheduler, ReduceLROnPlateau)
            ):
                default_callbacks.append(
                    ("_scheduler", ISchedulerCallback, SchedulerCallback)
                )

        for (
            callback_name,
            callback_interface,
            callback_fn,
        ) in default_callbacks:
            callback_interface = callback_interface or callback_fn
            is_already_present = any(
                check_callback_isinstance(x, callback_interface)
                for x in callbacks.values()
            )
            if not is_already_present:
                callbacks[callback_name] = callback_fn()

        return callbacks
Esempio n. 2
0
    def get_callbacks(self, stage: str) -> "OrderedDict[Callback]":
        """Returns the callbacks for a given stage."""
        callbacks_params = self.stages_config[stage].get(
            "callbacks_params", {})

        callbacks = OrderedDict()
        for key, callback_params in callbacks_params.items():
            callback = self._get_callback(**callback_params)
            callbacks[key] = callback

        # default_callbacks = [(Name, InterfaceClass, InstanceFactory)]
        default_callbacks = []

        is_amp_enabled = (self.distributed_params.get("amp", False)
                          and check_amp_available())
        optimizer_cls = (AMPOptimizerCallback
                         if is_amp_enabled else OptimizerCallback)

        if self._verbose:
            default_callbacks.append(("_verbose", None, VerboseLogger))
        if self._check_time:
            default_callbacks.append(("_timer", None, TimerCallback))
        if self._check_run:
            default_callbacks.append(("_check", None, CheckRunCallback))
        if self._overfit:
            default_callbacks.append(("_overfit", None, BatchOverfitCallback))

        if not stage.startswith("infer"):
            default_callbacks.append(("_metrics", None, MetricManagerCallback))
            default_callbacks.append(
                ("_validation", None, ValidationManagerCallback))
            default_callbacks.append(("_console", None, ConsoleLogger))

            if self.logdir is not None:
                default_callbacks.append(("_saver", None, CheckpointCallback))
                default_callbacks.append(
                    ("_tensorboard", None, TensorboardLogger))

            if self.stages_config[stage].get("criterion_params", {}):
                default_callbacks.append(
                    ("_criterion", None, CriterionCallback))
            if self.stages_config[stage].get("optimizer_params", {}):
                default_callbacks.append(
                    ("_optimizer", IOptimizerCallback, optimizer_cls))
            if self.stages_config[stage].get("scheduler_params", {}):
                default_callbacks.append(
                    ("_scheduler", ISchedulerCallback, SchedulerCallback))

        default_callbacks.append(("_exception", None, ExceptionCallback))

        for (
                callback_name,
                callback_interface,
                callback_fn,
        ) in default_callbacks:
            callback_interface = callback_interface or callback_fn
            is_already_present = any(
                check_callback_isinstance(x, callback_interface)
                for x in callbacks.values())
            if not is_already_present:
                callbacks[callback_name] = callback_fn()

        # NOTE: stage should be in self.stages_config
        #       othervise will be raised ValueError
        stage_index = list(self.stages_config.keys()).index(stage)
        self._process_callbacks(callbacks, stage_index)

        return callbacks