def main(args, _): """ Main method for `catalyst-dl trace` """ logdir: Path = args.logdir method_name: str = args.method checkpoint_name: str = args.checkpoint mode: str = args.mode requires_grad: bool = args.with_grad opt_level: str = args.opt_level if opt_level is not None: opt_level = opt_level device = "cuda" else: opt_level = None device = "cpu" traced = trace_model_from_checkpoint( logdir, method_name, checkpoint_name=checkpoint_name, stage=args.stage, loader=args.loader, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) if args.out_model is None: file_name = trace.get_trace_name( method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, additional_string=checkpoint_name, ) output: Path = args.out_dir if output is None: output: Path = logdir / "trace" output.mkdir(exist_ok=True, parents=True) out_model = str(output / file_name) else: out_model = str(args.out_model) torch.jit.save(traced, out_model)
def trace( self, model: Model = None, batch=None, logdir: str = None, loader: DataLoader = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, fp16: Union[Dict, bool] = None, device: Device = "cpu", predict_params: dict = None, ) -> ScriptModule: """ Traces model using Torch Jit Args: model (Model): model to trace batch: batch to forward through the model to trace logdir (str, optional): If specified, the result will be written to the directory loader (DataLoader, optional): if batch is not specified, the batch will be ``next(iter(loader))`` method_name (str): model's method name that will be traced mode (str): ``train`` or ``eval`` requires_grad (bool): flag to trace with gradients fp16 (Union[Dict, bool]): If not None, then sets tracing params to FP16 deivice (Device): Torch deivice or a string predict_params (dict): additional parameters for model forward """ if batch is None: if loader is None: raise ValueError( "If batch is not provided the loader must be specified" ) batch = next(iter(loader)) if model is not None: self.model = model if isinstance(fp16, bool) and fp16: opt_level = "O1" elif isinstance(fp16, bool) and not fp16: opt_level = None elif isinstance(fp16, dict): opt_level = fp16["opt_level"] else: opt_level = fp16 if opt_level is not None: device = "cuda" elif device is None: if self.device is None: self.device = utils.get_device() device = self.device result = trace.trace_model( model=self.model, runner=self, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, predict_params=predict_params, ) if logdir is not None: filename = trace.get_trace_name( method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, ) logdir = Path(logdir) output: Path = logdir / "trace" output.mkdir(exist_ok=True, parents=True) out_model = str(output / filename) torch.jit.save(result, out_model) return result