def _trace(self, runner: IRunner): """ Performing model tracing on epoch end if condition metric is improved. Args: runner (IRunner): Current runner """ if self.opt_level is not None: device = "cuda" else: device = "cpu" # the only case we need to restore model from previous checkpoint # is when we need to trace best model only once in the end of stage checkpoint_name_to_restore = None if self.do_once and self.mode == "best": checkpoint_name_to_restore = "best" traced_model = trace_model_from_runner( runner=runner, checkpoint_name=checkpoint_name_to_restore, method_name=self.method_name, mode=self.trace_mode, requires_grad=self.requires_grad, opt_level=self.opt_level, device=device, ) save_traced_model( model=traced_model, logdir=runner.logdir, checkpoint_name=self.mode, method_name=self.method_name, mode=self.trace_mode, requires_grad=self.requires_grad, opt_level=self.opt_level, out_model=self.out_model, out_dir=self.out_dir, )
def main(args, _): """Main method for ``catalyst-dl trace``.""" logdir: Path = args.logdir method_name: str = args.method checkpoint_name: str = args.checkpoint mode: str = args.mode requires_grad: bool = args.with_grad opt_level: str = args.opt_level if opt_level is not None: device = "cuda" else: device = "cpu" traced_model = trace_model_from_checkpoint( logdir, method_name, checkpoint_name=checkpoint_name, stage=args.stage, loader=args.loader, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) save_traced_model( model=traced_model, logdir=logdir, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, out_model=args.out_model, out_dir=args.out_dir, checkpoint_name=checkpoint_name, )
def trace( self, *, model: Model = None, batch: Any = None, logdir: str = None, loader: DataLoader = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, fp16: Union[Dict, bool] = None, device: Device = "cpu", predict_params: dict = None, ) -> ScriptModule: """ Traces model using Torch Jit. Args: model (Model): model to trace batch: batch to forward through the model to trace logdir (str, optional): If specified, the result will be written to the directory loader (DataLoader, optional): if batch is not specified, the batch will be ``next(iter(loader))`` method_name (str): model's method name that will be traced mode (str): ``train`` or ``eval`` requires_grad (bool): flag to trace with gradients fp16 (Union[Dict, bool]): If not None, then sets tracing params to FP16 device (Device): Torch device or a string predict_params (dict): additional parameters for model forward """ if batch is None: if loader is None: raise ValueError( "If batch is not provided the loader must be specified" ) batch = next(iter(loader)) if model is not None: self.model = model assert self.model is not None if isinstance(fp16, bool) and fp16: opt_level = "O1" elif isinstance(fp16, bool) and not fp16: opt_level = None elif isinstance(fp16, dict): opt_level = fp16["opt_level"] else: opt_level = fp16 if opt_level is not None: device = "cuda" elif device is None: if self.device is None: self.device = utils.get_device() device = self.device # Dumping previous state of the model, we will need it to restore _device, _is_training, _requires_grad = ( self.device, self.model.training, utils.get_requires_grad(self.model), ) self.model.to(device) # function to run prediction on batch def predict_fn(model, inputs, **kwargs): _model = self.model self.model = model result = self.predict_batch(inputs, **kwargs) self.model = _model return result traced_model = utils.trace_model( model=self.model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, predict_params=predict_params, ) if logdir is not None: utils.save_traced_model( model=traced_model, logdir=logdir, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, ) # Restore previous state of the model getattr(self.model, "train" if _is_training else "eval")() utils.set_requires_grad(self.model, _requires_grad) self.model.to(_device) return traced_model