Exemplo n.º 1
    def _inject_new_class(module: Module) -> None:
        r"""Sets up a module to be parametrized.

        This works by substituting the class of the module by a class
        that extends it to be able to inject a property

            module (nn.Module): module into which to inject the property
        cls = module.__class__

        def getstate(self):
            raise RuntimeError(
                "Serialization of parametrized modules is only "
                "supported through state_dict(). See:\n"

        param_cls = type(
                "__getstate__": getstate,

        module.__class__ = param_cls
Exemplo n.º 2
def _inject_new_class(module: Module) -> None:
    r"""Sets up a module to be parametrized.

    This works by substituting the class of the module by a class
    that extends it to be able to inject a property

        module (nn.Module): module into which to inject the property
    cls = module.__class__

    def default_deepcopy(self, memo):
        # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
        obj = memo.get(id(self), None)
        if obj is not None:
            return obj
        replica = self.__new__(self.__class__)
        memo[id(self)] = replica
        replica.__dict__ = deepcopy(self.__dict__, memo)
        # Also save all slots if they exist.
        slots_to_save = copyreg._slotnames(
            self.__class__)  # type: ignore[attr-defined]
        for slot in slots_to_save:
            if hasattr(self, slot):
                setattr(replica, slot, deepcopy(getattr(self, slot), memo))
        return replica

    def getstate(self):
        raise RuntimeError(
            "Serialization of parametrized modules is only "
            "supported through state_dict(). See:\n"

    dct = {"__getstate__": getstate}
    # We don't allow serialization of parametrized modules but should still allow deepcopying.
    # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
    if not hasattr(cls, "__deepcopy__"):
        dct["__deepcopy__"] = default_deepcopy  # type: ignore[assignment]

    param_cls = type(
        (cls, ),

    module.__class__ = param_cls
Exemplo n.º 3
    def remove_parametrizations(
        module: Module, tensor_name: str, leave_parametrized: bool = True
    ) -> Module:
        r"""Removes the parametrizations on a tensor in a module.

        - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to
          its current output. In this case, the parametrization shall not change the ``dtype``
          of the tensor.
        - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to
          the unparametrised tensor in ``module.parametrizations[tensor_name].original``.
          This is only possible when the parametrization depends on just one tensor.

            module (nn.Module): module from which remove the parametrization
            tensor_name (str): name of the parametrization to be removed
            leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized.
                Default: ``True``

            Module: module

            ValueError: if ``module[tensor_name]`` is not parametrized
            ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors

        if not is_parametrized(module, tensor_name):
            raise ValueError(
                f"Module {module} does not have a parametrization on {tensor_name}"

        # Fetch the original tensor
        assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy
        parametrizations = module.parametrizations[tensor_name]
        if parametrizations.is_tensor:
            original = parametrizations.original
            if leave_parametrized:
                with torch.no_grad():
                    t = getattr(module, tensor_name)
                # We know they have the same dtype because we have checked this when registering the
                # parametrizations. As such, we can use set_
                # We do this so that the parameter does not to change the id()
                # This way the user does not need to update the optimizer
                with torch.no_grad():
            if leave_parametrized:
                # We cannot use no_grad because we need to know whether one or more
                # original tensors required grad
                t = getattr(module, tensor_name)
                # We'll have to trust the user to add it to the optimizer
                original = Parameter(t) if t.requires_grad else t
                raise ValueError(
                    "Cannot leave unparametrized (`leave_parametrized=False`) a tensor "
                    "that is parametrized in terms of a sequence of tensors."

        # Delete the property that manages the parametrization
        delattr(module.__class__, tensor_name)
        # Delete the ParametrizationList
        del module.parametrizations[tensor_name]

        # Restore the parameter / buffer into the main class
        _register_parameter_or_buffer(module, tensor_name, original)

        # Roll back the parametrized class if no other buffer or parameter
        # is currently parametrized in this class
        if not is_parametrized(module):
            delattr(module, "parametrizations")
            # Restore class
            orig_cls = module.__class__.__bases__[0]
            module.__class__ = orig_cls
        return module
Exemplo n.º 4
def remove_parametrizations(module: Module,
                            tensor_name: str,
                            leave_parametrized: bool = True) -> Module:
    r"""Removes the parametrizations on a tensor in a module.

    - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to
      its current output. In this case, the parametrization shall not change the ``dtype``
      of the tensor.
    - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to
      the unparametrised tensor in ``module.parametrizations[tensor_name].original``.

        module (nn.Module): module from which remove the parametrization
        tensor_name (str): name of the parametrization to be removed
        leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized.
            Default: ``True``

        Module: module

        ValueError: if ``module[tensor_name]`` is not parametrized
        ValueError: if ``leave_parametrized=True`` and the parametrization changes the size or dtype
            of the tensor

    if not is_parametrized(module, tensor_name):
        raise ValueError(
            "Module {} does not have a parametrization on {}".format(
                module, tensor_name))

    # Fetch the original tensor
    original = module.parametrizations[tensor_name].original  # type: ignore
    if leave_parametrized:
        t = getattr(module, tensor_name)
        # If they have the same dtype, we reuse the original tensor.
        # We do this so that the parameter does not to change the id()
        # This way the user does not need to update the optimizer
        if t.dtype == original.dtype:
            with torch.no_grad():
            raise ValueError(
                "The parametrization changes the dtype of the tensor from {} to {}. "
                "It is not supported to leave the tensor parametrized (`leave_parametrized=True`) "
                "in this case.".format(original.dtype, t.dtype))
    # Delete the property that manages the parametrization
    delattr(module.__class__, tensor_name)
    # Delete the ParametrizationList
    del module.parametrizations[tensor_name]  # type: ignore

    # Restore the parameter / buffer into the main class
    if isinstance(original, Parameter):
        module.register_parameter(tensor_name, original)
        module.register_buffer(tensor_name, original)

    # Roll back the parametrized class if no other buffer or parameter
    # is currently parametrized in this class
    if not is_parametrized(module):
        delattr(module, "parametrizations")
        # Restore class
        orig_cls = module.__class__.__bases__[0]
        module.__class__ = orig_cls
    return module