コード例 #1
0
ファイル: torch.py プロジェクト: wwchung91/catalyst
    def load_checkpoint(self, path: str):
        """Load checkpoint from path.

        Args:
            path: checkpoint file to load

        Returns:
            loaded checkpoint
        """
        return load_checkpoint(path=path)
コード例 #2
0
ファイル: runner.py プロジェクト: catalyst-team/catalyst
    def predict_loader(
        self,
        *,
        loader: DataLoader,
        model: TorchModel = None,
        engine: Union["Engine", str] = None,
        seed: int = 42,
        # extra info
        resume: str = None,
        # engine extra params,
        cpu: bool = False,
        fp16: bool = False,
    ) -> Generator:
        """
        Runs model inference on PyTorch DataLoader and returns
        python generator with model predictions from `runner.predict_batch`.

        Args:
            loader: loader to predict
            model: model to use for prediction
            engine: engine to use for prediction
            seed: random seed to use before prediction
            resume: path to checkpoint for model
            cpu: boolean flag to force CPU usage
            fp16: boolean flag to use half-precision

        Yields:
            bathes with model predictions

        .. note::
            Please follow the `minimal examples`_ sections for use cases.

            .. _`minimal examples`: https://github.com/catalyst-team/catalyst#minimal-examples  # noqa: E501, W505
        """
        self.engine = engine or get_available_engine(cpu=cpu, fp16=fp16)

        if model is not None:
            self.model = model
        assert self.model is not None

        if resume is not None:
            self.engine.wait_for_everyone()
            unwrapped_model = self.engine.unwrap_model(self.model)
            unwrapped_model.load_state_dict(load_checkpoint(resume))

        self.model = self.engine.prepare(self.model)
        maybe_recursive_call(self.model, "train", mode=False)
        loader = self.engine.prepare(loader)

        set_global_seed(seed)
        for batch in loader:
            yield self.predict_batch(batch)