def quantize_model( model: Model, qconfig_spec: Dict = None, dtype: Union[str, Optional[torch.dtype]] = "qint8" ) -> Model: """Function to quantize model weights. Args: model: model to be quantized qconfig_spec (Dict, optional): quantization config in PyTorch format. Defaults to None. dtype (Union[str, Optional[torch.dtype]], optional): Type of weights after quantization. Defaults to "qint8". Returns: Model: quantized model """ nn_model = get_nn_from_ddp_module(model) if isinstance(dtype, str): type_mapping = {"qint8": torch.qint8, "quint8": torch.quint8} try: quantized_model = quantization.quantize_dynamic( nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype] ) except RuntimeError: torch.backends.quantized.engine = "qnnpack" quantized_model = quantization.quantize_dynamic( nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype] ) return quantized_model
def prune_model( model: Module, pruning_fn: Union[Callable, str], amount: Union[float, int], keys_to_prune: Optional[List[str]] = None, layers_to_prune: Optional[List[str]] = None, dim: int = None, l_norm: int = None, ) -> None: """ Prune model function can be used for pruning certain tensors in model layers. Args: model: Model to be pruned. pruning_fn: Pruning function with API same as in torch.nn.utils.pruning. pruning_fn(module, name, amount). keys_to_prune: list of strings. Determines which tensor in modules will be pruned. amount: quantity of parameters to prune. If float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If int, it represents the absolute number of parameters to prune. layers_to_prune: list of strings - module names to be pruned. If None provided then will try to prune every module in model. dim (int, optional): if you are using structured pruning method you need to specify dimension. Defaults to None. l_norm (int, optional): if you are using ln_structured you need to specify l_norm. Defaults to None. Example: .. code-block:: python pruned_model = prune_model(model, pruning_fn="l1_unstructured") Raises: AttributeError: If layers_to_prune is not None, but there is no layers with specified name. OR ValueError: if no layers have specified keys. """ nn_model = get_nn_from_ddp_module(model) pruning_fn = get_pruning_fn(pruning_fn, l_norm=l_norm, dim=dim) keys_to_prune = keys_to_prune or ["weight"] pruned_modules = 0 for name, module in nn_model.named_modules(): try: if layers_to_prune is None or name in layers_to_prune: for key in keys_to_prune: pruning_fn(module, name=key, amount=amount) pruned_modules += 1 except AttributeError as e: if layers_to_prune is not None: raise e if pruned_modules == 0: raise ValueError(f"There is no {keys_to_prune} key in your model")
def trace_model( model: Model, batch: Union[Tuple[torch.Tensor], torch.Tensor], method_name: str = "forward", ) -> jit.ScriptModule: """Traces model using runner and batch. Args: model: Model to trace batch: Batch to trace the model method_name: Model's method name that will be used as entrypoint during tracing Example: .. code-block:: python import torch from catalyst.utils import trace_model class LinModel(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = torch.nn.Linear(10, 10) self.lin2 = torch.nn.Linear(2, 10) def forward(self, inp_1, inp_2): return self.lin1(inp_1), self.lin2(inp_2) def first_only(self, inp_1): return self.lin1(inp_1) lin_model = LinModel() traced_model = trace_model( lin_model, batch=torch.randn(1, 10), method_name="first_only" ) Returns: jit.ScriptModule: Traced model """ nn_model = get_nn_from_ddp_module(model) wrapped_model = ModelForwardWrapper(model=nn_model, method_name=method_name) traced = jit.trace(wrapped_model, example_inputs=batch) return traced
def remove_reparametrization( model: Module, keys_to_prune: List[str], layers_to_prune: Optional[List[str]] = None, ) -> None: """ Removes pre-hooks and pruning masks from the model. Args: model: model to remove reparametrization. keys_to_prune: list of strings. Determines which tensor in modules have already been pruned. layers_to_prune: list of strings - module names have already been pruned. If None provided then will try to prune every module in model. """ nn_model = get_nn_from_ddp_module(model) for name, module in nn_model.named_modules(): try: if layers_to_prune is None or name in layers_to_prune: for key in keys_to_prune: prune.remove(module, key) except ValueError: pass
def onnx_export( model: torch.nn.Module, batch: torch.Tensor, file: str, method_name: str = "forward", input_names: Iterable = None, output_names: List[str] = None, dynamic_axes: Union[Dict[str, int], Dict[str, Dict[str, int]]] = None, opset_version: int = 9, do_constant_folding: bool = False, return_model: bool = False, verbose: bool = False, ) -> Union[None, "onnx"]: """Converts model to onnx runtime. Args: model (torch.nn.Module): model batch (Tensor): inputs file (str, optional): file to save. Defaults to "model.onnx". method_name (str, optional): Forward pass method to be converted. Defaults to "forward". input_names (Iterable, optional): name of inputs in graph. Defaults to None. output_names (List[str], optional): name of outputs in graph. Defaults to None. dynamic_axes (Union[Dict[str, int], Dict[str, Dict[str, int]]], optional): axes with dynamic shapes. Defaults to None. opset_version (int, optional): Defaults to 9. do_constant_folding (bool, optional): If True, the constant-folding optimization is applied to the model during export. Defaults to False. return_model (bool, optional): If True then returns onnxruntime model (onnx required). Defaults to False. verbose (bool, default False): if specified, we will print out a debug description of the trace being exported. Example: .. code-block:: python import torch from catalyst.utils import convert_to_onnx class LinModel(torch.nn.Module): def __init__(self): super().__init__() self.lin1 = torch.nn.Linear(10, 10) self.lin2 = torch.nn.Linear(2, 10) def forward(self, inp_1, inp_2): return self.lin1(inp_1), self.lin2(inp_2) def first_only(self, inp_1): return self.lin1(inp_1) lin_model = LinModel() convert_to_onnx( model, batch=torch.randn((1, 10)), file="model.onnx", method_name="first_only" ) Raises: ImportError: when ``return_model`` is True, but onnx is not installed. Returns: Union[None, "onnx"]: onnx model if return_model set to True. """ nn_model = get_nn_from_ddp_module(model) if method_name != "forward": nn_model = ModelForwardWrapper(model=nn_model, method_name=method_name) torch.onnx.export( nn_model, batch, file, verbose=verbose, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, do_constant_folding=do_constant_folding, opset_version=opset_version, ) if return_model: if not SETTINGS.onnx_required: raise ImportError( "To use onnx model you should install it with ``pip install onnx``" ) return onnx.load(file)