Exemplo n.º 1
0
def main(args, _):
    """
    Main method for `catalyst-dl trace`
    """
    logdir: Path = args.logdir
    method_name: str = args.method
    checkpoint_name: str = args.checkpoint
    mode: str = args.mode
    requires_grad: bool = args.with_grad
    opt_level: str = args.opt_level

    if opt_level is not None:
        opt_level = opt_level
        device = "cuda"
    else:
        opt_level = None
        device = "cpu"

    traced = trace_model_from_checkpoint(
        logdir,
        method_name,
        checkpoint_name=checkpoint_name,
        stage=args.stage,
        loader=args.loader,
        mode=mode,
        requires_grad=requires_grad,
        opt_level=opt_level,
        device=device,
    )

    if args.out_model is None:
        file_name = trace.get_trace_name(
            method_name=method_name,
            mode=mode,
            requires_grad=requires_grad,
            opt_level=opt_level,
            additional_string=checkpoint_name,
        )

        output: Path = args.out_dir
        if output is None:
            output: Path = logdir / "trace"
        output.mkdir(exist_ok=True, parents=True)

        out_model = str(output / file_name)
    else:
        out_model = str(args.out_model)

    torch.jit.save(traced, out_model)
Exemplo n.º 2
0
    def trace(
        self,
        model: Model = None,
        batch=None,
        logdir: str = None,
        loader: DataLoader = None,
        method_name: str = "forward",
        mode: str = "eval",
        requires_grad: bool = False,
        fp16: Union[Dict, bool] = None,
        device: Device = "cpu",
        predict_params: dict = None,
    ) -> ScriptModule:
        """
        Traces model using Torch Jit

        Args:
            model (Model): model to trace
            batch: batch to forward through the model to trace
            logdir (str, optional): If specified,
                the result will be written to the directory
            loader (DataLoader, optional): if batch is not specified, the batch
                will be ``next(iter(loader))``
            method_name (str): model's method name that will be traced
            mode (str): ``train`` or ``eval``
            requires_grad (bool): flag to trace with gradients
            fp16 (Union[Dict, bool]): If not None, then sets
                tracing params to FP16
            deivice (Device): Torch deivice or a string
            predict_params (dict): additional parameters for model forward
        """
        if batch is None:
            if loader is None:
                raise ValueError(
                    "If batch is not provided the loader must be specified"
                )
            batch = next(iter(loader))

        if model is not None:
            self.model = model

        if isinstance(fp16, bool) and fp16:
            opt_level = "O1"
        elif isinstance(fp16, bool) and not fp16:
            opt_level = None
        elif isinstance(fp16, dict):
            opt_level = fp16["opt_level"]
        else:
            opt_level = fp16

        if opt_level is not None:
            device = "cuda"
        elif device is None:
            if self.device is None:
                self.device = utils.get_device()
            device = self.device

        result = trace.trace_model(
            model=self.model,
            runner=self,
            batch=batch,
            method_name=method_name,
            mode=mode,
            requires_grad=requires_grad,
            opt_level=opt_level,
            device=device,
            predict_params=predict_params,
        )

        if logdir is not None:
            filename = trace.get_trace_name(
                method_name=method_name,
                mode=mode,
                requires_grad=requires_grad,
                opt_level=opt_level,
            )

            logdir = Path(logdir)
            output: Path = logdir / "trace"
            output.mkdir(exist_ok=True, parents=True)

            out_model = str(output / filename)

            torch.jit.save(result, out_model)

        return result