def main(args, _): """Main method for ``catalyst-dl trace``.""" logdir: Path = args.logdir method_name: str = args.method checkpoint_name: str = args.checkpoint mode: str = args.mode requires_grad: bool = args.with_grad opt_level: str = args.opt_level if opt_level is not None: device = "cuda" else: device = "cpu" traced_model = trace_model_from_checkpoint( logdir, method_name, checkpoint_name=checkpoint_name, stage=args.stage, loader=args.loader, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) save_traced_model( model=traced_model, logdir=logdir, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, out_model=args.out_model, out_dir=args.out_dir, checkpoint_name=checkpoint_name, )
def trace( self, *, model: Model = None, batch: Any = None, logdir: str = None, loader: DataLoader = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, fp16: Union[Dict, bool] = None, device: Device = "cpu", predict_params: dict = None, ) -> ScriptModule: """ Traces model using Torch Jit. Args: model: model to trace batch: batch to forward through the model to trace logdir (str, optional): If specified, the result will be written to the directory loader (DataLoader, optional): if batch is not specified, the batch will be ``next(iter(loader))`` method_name: model's method name that will be traced mode: ``train`` or ``eval`` requires_grad: flag to trace with gradients fp16 (Union[Dict, bool]): fp16 settings (same as in `train`) device: Torch device or a string predict_params: additional parameters for model forward Returns: ScriptModule: traced model Raises: ValueError: if `batch` and `loader` are Nones """ if batch is None: if loader is None: raise ValueError( "If batch is not provided the loader must be specified") batch = next(iter(loader)) if model is not None: self.model = model assert self.model is not None fp16 = _resolve_bool_fp16(fp16) opt_level = None if fp16: opt_level = fp16.get("opt_level", None) if opt_level is not None: device = "cuda" elif device is None: if self.device is None: self.device = get_device() device = self.device # Dumping previous state of the model, we will need it to restore device_dump, is_training_dump, requires_grad_dump = ( self.device, self.model.training, get_requires_grad(self.model), ) self.model.to(device) # function to run prediction on batch def predict_fn(model, inputs, **kwargs): # noqa: WPS442 model_dump = self.model self.model = model result = self.predict_batch(inputs, **kwargs) self.model = model_dump return result traced_model = trace_model( model=self.model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, predict_params=predict_params, ) if logdir is not None: save_traced_model( model=traced_model, logdir=logdir, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, ) # Restore previous state of the model getattr(self.model, "train" if is_training_dump else "eval")() set_requires_grad(self.model, requires_grad_dump) self.model.to(device_dump) return traced_model