コード例 #1
0
def main(args, _):
    """Main method for ``catalyst-dl trace``."""
    logdir: Path = args.logdir
    method_name: str = args.method
    checkpoint_name: str = args.checkpoint
    mode: str = args.mode
    requires_grad: bool = args.with_grad
    opt_level: str = args.opt_level

    if opt_level is not None:
        device = "cuda"
    else:
        device = "cpu"

    traced_model = trace_model_from_checkpoint(
        logdir,
        method_name,
        checkpoint_name=checkpoint_name,
        stage=args.stage,
        loader=args.loader,
        mode=mode,
        requires_grad=requires_grad,
        opt_level=opt_level,
        device=device,
    )

    save_traced_model(
        model=traced_model,
        logdir=logdir,
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
        opt_level=opt_level,
        out_model=args.out_model,
        out_dir=args.out_dir,
        checkpoint_name=checkpoint_name,
    )
コード例 #2
0
    def trace(
        self,
        *,
        model: Model = None,
        batch: Any = None,
        logdir: str = None,
        loader: DataLoader = None,
        method_name: str = "forward",
        mode: str = "eval",
        requires_grad: bool = False,
        fp16: Union[Dict, bool] = None,
        device: Device = "cpu",
        predict_params: dict = None,
    ) -> ScriptModule:
        """
        Traces model using Torch Jit.

        Args:
            model: model to trace
            batch: batch to forward through the model to trace
            logdir (str, optional): If specified,
                the result will be written to the directory
            loader (DataLoader, optional): if batch is not specified, the batch
                will be ``next(iter(loader))``
            method_name: model's method name that will be traced
            mode: ``train`` or ``eval``
            requires_grad: flag to trace with gradients
            fp16 (Union[Dict, bool]): fp16 settings (same as in `train`)
            device: Torch device or a string
            predict_params: additional parameters for model forward

        Returns:
            ScriptModule: traced model

        Raises:
            ValueError: if `batch` and `loader` are Nones
        """
        if batch is None:
            if loader is None:
                raise ValueError(
                    "If batch is not provided the loader must be specified")
            batch = next(iter(loader))

        if model is not None:
            self.model = model
        assert self.model is not None

        fp16 = _resolve_bool_fp16(fp16)
        opt_level = None
        if fp16:
            opt_level = fp16.get("opt_level", None)

        if opt_level is not None:
            device = "cuda"
        elif device is None:
            if self.device is None:
                self.device = get_device()
            device = self.device

        # Dumping previous state of the model, we will need it to restore
        device_dump, is_training_dump, requires_grad_dump = (
            self.device,
            self.model.training,
            get_requires_grad(self.model),
        )

        self.model.to(device)

        # function to run prediction on batch
        def predict_fn(model, inputs, **kwargs):  # noqa: WPS442
            model_dump = self.model
            self.model = model
            result = self.predict_batch(inputs, **kwargs)
            self.model = model_dump
            return result

        traced_model = trace_model(
            model=self.model,
            predict_fn=predict_fn,
            batch=batch,
            method_name=method_name,
            mode=mode,
            requires_grad=requires_grad,
            opt_level=opt_level,
            device=device,
            predict_params=predict_params,
        )

        if logdir is not None:
            save_traced_model(
                model=traced_model,
                logdir=logdir,
                method_name=method_name,
                mode=mode,
                requires_grad=requires_grad,
                opt_level=opt_level,
            )

        # Restore previous state of the model
        getattr(self.model, "train" if is_training_dump else "eval")()
        set_requires_grad(self.model, requires_grad_dump)
        self.model.to(device_dump)

        return traced_model