def pack_checkpoint( model: nn.Module = None, criterion: nn.Module = None, optimizer=None, scheduler=None, **kwargs, ): """ Packs ``model``, ``criterion``, ``optimizer``, ``scheduler`` and some extra info ``**kwargs`` to torch-based checkpoint. Args: model: torch model criterion: torch criterion optimizer: torch optimizer scheduler: torch scheduler **kwargs: some extra info to pack Returns: torch-based checkpoint with ``model_state_dict``, ``criterion_state_dict``, ``optimizer_state_dict``, ``scheduler_state_dict`` keys. """ checkpoint = kwargs if isinstance(model, dict): for key, value in model.items(): model_module = get_nn_from_ddp_module(value) checkpoint[f"model_{key}_state_dict"] = maybe_recursive_call( model_module, "state_dict") else: model_module = get_nn_from_ddp_module(model) checkpoint["model_state_dict"] = maybe_recursive_call( model_module, "state_dict") for dict2save, name2save in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"], ): if dict2save is None: continue # @TODO refactor with maybe_recursive_call (?) if isinstance(dict2save, dict): for key, value in dict2save.items(): if value is not None: state_dict2save = name2save + "_" + str(key) # checkpoint[name2save_] = value state_dict2save = state_dict2save + "_state_dict" checkpoint[state_dict2save] = value.state_dict() else: # checkpoint[name2save] = dict2save name2save = name2save + "_state_dict" checkpoint[name2save] = dict2save.state_dict() return checkpoint
def pack_checkpoint(model=None, criterion=None, optimizer=None, scheduler=None, **kwargs): """@TODO: Docs. Contribution is welcome.""" checkpoint = kwargs if isinstance(model, OrderedDict): raise NotImplementedError() else: model_module = get_nn_from_ddp_module(model) checkpoint["model_state_dict"] = maybe_recursive_call( model_module, "state_dict") for dict2save, name2save in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"], ): if dict2save is None: continue # @TODO refactor with maybe_recursive_call if isinstance(dict2save, dict): for key, value in dict2save.items(): if value is not None: state_dict2save = name2save + "_" + str(key) # checkpoint[name2save_] = value state_dict2save = state_dict2save + "_state_dict" checkpoint[state_dict2save] = value.state_dict() else: # checkpoint[name2save] = dict2save name2save = name2save + "_state_dict" checkpoint[name2save] = dict2save.state_dict() return checkpoint
def unpack_checkpoint(checkpoint, model=None, criterion=None, optimizer=None, scheduler=None): """@TODO: Docs. Contribution is welcome.""" if model is not None: model = get_nn_from_ddp_module(model) maybe_recursive_call( model, "load_state_dict", recursive_args=checkpoint["model_state_dict"], ) for dict2load, name2load in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"], ): if dict2load is None: continue if isinstance(dict2load, dict): for key, value in dict2load.items(): if value is not None: state_dict2load = f"{name2load}_{key}_state_dict" value.load_state_dict(checkpoint[state_dict2load]) else: name2load = f"{name2load}_state_dict" dict2load.load_state_dict(checkpoint[name2load])
def quantize_model( model: TorchModel, qconfig_spec: Dict = None, dtype: Union[str, Optional[torch.dtype]] = "qint8", ) -> TorchModel: """Function to quantize model weights. Args: model: model to be quantized qconfig_spec (Dict, optional): quantization config in PyTorch format. Defaults to None. dtype: Type of weights after quantization. Defaults to "qint8". Returns: Model: quantized model """ nn_model = get_nn_from_ddp_module(model) if isinstance(dtype, str): type_mapping = {"qint8": torch.qint8, "quint8": torch.quint8} try: quantized_model = quantization.quantize_dynamic( nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype]) except RuntimeError: torch.backends.quantized.engine = "qnnpack" quantized_model = quantization.quantize_dynamic( nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype]) return quantized_model
def prune_model( model: Module, pruning_fn: Union[Callable, str], amount: Union[float, int], keys_to_prune: Optional[List[str]] = None, layers_to_prune: Optional[List[str]] = None, dim: int = None, l_norm: int = None, ) -> None: """ Prune model function can be used for pruning certain tensors in model layers. Args: model: Model to be pruned. pruning_fn: Pruning function with API same as in torch.nn.utils.pruning. pruning_fn(module, name, amount). keys_to_prune: list of strings. Determines which tensor in modules will be pruned. amount: quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. layers_to_prune: list of strings - module names to be pruned. If None provided then will try to prune every module in model. dim (int, optional): if you are using structured pruning method you need to specify dimension. Defaults to None. l_norm (int, optional): if you are using ln_structured you need to specify l_norm. Defaults to None. Example: .. code-block:: python pruned_model = prune_model(model, pruning_fn="l1_unstructured") Raises: AttributeError: If layers_to_prune is not None, but there is no layers with specified name. OR ValueError: if no layers have specified keys. """ nn_model = get_nn_from_ddp_module(model) pruning_fn = get_pruning_fn(pruning_fn, l_norm=l_norm, dim=dim) keys_to_prune = keys_to_prune or ["weight"] pruned_modules = 0 for name, module in nn_model.named_modules(): try: if layers_to_prune is None or name in layers_to_prune: for key in keys_to_prune: pruning_fn(module, name=key, amount=amount) pruned_modules += 1 except AttributeError as e: if layers_to_prune is not None: raise e if pruned_modules == 0: raise ValueError(f"There is no {keys_to_prune} key in your model")
def pack_checkpoint( model: RunnerModel = None, criterion: RunnerCriterion = None, optimizer: RunnerOptimizer = None, scheduler: RunnerScheduler = None, **kwargs, ) -> Dict: """ Packs ``model``, ``criterion``, ``optimizer``, ``scheduler`` and some extra info ``**kwargs`` to torch-based checkpoint. Args: model: torch model criterion: torch criterion optimizer: torch optimizer scheduler: torch scheduler **kwargs: some extra info to pack Returns: torch-based checkpoint with ``model_state_dict``, ``criterion_state_dict``, ``optimizer_state_dict``, ``scheduler_state_dict`` keys. """ checkpoint = kwargs for dict2save, name2save in zip( [model, criterion, optimizer, scheduler], ["model", "criterion", "optimizer", "scheduler"], ): if dict2save is None: continue if isinstance(dict2save, dict): for key, value in dict2save.items(): if value is not None: state_dict2save = name2save + "_" + str( key) + "_state_dict" value = get_nn_from_ddp_module(value) checkpoint[state_dict2save] = value.state_dict() else: # checkpoint[name2save] = dict2save name2save = name2save + "_state_dict" dict2save = get_nn_from_ddp_module(dict2save) checkpoint[name2save] = dict2save.state_dict() return checkpoint
def trace_model( model: TorchModel, batch: Union[Tuple[torch.Tensor], torch.Tensor], method_name: str = "forward", ) -> jit.ScriptModule: """Traces model using runner and batch. Args: model: Model to trace batch: Batch to trace the model method_name: Model's method name that will be used as entrypoint during tracing Example: .. code-block:: python import torch from catalyst.utils import trace_model class LinModel(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = torch.nn.Linear(10, 10) self.lin2 = torch.nn.Linear(2, 10) def forward(self, inp_1, inp_2): return self.lin1(inp_1), self.lin2(inp_2) def first_only(self, inp_1): return self.lin1(inp_1) lin_model = LinModel() traced_model = trace_model( lin_model, batch=torch.randn(1, 10), method_name="first_only" ) Returns: jit.ScriptModule: Traced model """ nn_model = get_nn_from_ddp_module(model) wrapped_model = ModelForwardWrapper(model=nn_model, method_name=method_name) traced = jit.trace(wrapped_model, example_inputs=batch) return traced
def unpack_checkpoint( checkpoint: Dict, model: RunnerModel = None, criterion: RunnerCriterion = None, optimizer: RunnerOptimizer = None, scheduler: RunnerScheduler = None, ) -> None: """Load checkpoint from file and unpack the content to a model (if not None), criterion (if not None), optimizer (if not None), scheduler (if not None). Args: checkpoint: checkpoint to load model: model where should be updated state criterion: criterion where should be updated state optimizer: optimizer where should be updated state scheduler: scheduler where should be updated state """ if model is not None: model = get_nn_from_ddp_module(model) maybe_recursive_call( model, "load_state_dict", recursive_args=checkpoint["model_state_dict"], ) for dict2load, name2load in zip( [criterion, optimizer, scheduler], ["criterion", "optimizer", "scheduler"], ): if dict2load is None: continue if isinstance(dict2load, dict): for key, value in dict2load.items(): if value is not None: state_dict2load = f"{name2load}_{key}_state_dict" value.load_state_dict(checkpoint[state_dict2load]) else: name2load = f"{name2load}_state_dict" dict2load.load_state_dict(checkpoint[name2load])
def remove_reparametrization( model: Module, keys_to_prune: List[str], layers_to_prune: Optional[List[str]] = None) -> None: """ Removes pre-hooks and pruning masks from the model. Args: model: model to remove reparametrization. keys_to_prune: list of strings. Determines which tensor in modules have already been pruned. layers_to_prune: list of strings - module names have already been pruned. If None provided then will try to prune every module in model. """ nn_model = get_nn_from_ddp_module(model) for name, module in nn_model.named_modules(): try: if layers_to_prune is None or name in layers_to_prune: for key in keys_to_prune: prune.remove(module, key) except ValueError: pass
def onnx_export( model: torch.nn.Module, batch: torch.Tensor, file: str, method_name: str = "forward", input_names: Iterable = None, output_names: List[str] = None, dynamic_axes: Union[Dict[str, int], Dict[str, Dict[str, int]]] = None, opset_version: int = 9, do_constant_folding: bool = False, return_model: bool = False, verbose: bool = False, ) -> Union[None, "onnx"]: """Converts model to onnx runtime. Args: model: model batch: inputs file: file to save. Defaults to "model.onnx". method_name: Forward pass method to be converted. Defaults to "forward". input_names: name of inputs in graph. Defaults to None. output_names: name of outputs in graph. Defaults to None. dynamic_axes: axes with dynamic shapes. Defaults to None. opset_version: Defaults to 9. do_constant_folding: If True, the constant-folding optimization is applied to the model during export. Defaults to False. return_model: If True then returns onnxruntime model (onnx required). Defaults to False. verbose: if specified, we will print out a debug description of the trace being exported. Example: .. code-block:: python import torch from catalyst.utils import convert_to_onnx class LinModel(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = torch.nn.Linear(10, 10) self.lin2 = torch.nn.Linear(2, 10) def forward(self, inp_1, inp_2): return self.lin1(inp_1), self.lin2(inp_2) def first_only(self, inp_1): return self.lin1(inp_1) lin_model = LinModel() convert_to_onnx( model, batch=torch.randn((1, 10)), file="model.onnx", method_name="first_only" ) Raises: ImportError: when ``return_model`` is True, but onnx is not installed. Returns: Union[None, "onnx"]: onnx model if return_model set to True. """ nn_model = get_nn_from_ddp_module(model) if method_name != "forward": nn_model = ModelForwardWrapper(model=nn_model, method_name=method_name) torch.onnx.export( nn_model, batch, file, verbose=verbose, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, do_constant_folding=do_constant_folding, opset_version=opset_version, ) if return_model: if not SETTINGS.onnx_required: raise ImportError( "To use onnx model you should install it with ``pip install onnx``" ) return onnx.load(file)
def trace_model_from_runner( runner: "IRunner", checkpoint_name: str = None, method_name: str = "forward", mode: str = "eval", requires_grad: bool = False, opt_level: str = None, device: Device = "cpu", ) -> jit.ScriptModule: """ Traces model using created experiment and runner. Args: runner: current runner. checkpoint_name: Name of model checkpoint to use, if None traces current model from runner method_name: Model's method name that will be used as entrypoint during tracing mode: Mode for model to trace (``train`` or ``eval``) requires_grad: Flag to use grads opt_level: AMP FP16 init level device: Torch device Returns: ScriptModule: Traced model """ logdir = runner.logdir model = get_nn_from_ddp_module(runner.model) if checkpoint_name is not None: dumped_checkpoint = pack_checkpoint(model=model) checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth" checkpoint = load_checkpoint(filepath=checkpoint_path) unpack_checkpoint(checkpoint=checkpoint, model=model) # getting input names of args for method since we don't have Runner # and we don't know input_key to preprocess batch for method call fn = getattr(model, method_name) method_argnames = _get_input_argnames(fn=fn, exclude=["self"]) batch = {} for name in method_argnames: # TODO: We don't know input_keys without runner assert name in runner.input, ( "Input batch should contain the same keys as input argument " "names of `forward` function to be traced correctly") batch[name] = runner.input[name] batch = any2device(batch, device) # Dumping previous runner of the model, we will need it to restore device_dump, is_training_dump, requires_grad_dump = ( runner.device, model.training, get_requires_grad(model), ) model.to(device) # Function to run prediction on batch def predict_fn(model: Model, inputs, **kwargs): # noqa: WPS442 return model(**inputs, **kwargs) traced_model = trace_model( model=model, predict_fn=predict_fn, batch=batch, method_name=method_name, mode=mode, requires_grad=requires_grad, opt_level=opt_level, device=device, ) if checkpoint_name is not None: unpack_checkpoint(checkpoint=dumped_checkpoint, model=model) # Restore previous runner of the model getattr(model, "train" if is_training_dump else "eval")() set_requires_grad(model, requires_grad_dump) model.to(device_dump) return traced_model