Пример #1
0
def pack_checkpoint(
    model: nn.Module = None,
    criterion: nn.Module = None,
    optimizer=None,
    scheduler=None,
    **kwargs,
):
    """
    Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
    and some extra info ``**kwargs`` to torch-based checkpoint.

    Args:
        model: torch model
        criterion: torch criterion
        optimizer: torch optimizer
        scheduler: torch scheduler
        **kwargs: some extra info to pack

    Returns:
        torch-based checkpoint with ``model_state_dict``,
        ``criterion_state_dict``, ``optimizer_state_dict``,
        ``scheduler_state_dict`` keys.
    """
    checkpoint = kwargs

    if isinstance(model, dict):
        for key, value in model.items():
            model_module = get_nn_from_ddp_module(value)
            checkpoint[f"model_{key}_state_dict"] = maybe_recursive_call(
                model_module, "state_dict")
    else:
        model_module = get_nn_from_ddp_module(model)
        checkpoint["model_state_dict"] = maybe_recursive_call(
            model_module, "state_dict")

    for dict2save, name2save in zip(
        [criterion, optimizer, scheduler],
        ["criterion", "optimizer", "scheduler"],
    ):
        if dict2save is None:
            continue
        # @TODO refactor with maybe_recursive_call (?)
        if isinstance(dict2save, dict):
            for key, value in dict2save.items():
                if value is not None:
                    state_dict2save = name2save + "_" + str(key)
                    # checkpoint[name2save_] = value
                    state_dict2save = state_dict2save + "_state_dict"
                    checkpoint[state_dict2save] = value.state_dict()
        else:
            # checkpoint[name2save] = dict2save
            name2save = name2save + "_state_dict"
            checkpoint[name2save] = dict2save.state_dict()

    return checkpoint
Пример #2
0
def pack_checkpoint(model=None,
                    criterion=None,
                    optimizer=None,
                    scheduler=None,
                    **kwargs):
    """@TODO: Docs. Contribution is welcome."""
    checkpoint = kwargs

    if isinstance(model, OrderedDict):
        raise NotImplementedError()
    else:
        model_module = get_nn_from_ddp_module(model)
        checkpoint["model_state_dict"] = maybe_recursive_call(
            model_module, "state_dict")

    for dict2save, name2save in zip(
        [criterion, optimizer, scheduler],
        ["criterion", "optimizer", "scheduler"],
    ):
        if dict2save is None:
            continue
        # @TODO refactor with maybe_recursive_call
        if isinstance(dict2save, dict):
            for key, value in dict2save.items():
                if value is not None:
                    state_dict2save = name2save + "_" + str(key)
                    # checkpoint[name2save_] = value
                    state_dict2save = state_dict2save + "_state_dict"
                    checkpoint[state_dict2save] = value.state_dict()
        else:
            # checkpoint[name2save] = dict2save
            name2save = name2save + "_state_dict"
            checkpoint[name2save] = dict2save.state_dict()

    return checkpoint
def unpack_checkpoint(checkpoint,
                      model=None,
                      criterion=None,
                      optimizer=None,
                      scheduler=None):
    """@TODO: Docs. Contribution is welcome."""
    if model is not None:
        model = get_nn_from_ddp_module(model)
        maybe_recursive_call(
            model,
            "load_state_dict",
            recursive_args=checkpoint["model_state_dict"],
        )

    for dict2load, name2load in zip(
        [criterion, optimizer, scheduler],
        ["criterion", "optimizer", "scheduler"],
    ):
        if dict2load is None:
            continue

        if isinstance(dict2load, dict):
            for key, value in dict2load.items():
                if value is not None:
                    state_dict2load = f"{name2load}_{key}_state_dict"
                    value.load_state_dict(checkpoint[state_dict2load])
        else:
            name2load = f"{name2load}_state_dict"
            dict2load.load_state_dict(checkpoint[name2load])
Пример #4
0
def quantize_model(
    model: TorchModel,
    qconfig_spec: Dict = None,
    dtype: Union[str, Optional[torch.dtype]] = "qint8",
) -> TorchModel:
    """Function to quantize model weights.

    Args:
        model: model to be quantized
        qconfig_spec (Dict, optional): quantization config in PyTorch format.
            Defaults to None.
        dtype: Type of weights after quantization.
            Defaults to "qint8".

    Returns:
        Model: quantized model
    """
    nn_model = get_nn_from_ddp_module(model)
    if isinstance(dtype, str):
        type_mapping = {"qint8": torch.qint8, "quint8": torch.quint8}
    try:
        quantized_model = quantization.quantize_dynamic(
            nn_model.cpu(),
            qconfig_spec=qconfig_spec,
            dtype=type_mapping[dtype])
    except RuntimeError:
        torch.backends.quantized.engine = "qnnpack"
        quantized_model = quantization.quantize_dynamic(
            nn_model.cpu(),
            qconfig_spec=qconfig_spec,
            dtype=type_mapping[dtype])

    return quantized_model
Пример #5
0
def prune_model(
    model: Module,
    pruning_fn: Union[Callable, str],
    amount: Union[float, int],
    keys_to_prune: Optional[List[str]] = None,
    layers_to_prune: Optional[List[str]] = None,
    dim: int = None,
    l_norm: int = None,
) -> None:
    """
    Prune model function can be used for pruning certain
    tensors in model layers.

    Args:
        model: Model to be pruned.
        pruning_fn: Pruning function with API same as in torch.nn.utils.pruning.
            pruning_fn(module, name, amount).
        keys_to_prune: list of strings.
            Determines which tensor in modules will be pruned.
        amount: quantity of parameters to prune.
            If float, should be between 0.0 and 1.0 and
            represent the fraction of parameters to prune.
            If int, it represents the absolute number
            of parameters to prune.
        layers_to_prune: list of strings - module names to be pruned.
            If None provided then will try to prune every module in model.
        dim (int, optional): if you are using structured pruning method you need
            to specify dimension. Defaults to None.
        l_norm (int, optional): if you are using
            ln_structured you need to specify l_norm. Defaults to None.

    Example:
        .. code-block:: python

           pruned_model = prune_model(model, pruning_fn="l1_unstructured")

    Raises:
        AttributeError: If layers_to_prune is not None, but there is
            no layers with specified name. OR
        ValueError: if no layers have specified keys.
    """
    nn_model = get_nn_from_ddp_module(model)
    pruning_fn = get_pruning_fn(pruning_fn, l_norm=l_norm, dim=dim)
    keys_to_prune = keys_to_prune or ["weight"]
    pruned_modules = 0
    for name, module in nn_model.named_modules():
        try:
            if layers_to_prune is None or name in layers_to_prune:
                for key in keys_to_prune:
                    pruning_fn(module, name=key, amount=amount)
                pruned_modules += 1
        except AttributeError as e:
            if layers_to_prune is not None:
                raise e

    if pruned_modules == 0:
        raise ValueError(f"There is no {keys_to_prune} key in your model")
Пример #6
0
def pack_checkpoint(
    model: RunnerModel = None,
    criterion: RunnerCriterion = None,
    optimizer: RunnerOptimizer = None,
    scheduler: RunnerScheduler = None,
    **kwargs,
) -> Dict:
    """
    Packs ``model``, ``criterion``, ``optimizer``, ``scheduler``
    and some extra info ``**kwargs`` to torch-based checkpoint.

    Args:
        model: torch model
        criterion: torch criterion
        optimizer: torch optimizer
        scheduler: torch scheduler
        **kwargs: some extra info to pack

    Returns:
        torch-based checkpoint with ``model_state_dict``,
        ``criterion_state_dict``, ``optimizer_state_dict``,
        ``scheduler_state_dict`` keys.
    """
    checkpoint = kwargs

    for dict2save, name2save in zip(
        [model, criterion, optimizer, scheduler],
        ["model", "criterion", "optimizer", "scheduler"],
    ):
        if dict2save is None:
            continue
        if isinstance(dict2save, dict):
            for key, value in dict2save.items():
                if value is not None:
                    state_dict2save = name2save + "_" + str(
                        key) + "_state_dict"
                    value = get_nn_from_ddp_module(value)
                    checkpoint[state_dict2save] = value.state_dict()
        else:
            # checkpoint[name2save] = dict2save
            name2save = name2save + "_state_dict"
            dict2save = get_nn_from_ddp_module(dict2save)
            checkpoint[name2save] = dict2save.state_dict()
    return checkpoint
Пример #7
0
def trace_model(
    model: TorchModel,
    batch: Union[Tuple[torch.Tensor], torch.Tensor],
    method_name: str = "forward",
) -> jit.ScriptModule:
    """Traces model using runner and batch.

    Args:
        model: Model to trace
        batch: Batch to trace the model
        method_name: Model's method name that will be
            used as entrypoint during tracing

    Example:
        .. code-block:: python

           import torch

           from catalyst.utils import trace_model

           class LinModel(torch.nn.Module):
               def __init__(self):
                   super().__init__()
                   self.lin1 = torch.nn.Linear(10, 10)
                   self.lin2 = torch.nn.Linear(2, 10)

               def forward(self, inp_1, inp_2):
                   return self.lin1(inp_1), self.lin2(inp_2)

               def first_only(self, inp_1):
                   return self.lin1(inp_1)

           lin_model = LinModel()
           traced_model = trace_model(
               lin_model, batch=torch.randn(1, 10), method_name="first_only"
           )

    Returns:
        jit.ScriptModule: Traced model
    """
    nn_model = get_nn_from_ddp_module(model)
    wrapped_model = ModelForwardWrapper(model=nn_model, method_name=method_name)
    traced = jit.trace(wrapped_model, example_inputs=batch)
    return traced
Пример #8
0
def unpack_checkpoint(
    checkpoint: Dict,
    model: RunnerModel = None,
    criterion: RunnerCriterion = None,
    optimizer: RunnerOptimizer = None,
    scheduler: RunnerScheduler = None,
) -> None:
    """Load checkpoint from file and unpack the content to a model
    (if not None), criterion (if not None), optimizer (if not None),
    scheduler (if not None).

    Args:
        checkpoint: checkpoint to load
        model: model where should be updated state
        criterion: criterion where should be updated state
        optimizer: optimizer where should be updated state
        scheduler: scheduler where should be updated state
    """
    if model is not None:
        model = get_nn_from_ddp_module(model)
        maybe_recursive_call(
            model,
            "load_state_dict",
            recursive_args=checkpoint["model_state_dict"],
        )

    for dict2load, name2load in zip(
        [criterion, optimizer, scheduler],
        ["criterion", "optimizer", "scheduler"],
    ):
        if dict2load is None:
            continue

        if isinstance(dict2load, dict):
            for key, value in dict2load.items():
                if value is not None:
                    state_dict2load = f"{name2load}_{key}_state_dict"
                    value.load_state_dict(checkpoint[state_dict2load])
        else:
            name2load = f"{name2load}_state_dict"
            dict2load.load_state_dict(checkpoint[name2load])
Пример #9
0
def remove_reparametrization(
        model: Module,
        keys_to_prune: List[str],
        layers_to_prune: Optional[List[str]] = None) -> None:
    """
    Removes pre-hooks and pruning masks from the model.

    Args:
        model: model to remove reparametrization.
        keys_to_prune: list of strings. Determines
            which tensor in modules have already been pruned.
        layers_to_prune: list of strings - module names
            have already been pruned.
            If None provided then will try to prune every module in
            model.
    """
    nn_model = get_nn_from_ddp_module(model)
    for name, module in nn_model.named_modules():
        try:
            if layers_to_prune is None or name in layers_to_prune:
                for key in keys_to_prune:
                    prune.remove(module, key)
        except ValueError:
            pass
Пример #10
0
def onnx_export(
    model: torch.nn.Module,
    batch: torch.Tensor,
    file: str,
    method_name: str = "forward",
    input_names: Iterable = None,
    output_names: List[str] = None,
    dynamic_axes: Union[Dict[str, int], Dict[str, Dict[str, int]]] = None,
    opset_version: int = 9,
    do_constant_folding: bool = False,
    return_model: bool = False,
    verbose: bool = False,
) -> Union[None, "onnx"]:
    """Converts model to onnx runtime.

    Args:
        model: model
        batch: inputs
        file: file to save. Defaults to "model.onnx".
        method_name: Forward pass method to be converted. Defaults to "forward".
        input_names: name of inputs in graph. Defaults to None.
        output_names: name of outputs in graph. Defaults to None.
        dynamic_axes: axes
            with dynamic shapes. Defaults to None.
        opset_version: Defaults to 9.
        do_constant_folding: If True, the constant-folding optimization
            is applied to the model during export. Defaults to False.
        return_model: If True then returns onnxruntime model (onnx required).
            Defaults to False.
        verbose: if specified, we will print out a debug
            description of the trace being exported.

    Example:
        .. code-block:: python

           import torch

           from catalyst.utils import convert_to_onnx

           class LinModel(torch.nn.Module):
               def __init__(self):
                   super().__init__()
                   self.lin1 = torch.nn.Linear(10, 10)
                   self.lin2 = torch.nn.Linear(2, 10)

               def forward(self, inp_1, inp_2):
                   return self.lin1(inp_1), self.lin2(inp_2)

               def first_only(self, inp_1):
                   return self.lin1(inp_1)

           lin_model = LinModel()
           convert_to_onnx(
               model, batch=torch.randn((1, 10)),
               file="model.onnx",
               method_name="first_only"
           )

    Raises:
        ImportError: when ``return_model`` is True, but onnx is not installed.

    Returns:
        Union[None, "onnx"]: onnx model if return_model set to True.
    """
    nn_model = get_nn_from_ddp_module(model)
    if method_name != "forward":
        nn_model = ModelForwardWrapper(model=nn_model, method_name=method_name)
    torch.onnx.export(
        nn_model,
        batch,
        file,
        verbose=verbose,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        do_constant_folding=do_constant_folding,
        opset_version=opset_version,
    )
    if return_model:
        if not SETTINGS.onnx_required:
            raise ImportError(
                "To use onnx model you should install it with ``pip install onnx``"
            )
        return onnx.load(file)
Пример #11
0
def trace_model_from_runner(
    runner: "IRunner",
    checkpoint_name: str = None,
    method_name: str = "forward",
    mode: str = "eval",
    requires_grad: bool = False,
    opt_level: str = None,
    device: Device = "cpu",
) -> jit.ScriptModule:
    """
    Traces model using created experiment and runner.

    Args:
        runner: current runner.
        checkpoint_name: Name of model checkpoint to use, if None
            traces current model from runner
        method_name: Model's method name that will be
            used as entrypoint during tracing
        mode: Mode for model to trace (``train`` or ``eval``)
        requires_grad: Flag to use grads
        opt_level: AMP FP16 init level
        device: Torch device

    Returns:
        ScriptModule: Traced model
    """
    logdir = runner.logdir
    model = get_nn_from_ddp_module(runner.model)

    if checkpoint_name is not None:
        dumped_checkpoint = pack_checkpoint(model=model)
        checkpoint_path = logdir / "checkpoints" / f"{checkpoint_name}.pth"
        checkpoint = load_checkpoint(filepath=checkpoint_path)
        unpack_checkpoint(checkpoint=checkpoint, model=model)

    # getting input names of args for method since we don't have Runner
    # and we don't know input_key to preprocess batch for method call
    fn = getattr(model, method_name)
    method_argnames = _get_input_argnames(fn=fn, exclude=["self"])

    batch = {}
    for name in method_argnames:
        # TODO: We don't know input_keys without runner
        assert name in runner.input, (
            "Input batch should contain the same keys as input argument "
            "names of `forward` function to be traced correctly")
        batch[name] = runner.input[name]

    batch = any2device(batch, device)

    # Dumping previous runner of the model, we will need it to restore
    device_dump, is_training_dump, requires_grad_dump = (
        runner.device,
        model.training,
        get_requires_grad(model),
    )

    model.to(device)

    # Function to run prediction on batch
    def predict_fn(model: Model, inputs, **kwargs):  # noqa: WPS442
        return model(**inputs, **kwargs)

    traced_model = trace_model(
        model=model,
        predict_fn=predict_fn,
        batch=batch,
        method_name=method_name,
        mode=mode,
        requires_grad=requires_grad,
        opt_level=opt_level,
        device=device,
    )

    if checkpoint_name is not None:
        unpack_checkpoint(checkpoint=dumped_checkpoint, model=model)

    # Restore previous runner of the model
    getattr(model, "train" if is_training_dump else "eval")()
    set_requires_grad(model, requires_grad_dump)
    model.to(device_dump)

    return traced_model