Example #1
0
def quantize_model(
    model: Model, qconfig_spec: Dict = None, dtype: Union[str, Optional[torch.dtype]] = "qint8"
) -> Model:
    """Function to quantize model weights.

    Args:
        model: model to be quantized
        qconfig_spec (Dict, optional): quantization config in PyTorch format. Defaults to None.
        dtype (Union[str, Optional[torch.dtype]], optional): Type of weights after quantization.
            Defaults to "qint8".

    Returns:
        Model: quantized model
    """
    nn_model = get_nn_from_ddp_module(model)
    if isinstance(dtype, str):
        type_mapping = {"qint8": torch.qint8, "quint8": torch.quint8}
    try:
        quantized_model = quantization.quantize_dynamic(
            nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype]
        )
    except RuntimeError:
        torch.backends.quantized.engine = "qnnpack"
        quantized_model = quantization.quantize_dynamic(
            nn_model.cpu(), qconfig_spec=qconfig_spec, dtype=type_mapping[dtype]
        )

    return quantized_model
Example #2
0
def prune_model(
    model: Module,
    pruning_fn: Union[Callable, str],
    amount: Union[float, int],
    keys_to_prune: Optional[List[str]] = None,
    layers_to_prune: Optional[List[str]] = None,
    dim: int = None,
    l_norm: int = None,
) -> None:
    """
    Prune model function can be used for pruning certain
    tensors in model layers.

    Args:
        model: Model to be pruned.
        pruning_fn: Pruning function with API same as in torch.nn.utils.pruning.
            pruning_fn(module, name, amount).
        keys_to_prune: list of strings. Determines which tensor in modules will be pruned.
        amount: quantity of parameters to prune.
            If float, should be between 0.0 and 1.0 and
            represent the fraction of parameters to prune.
            If int, it represents the absolute number
            of parameters to prune.
        layers_to_prune: list of strings - module names to be pruned.
            If None provided then will try to prune every module in model.
        dim (int, optional): if you are using structured pruning method you need
            to specify dimension. Defaults to None.
        l_norm (int, optional): if you are using ln_structured you need to specify l_norm.
            Defaults to None.

    Example:
        .. code-block:: python

           pruned_model = prune_model(model, pruning_fn="l1_unstructured")

    Raises:
        AttributeError: If layers_to_prune is not None, but there is
            no layers with specified name. OR
        ValueError: if no layers have specified keys.
    """
    nn_model = get_nn_from_ddp_module(model)
    pruning_fn = get_pruning_fn(pruning_fn, l_norm=l_norm, dim=dim)
    keys_to_prune = keys_to_prune or ["weight"]
    pruned_modules = 0
    for name, module in nn_model.named_modules():
        try:
            if layers_to_prune is None or name in layers_to_prune:
                for key in keys_to_prune:
                    pruning_fn(module, name=key, amount=amount)
                pruned_modules += 1
        except AttributeError as e:
            if layers_to_prune is not None:
                raise e

    if pruned_modules == 0:
        raise ValueError(f"There is no {keys_to_prune} key in your model")
Example #3
0
def trace_model(
    model: Model,
    batch: Union[Tuple[torch.Tensor], torch.Tensor],
    method_name: str = "forward",
) -> jit.ScriptModule:
    """Traces model using runner and batch.

    Args:
        model: Model to trace
        batch: Batch to trace the model
        method_name: Model's method name that will be
            used as entrypoint during tracing

    Example:
        .. code-block:: python

           import torch

           from catalyst.utils import trace_model

           class LinModel(torch.nn.Module):
               def __init__(self):
                   super().__init__()
                   self.lin1 = torch.nn.Linear(10, 10)
                   self.lin2 = torch.nn.Linear(2, 10)

               def forward(self, inp_1, inp_2):
                   return self.lin1(inp_1), self.lin2(inp_2)

               def first_only(self, inp_1):
                   return self.lin1(inp_1)

           lin_model = LinModel()
           traced_model = trace_model(
               lin_model, batch=torch.randn(1, 10), method_name="first_only"
           )

    Returns:
        jit.ScriptModule: Traced model
    """
    nn_model = get_nn_from_ddp_module(model)
    wrapped_model = ModelForwardWrapper(model=nn_model,
                                        method_name=method_name)
    traced = jit.trace(wrapped_model, example_inputs=batch)
    return traced
Example #4
0
def remove_reparametrization(
    model: Module, keys_to_prune: List[str], layers_to_prune: Optional[List[str]] = None,
) -> None:
    """
    Removes pre-hooks and pruning masks from the model.

    Args:
        model: model to remove reparametrization.
        keys_to_prune: list of strings. Determines
            which tensor in modules have already been pruned.
        layers_to_prune: list of strings - module names
            have already been pruned.
            If None provided then will try to prune every module in
            model.
    """
    nn_model = get_nn_from_ddp_module(model)
    for name, module in nn_model.named_modules():
        try:
            if layers_to_prune is None or name in layers_to_prune:
                for key in keys_to_prune:
                    prune.remove(module, key)
        except ValueError:
            pass
Example #5
0
def onnx_export(
    model: torch.nn.Module,
    batch: torch.Tensor,
    file: str,
    method_name: str = "forward",
    input_names: Iterable = None,
    output_names: List[str] = None,
    dynamic_axes: Union[Dict[str, int], Dict[str, Dict[str, int]]] = None,
    opset_version: int = 9,
    do_constant_folding: bool = False,
    return_model: bool = False,
    verbose: bool = False,
) -> Union[None, "onnx"]:
    """Converts model to onnx runtime.

    Args:
        model (torch.nn.Module): model
        batch (Tensor): inputs
        file (str, optional): file to save. Defaults to "model.onnx".
        method_name (str, optional): Forward pass method to be converted. Defaults to "forward".
        input_names (Iterable, optional): name of inputs in graph. Defaults to None.
        output_names (List[str], optional): name of outputs in graph. Defaults to None.
        dynamic_axes (Union[Dict[str, int], Dict[str, Dict[str, int]]], optional): axes
            with dynamic shapes. Defaults to None.
        opset_version (int, optional): Defaults to 9.
        do_constant_folding (bool, optional): If True, the constant-folding optimization
            is applied to the model during export. Defaults to False.
        return_model (bool, optional): If True then returns onnxruntime model (onnx required).
            Defaults to False.
        verbose (bool, default False): if specified, we will print out a debug
            description of the trace being exported.

    Example:
        .. code-block:: python

           import torch

           from catalyst.utils import convert_to_onnx

           class LinModel(torch.nn.Module):
               def __init__(self):
                   super().__init__()
                   self.lin1 = torch.nn.Linear(10, 10)
                   self.lin2 = torch.nn.Linear(2, 10)

               def forward(self, inp_1, inp_2):
                   return self.lin1(inp_1), self.lin2(inp_2)

               def first_only(self, inp_1):
                   return self.lin1(inp_1)

           lin_model = LinModel()
           convert_to_onnx(
               model, batch=torch.randn((1, 10)), file="model.onnx", method_name="first_only"
           )

    Raises:
        ImportError: when ``return_model`` is True, but onnx is not installed.

    Returns:
        Union[None, "onnx"]: onnx model if return_model set to True.
    """
    nn_model = get_nn_from_ddp_module(model)
    if method_name != "forward":
        nn_model = ModelForwardWrapper(model=nn_model, method_name=method_name)
    torch.onnx.export(
        nn_model,
        batch,
        file,
        verbose=verbose,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        do_constant_folding=do_constant_folding,
        opset_version=opset_version,
    )
    if return_model:
        if not SETTINGS.onnx_required:
            raise ImportError(
                "To use onnx model you should install it with ``pip install onnx``"
            )
        return onnx.load(file)