def _prepare_for_stage(self, stage: str):
        utils.set_global_seed(self.experiment.initial_seed)
        migrating_params = {}
        if self.state is not None:
            migrating_params.update(
                {
                    "step": self.state.step,
                    "epoch": self.state.epoch
                }
            )

        utils.set_global_seed(self.experiment.initial_seed)
        self.model, criterion, optimizer, scheduler, self.device = \
            self._get_experiment_components(stage)

        utils.set_global_seed(self.experiment.initial_seed)
        self.state = self.state_fn(
            stage=stage,
            model=self.model,
            device=self.device,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            **self.experiment.get_state_params(stage),
            **migrating_params
        )

        utils.set_global_seed(self.experiment.initial_seed)
        callbacks = self.experiment.get_callbacks(stage)

        loggers = utils.process_callbacks(
            OrderedDict(
                [
                    (k, v) for k, v in callbacks.items()
                    if isinstance(v, LoggerCallback)
                ]
            )
        )
        callbacks = utils.process_callbacks(
            OrderedDict(
                [
                    (k, v) for k, v in callbacks.items()
                    if not isinstance(v, LoggerCallback)
                ]
            )
        )

        self.state.loggers = loggers
        self.loggers = loggers
        self.callbacks = callbacks
示例#2
0
    def _get_callbacks(self, stage: str):
        callbacks = self.experiment.get_callbacks(stage)

        # distributed run setting
        rank = utils.get_rank()
        if rank == 0:  # master node
            # remove worker-only callbacks on master node
            for k in list(
                filter(
                    lambda c: callbacks[c].node == CallbackNode.Worker,
                    callbacks
                )
            ):
                del callbacks[k]
        elif rank > 0:  # worker node
            # remove master-only callbacks on worker nodes
            for k in list(
                filter(
                    lambda c: callbacks[c].node == CallbackNode.Master,
                    callbacks
                )
            ):
                del callbacks[k]

        callbacks = utils.process_callbacks(callbacks)

        return callbacks
示例#3
0
    def _get_callbacks(self, stage: str):
        callbacks = self.experiment.get_callbacks(stage)

        # Remove master-only callbacks on worker nodes
        if utils.get_rank() > 0:
            for k in list(
                    filter(
                        lambda c: issubclass(callbacks[c].__class__,
                                             MasterOnlyCallback), callbacks)):
                del callbacks[k]

        loggers = utils.process_callbacks(
            OrderedDict([(k, v) for k, v in callbacks.items()
                         if issubclass(v.__class__, LoggerCallback)]))
        callbacks = utils.process_callbacks(
            OrderedDict([(k, v) for k, v in callbacks.items()
                         if not issubclass(v.__class__, LoggerCallback)]))

        return callbacks, loggers
示例#4
0
    def _get_state(
        self,
        stage: str,
        model: Model,
        criterion: Criterion,
        optimizer: Optimizer,
        scheduler: Scheduler,
        device: Device,
        callbacks: Dict[str, Callback],
    ):
        migrating_params = dict(**self.experiment.get_state_params(stage))
        migrate_from_previous_stage = \
            migrating_params.get("migrate_from_previous_stage", True)

        if migrate_from_previous_stage \
                and self.state is not None \
                and self.state.callbacks is not None:
            for key, value in self.state.callbacks.items():
                if value.type == CallbackType.Experiment:
                    callbacks[key] = value
            callbacks = utils.process_callbacks(callbacks)

        if self.state is not None and migrate_from_previous_stage:
            migrating_params.update(
                {
                    "global_step": self.state.global_step,
                    "global_epoch": self.state.global_epoch,
                    "resume": getattr(self.state, "resume", None),
                }
            )

        state = self._state_fn(
            stage=stage,
            model=model,
            device=device,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            callbacks=callbacks,
            **migrating_params
        )

        return state