def on_stage_end(self, runner: "IRunner") -> None: """ On stage end action. Args: runner: runner for experiment """ model = runner.engine.sync_device(runner.model) batch = tuple(runner.batch[key] for key in self.input_key) batch = runner.engine.sync_device(batch) traced_model = trace_model(model=model, batch=batch, method_name=self.method_name) torch.jit.save(traced_model, self.filename)
def trace( self, *, model: Model = None, batch: Any = None, logdir: str = None, loader: DataLoader = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, fp16: Union[Dict, bool] = None, device: Device = "cpu", predict_params: dict = None, ) -> ScriptModule: """ Traces model using Torch Jit. Args: model: model to trace batch: batch to forward through the model to trace logdir (str, optional): If specified, the result will be written to the directory loader (DataLoader, optional): if batch is not specified, the batch will be ``next(iter(loader))`` method_name: model's method name that will be traced mode: ``train`` or ``eval`` requires_grad: flag to trace with gradients fp16 (Union[Dict, bool]): fp16 settings (same as in `train`) device: Torch device or a string predict_params: additional parameters for model forward Returns: ScriptModule: traced model Raises: ValueError: if `batch` and `loader` are Nones """ if batch is None: if loader is None: raise ValueError( "If batch is not provided the loader must be specified") batch = next(iter(loader)) if model is not None: self.model = model assert self.model is not None fp16 = _resolve_bool_fp16(fp16) opt_level = None if fp16: opt_level = fp16.get("opt_level", None) if opt_level is not None: device = "cuda" elif device is None: if self.device is None: self.device = get_device() device = self.device # Dumping previous state of the model, we will need it to restore device_dump, is_training_dump, requires_grad_dump = ( self.device, self.model.training, get_requires_grad(self.model), ) self.model.to(device) # function to run prediction on batch def predict_fn(model, inputs, **kwargs): # noqa: WPS442 model_dump = self.model self.model = model result = self.predict_batch(inputs, **kwargs) self.model = model_dump return result traced_model = trace_model( model=self.model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, predict_params=predict_params, ) if logdir is not None: save_traced_model( model=traced_model, logdir=logdir, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, ) # Restore previous state of the model getattr(self.model, "train" if is_training_dump else "eval")() set_requires_grad(self.model, requires_grad_dump) self.model.to(device_dump) return traced_model