예제 #1
0
파일: tracing.py 프로젝트: zkid18/catalyst
    def on_stage_end(self, runner: "IRunner") -> None:
        """
        On stage end action.

        Args:
            runner: runner for experiment
        """
        model = runner.engine.sync_device(runner.model)
        batch = tuple(runner.batch[key] for key in self.input_key)
        batch = runner.engine.sync_device(batch)
        traced_model = trace_model(model=model, batch=batch, method_name=self.method_name)
        torch.jit.save(traced_model, self.filename)
예제 #2
0
    def trace(
        self,
        *,
        model: Model = None,
        batch: Any = None,
        logdir: str = None,
        loader: DataLoader = None,
        method_name: str = "forward",
        mode: str = "eval",
        requires_grad: bool = False,
        fp16: Union[Dict, bool] = None,
        device: Device = "cpu",
        predict_params: dict = None,
    ) -> ScriptModule:
        """
        Traces model using Torch Jit.

        Args:
            model: model to trace
            batch: batch to forward through the model to trace
            logdir (str, optional): If specified,
                the result will be written to the directory
            loader (DataLoader, optional): if batch is not specified, the batch
                will be ``next(iter(loader))``
            method_name: model's method name that will be traced
            mode: ``train`` or ``eval``
            requires_grad: flag to trace with gradients
            fp16 (Union[Dict, bool]): fp16 settings (same as in `train`)
            device: Torch device or a string
            predict_params: additional parameters for model forward

        Returns:
            ScriptModule: traced model

        Raises:
            ValueError: if `batch` and `loader` are Nones
        """
        if batch is None:
            if loader is None:
                raise ValueError(
                    "If batch is not provided the loader must be specified")
            batch = next(iter(loader))

        if model is not None:
            self.model = model
        assert self.model is not None

        fp16 = _resolve_bool_fp16(fp16)
        opt_level = None
        if fp16:
            opt_level = fp16.get("opt_level", None)

        if opt_level is not None:
            device = "cuda"
        elif device is None:
            if self.device is None:
                self.device = get_device()
            device = self.device

        # Dumping previous state of the model, we will need it to restore
        device_dump, is_training_dump, requires_grad_dump = (
            self.device,
            self.model.training,
            get_requires_grad(self.model),
        )

        self.model.to(device)

        # function to run prediction on batch
        def predict_fn(model, inputs, **kwargs):  # noqa: WPS442
            model_dump = self.model
            self.model = model
            result = self.predict_batch(inputs, **kwargs)
            self.model = model_dump
            return result

        traced_model = trace_model(
            model=self.model,
            predict_fn=predict_fn,
            batch=batch,
            method_name=method_name,
            mode=mode,
            requires_grad=requires_grad,
            opt_level=opt_level,
            device=device,
            predict_params=predict_params,
        )

        if logdir is not None:
            save_traced_model(
                model=traced_model,
                logdir=logdir,
                method_name=method_name,
                mode=mode,
                requires_grad=requires_grad,
                opt_level=opt_level,
            )

        # Restore previous state of the model
        getattr(self.model, "train" if is_training_dump else "eval")()
        set_requires_grad(self.model, requires_grad_dump)
        self.model.to(device_dump)

        return traced_model