Ejemplo n.º 1
0
    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """Returns the callbacks for a given stage."""
        callbacks_params = self._config.stages[stage].callbacks or {}

        callbacks: Dict[str, Callback] = {
            name: self._get_callback_from_params(callback_params)
            for name, callback_params in callbacks_params.items()
        }

        is_callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values()
        )
        if self._verbose and not is_callback_exists(TqdmCallback):
            callbacks["_verbose"] = TqdmCallback()
        if self._timeit and not is_callback_exists(TimerCallback):
            callbacks["_timer"] = TimerCallback()
        if self._check and not is_callback_exists(CheckRunCallback):
            callbacks["_check"] = CheckRunCallback()
        if self._overfit and not is_callback_exists(BatchOverfitCallback):
            callbacks["_overfit"] = BatchOverfitCallback()

        if self._logdir is not None and not is_callback_exists(ICheckpointCallback):
            callbacks["_checkpoint"] = CheckpointCallback(
                logdir=os.path.join(self._logdir, "checkpoints")
            )

        return callbacks
Ejemplo n.º 2
0
    def get_callbacks(self) -> "OrderedDict[str, Callback]":
        """Returns the callbacks for the experiment."""
        callbacks = sort_callbacks_by_order(self._callbacks)
        callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values()
        )
        if self._verbose and not callback_exists(TqdmCallback):
            callbacks["_verbose"] = TqdmCallback()
        if self._timeit and not callback_exists(TimerCallback):
            callbacks["_timer"] = TimerCallback()
        if self._check and not callback_exists(CheckRunCallback):
            callbacks["_check"] = CheckRunCallback()
        if self._overfit and not callback_exists(BatchOverfitCallback):
            callbacks["_overfit"] = BatchOverfitCallback()
        if self._profile and not callback_exists(ProfilerCallback):
            callbacks["_profile"] = ProfilerCallback(
                tensorboard_path=os.path.join(self._logdir, "tb_profile"),
                profiler_kwargs={
                    "activities": [
                        torch.profiler.ProfilerActivity.CPU,
                        torch.profiler.ProfilerActivity.CUDA,
                    ],
                    "on_trace_ready": torch.profiler.tensorboard_trace_handler(
                        os.path.join(self._logdir, "tb_profile")
                    ),
                    "with_stack": True,
                    "with_flops": True,
                },
            )

        if self._logdir is not None and not callback_exists(ICheckpointCallback):
            callbacks["_checkpoint"] = CheckpointCallback(
                logdir=os.path.join(self._logdir, "checkpoints"),
                loader_key=self._valid_loader,
                metric_key=self._valid_metric,
                minimize=self._minimize_valid_metric,
                load_best_on_end=self._load_best_on_end,
            )
        return callbacks