Exemplo n.º 1
0
def filter_callbacks_by_node(
        callbacks: Union[Dict, OrderedDict]) -> Union[Dict, OrderedDict]:
    """
    Filters callbacks based on running node.
    Deletes worker-only callbacks from ``CallbackNode.Master``
    and master-only callbacks from ``CallbackNode.Worker``.

    Args:
        callbacks (Union[Dict, OrderedDict]): callbacks

    Returns:
        Union[Dict, OrderedDict]: filtered callbacks dictionary.
    """
    # distributed run setting
    output = callbacks.copy()
    rank = get_rank()
    if rank == 0:  # master node
        # remove worker-only callbacks on master node
        for k in list(
                filter(
                    lambda c: output[c].node == CallbackNode.Worker,
                    output,
                )):
            del output[k]
    elif rank > 0:  # worker node
        # remove master-only callbacks on worker nodes
        for k in list(
                filter(
                    lambda c: output[c].node == CallbackNode.Master,
                    output,
                )):
            del output[k]
    return output
Exemplo n.º 2
0
    def _get_callbacks(self, stage: str) -> Dict[str, Callback]:
        """
        Inner method for `Callbacks` preparation.

        Takes callbacks from the Experiment
        and filters them for distributed master/worker cases.

        Args:
            stage (str): stage name of interest,
                like "pretraining" / "training" / "finetuning" / etc

        Returns:
            OrderedDict[str, Callback]: Ordered dictionary
                with callbacks for current experiment stage.
        """
        callbacks = self.experiment.get_callbacks(stage)

        # distributed run setting
        rank = utils.get_rank()
        if rank == 0:  # master node
            # remove worker-only callbacks on master node
            for k in list(
                    filter(lambda c: callbacks[c].node == CallbackNode.Worker,
                           callbacks)):
                del callbacks[k]
        elif rank > 0:  # worker node
            # remove master-only callbacks on worker nodes
            for k in list(
                    filter(lambda c: callbacks[c].node == CallbackNode.Master,
                           callbacks)):
                del callbacks[k]

        callbacks = utils.process_callbacks(callbacks)

        return callbacks
Exemplo n.º 3
0
    def __init__(
        self,
        *,
        device: Device = None,
        model: StateModel = None,
        criterion: StateCriterion = None,
        optimizer: StateOptimizer = None,
        scheduler: StateScheduler = None,
        callbacks: Dict[str, "Callback"] = None,
        logdir: str = None,
        stage: str = STAGE_INFER_PREFIX,
        num_epochs: int = None,
        main_metric: str = STATE_MAIN_METRIC,
        minimize_metric: bool = True,
        valid_loader: str = LOADER_VALID_PREFIX,
        checkpoint_data: Dict = None,
        is_check_run: bool = False,
        **kwargs,
    ):
        """
        Args:
            @TODO: Docs. Contribution is welcome
        """
        # main part
        # data
        self.loaders: OrderedDict[str, DataLoader] = None
        # components
        self.model: StateModel = model
        self.criterion: StateCriterion = criterion
        self.optimizer: StateOptimizer = optimizer
        self.scheduler: StateScheduler = scheduler
        # extra components - PyTorch device
        self.device: Device = device
        # extra components - Catalyst callbacks
        self.callbacks: Dict[str, "Callback"] = callbacks

        # dataflow - model input, model output, metrics
        self.batch_in = None
        self.batch_out = None
        # let's use flatten storage for batch metrics
        # batch_metrics = {'loss': ..., 'accuracy': ..., 'iou': ...}
        self.batch_metrics = defaultdict(None)
        # just aggregated (aka mean over all batches)
        # batch statistics for loader
        # and global loader metrics, like AUC
        # loader_metrics = {'loss': ..., 'accuracy': ..., `auc`: ...}
        self.loader_metrics = defaultdict(None)
        # summarized metrics for different loaders
        # and global epoch metrics, like lr, momentum
        # epoch_metrics = {
        # 'train_loss': ..., 'train_auc': ..., 'valid_loss': ...,
        # 'lr': ..., 'momentum': ...,
        # }
        self.epoch_metrics = defaultdict(None)

        # validation
        self.is_best_valid = False
        self.valid_metrics = defaultdict(None)
        self.best_valid_metrics = defaultdict(None)

        # pipeline info
        self.distributed_rank = utils.get_rank()
        self.is_distributed_worker = self.distributed_rank > 0

        self.stage_name: str = stage
        self.epoch: int = 1
        self.num_epochs: int = num_epochs or np.iinfo(np.int32).max

        self.loader_name: str = None
        self.loader_step: int = 0
        self.loader_len: int = 0

        self.batch_size: int = 0

        self.global_step: int = 0
        self.global_epoch: int = 1

        # metrics & validation
        self.main_metric: str = main_metric
        self.minimize_metric: bool = minimize_metric
        self.valid_loader: str = valid_loader

        # logging
        self.logdir: Path = Path(logdir) if logdir is not None else None
        # extra checkpoint data for saving in checkpoint files
        self.checkpoint_data: Dict = checkpoint_data or {}

        # other
        self.is_check_run: bool = is_check_run
        self.is_train_loader: bool = False
        self.is_valid_loader: bool = False
        self.is_infer_loader: bool = False
        self.is_infer_stage: bool = self.stage_name.startswith(
            STAGE_INFER_PREFIX)
        self.need_early_stop: bool = False
        self.need_exception_reraise: bool = True
        self.exception: Optional[Exception] = None

        # kwargs
        for k, v in kwargs.items():
            setattr(self, k, v)

        self._freeze()
Exemplo n.º 4
0
    def _prepare_inner_state(
        self,
        stage: str = settings.stage_infer_prefix,
        device: Device = None,
        model: RunnerModel = None,
        criterion: RunnerCriterion = None,
        optimizer: RunnerOptimizer = None,
        scheduler: RunnerScheduler = None,
        callbacks: Dict[str, "Callback"] = None,
        logdir: str = None,
        num_epochs: int = 1,
        main_metric: str = "loss",
        minimize_metric: bool = True,
        valid_loader: str = settings.loader_valid_prefix,
        checkpoint_data: Dict = None,
        is_check_run: bool = False,
        verbose: bool = False,
        **kwargs,
    ):
        self._unfreeze()

        # main runner components: model and device to run
        self.device: Device = device
        self.model: RunnerModel = model

        # extra experiment components,
        # use `catalyst.core.IExperiment` to setup them
        self.criterion: RunnerCriterion = criterion
        self.optimizer: RunnerOptimizer = optimizer
        self.scheduler: RunnerScheduler = scheduler
        # and callbacks
        self.callbacks: Dict[str, "Callback"] = callbacks or {}

        # the data
        self.loaders: OrderedDict[str, DataLoader] = None
        # and the dataflow - model input, model output
        self.input = None
        self.output = None

        # metrics flow - batch, loader, epoch metrics
        # let's use flatten storage for batch metrics
        # batch_metrics = {'loss': ..., 'accuracy': ..., 'iou': ...}
        self.batch_metrics: Dict = defaultdict(None)
        # just aggregated (aka mean over all batches)
        # batch statistics for loader
        # and global loader metrics, like AUC
        # loader_metrics = {'loss': ..., 'accuracy': ..., `auc`: ...}
        self.loader_metrics: Dict = defaultdict(None)
        # summarized metrics for different loaders
        # and global epoch metrics, like lr, momentum
        # epoch_metrics = {
        # 'train_loss': ..., 'train_auc': ..., 'valid_loss': ...,
        # 'lr': ..., 'momentum': ...,
        # }
        self.epoch_metrics: Dict = defaultdict(None)

        # metrics & validation
        self.main_metric: str = main_metric
        self.minimize_metric: bool = minimize_metric

        # validation
        self.valid_loader: str = valid_loader
        self.valid_metrics: Dict = defaultdict(None)
        self.is_best_valid: bool = False
        self.best_valid_metrics: Dict = defaultdict(None)

        # distributed info
        self.distributed_rank: int = utils.get_rank()
        self.is_distributed_master: bool = ~(self.distributed_rank > 0)
        self.is_distributed_worker: bool = self.distributed_rank > 0
        # experiment info
        self.global_sample_step: int = 0
        self.global_batch_step: int = 0
        self.global_epoch: int = 1
        self.verbose: bool = verbose
        self.is_check_run: bool = is_check_run
        self.need_early_stop: bool = False
        self.need_exception_reraise: bool = True
        # stage info
        self.num_epochs: int = num_epochs
        self.stage_name: str = stage
        self.is_infer_stage: bool = self.stage_name.startswith(
            settings.stage_infer_prefix)
        # epoch info
        self.epoch: int = 1
        # loader info
        self.loader_sample_step: int = 0
        self.loader_batch_step: int = 0
        self.loader_name: str = None
        self.loader_len: int = 0
        self.loader_batch_size = 0
        self.is_train_loader: bool = False
        self.is_valid_loader: bool = False
        self.is_infer_loader: bool = True
        # batch info
        self.batch_size: int = 0

        # logging
        self.expdir: Path = None
        self.logdir: Path = Path(logdir) if logdir is not None else None
        # extra checkpoint data for saving in checkpoint files
        self.checkpoint_data: Dict = checkpoint_data or {}

        # extra
        self.exception: Optional[Exception] = None

        # kwargs
        for key, value in kwargs.items():
            setattr(self, key, value)

        self._freeze()