Beispiel #1
0
    def _init_state(self,
                    *,
                    mode: str,
                    stage: str = None,
                    **kwargs) -> RunnerState:
        """
        Inner method for children's classes for state specific initialization.
        :return: RunnerState with all necessary parameters.
        """
        additional_kwargs = {}

        # transfer previous counters from old state
        if self.state is not None:
            additional_kwargs = {
                "step": self.state.step,
                "epoch": self.state.epoch + 1,
                "best_metrics": self.state.best_metrics
            }
        return RunnerState(device=self.device,
                           model=self.model,
                           stage=self.stage,
                           _criterion=self.criterion,
                           _optimizer=self.optimizer,
                           _scheduler=self.scheduler,
                           **kwargs,
                           **additional_kwargs)
    def batch_handler(self,
                      *,
                      dct: Dict,
                      model: nn.Module,
                      state: RunnerState = None) -> Dict:
        """
        Batch handler wrapper with main statistics and device management.

        :param dct: key-value storage with input tensors
        :param model: model to predict with
        :param state: runner state
        :return: key-value storage with model predictions
        """
        dct = {key: value.to(state.device) for key, value in dct.items()}
        state.input = dct
        state.bs = len(dct[list(dct.keys())[0]])  # @TODO: fixme
        output = self._batch_handler(dct=dct, model=model)
        return output
Beispiel #3
0
    def batch_handler(self,
                      *,
                      dct: Dict,
                      model: nn.Module,
                      state: RunnerState = None) -> Dict:
        """
        Batch handler wrapper with main statistics and device management.

        :param dct: key-value storage with input tensors
        :param model: model to predict with
        :param state: runner state
        :return: key-value storage with model predictions
        """
        if isinstance(dct, (tuple, list)):
            assert len(dct) == 2
            dct = {"features": dct[0], "target": dct[1]}
        dct = {key: value.to(state.device) for key, value in dct.items()}
        if state is not None:
            state.input = dct
            state.bs = len(dct["features"])
        logits = model(dct["features"])
        output = {"logits": logits}
        return output