def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": """Returns the callbacks for a given stage.""" callbacks_params = get_by_keys(self._stage_config, stage, "callbacks", default={}) callbacks = OrderedDict(REGISTRY.get_from_params(**callbacks_params)) is_callback_exists = lambda callback_fn: any( callback_isinstance(x, callback_fn) for x in callbacks.values()) if self._verbose and not is_callback_exists(TqdmCallback): callbacks["_verbose"] = TqdmCallback() if self._timeit and not is_callback_exists(TimerCallback): callbacks["_timer"] = TimerCallback() if self._check and not is_callback_exists(CheckRunCallback): callbacks["_check"] = CheckRunCallback() if self._overfit and not is_callback_exists(BatchOverfitCallback): callbacks["_overfit"] = BatchOverfitCallback() if self._logdir is not None and not is_callback_exists( ICheckpointCallback): callbacks["_checkpoint"] = CheckpointCallback(logdir=os.path.join( self._logdir, "checkpoints"), ) return callbacks
def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": """Returns the callbacks for a given stage.""" callbacks_params = self._config.stages[stage].callbacks or {} callbacks: Dict[str, Callback] = { name: self._get_callback_from_params(callback_params) for name, callback_params in callbacks_params.items() } is_callback_exists = lambda callback_fn: any( callback_isinstance(x, callback_fn) for x in callbacks.values() ) if self._verbose and not is_callback_exists(TqdmCallback): callbacks["_verbose"] = TqdmCallback() if self._timeit and not is_callback_exists(TimerCallback): callbacks["_timer"] = TimerCallback() if self._check and not is_callback_exists(CheckRunCallback): callbacks["_check"] = CheckRunCallback() if self._overfit and not is_callback_exists(BatchOverfitCallback): callbacks["_overfit"] = BatchOverfitCallback() if self._logdir is not None and not is_callback_exists(ICheckpointCallback): callbacks["_checkpoint"] = CheckpointCallback( logdir=os.path.join(self._logdir, "checkpoints") ) return callbacks
def get_callbacks(self) -> "OrderedDict[str, Callback]": """Returns the callbacks for the experiment.""" callbacks = sort_callbacks_by_order(self._callbacks) callback_exists = lambda callback_fn: any( callback_isinstance(x, callback_fn) for x in callbacks.values() ) if self._verbose and not callback_exists(TqdmCallback): callbacks["_verbose"] = TqdmCallback() if self._timeit and not callback_exists(TimerCallback): callbacks["_timer"] = TimerCallback() if self._check and not callback_exists(CheckRunCallback): callbacks["_check"] = CheckRunCallback() if self._overfit and not callback_exists(BatchOverfitCallback): callbacks["_overfit"] = BatchOverfitCallback() if self._profile and not callback_exists(ProfilerCallback): callbacks["_profile"] = ProfilerCallback( tensorboard_path=os.path.join(self._logdir, "tb_profile"), profiler_kwargs={ "activities": [ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], "on_trace_ready": torch.profiler.tensorboard_trace_handler( os.path.join(self._logdir, "tb_profile") ), "with_stack": True, "with_flops": True, }, ) if self._logdir is not None and not callback_exists(ICheckpointCallback): callbacks["_checkpoint"] = CheckpointCallback( logdir=os.path.join(self._logdir, "checkpoints"), loader_key=self._valid_loader, metric_key=self._valid_metric, minimize=self._minimize_valid_metric, load_best_on_end=self._load_best_on_end, ) return callbacks
def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]": """Returns the callbacks for a given stage.""" callbacks = sort_callbacks_by_order(self._callbacks) is_callback_exists = lambda callback_fn: any( callback_isinstance(x, callback_fn) for x in callbacks.values() ) if self._verbose and not is_callback_exists(TqdmCallback): callbacks["_verbose"] = TqdmCallback() if self._timeit and not is_callback_exists(TimerCallback): callbacks["_timer"] = TimerCallback() if self._check and not is_callback_exists(CheckRunCallback): callbacks["_check"] = CheckRunCallback() if self._overfit and not is_callback_exists(BatchOverfitCallback): callbacks["_overfit"] = BatchOverfitCallback() if self._logdir is not None and not is_callback_exists(ICheckpointCallback): callbacks["_checkpoint"] = CheckpointCallback( logdir=os.path.join(self._logdir, "checkpoints"), loader_key=self._valid_loader, metric_key=self._valid_metric, minimize=self._minimize_valid_metric, ) return callbacks