Example #1
0
    def _inject_new_class(module: Module) -> None:
        r"""Sets up a module to be parametrized.

        This works by substituting the class of the module by a class
        that extends it to be able to inject a property

        Args:
            module (nn.Module): module into which to inject the property
        """
        cls = module.__class__

        def getstate(self):
            raise RuntimeError(
                "Serialization of parametrized modules is only "
                "supported through state_dict(). See:\n"
                "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
                "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
            )

        param_cls = type(
            f"Parametrized{cls.__name__}",
            (cls,),
            {
                "__getstate__": getstate,
            },
        )

        module.__class__ = param_cls
Example #2
0
def _inject_new_class(module: Module) -> None:
    r"""Sets up a module to be parametrized.

    This works by substituting the class of the module by a class
    that extends it to be able to inject a property

    Args:
        module (nn.Module): module into which to inject the property
    """
    cls = module.__class__

    def default_deepcopy(self, memo):
        # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
        obj = memo.get(id(self), None)
        if obj is not None:
            return obj
        replica = self.__new__(self.__class__)
        memo[id(self)] = replica
        replica.__dict__ = deepcopy(self.__dict__, memo)
        # Also save all slots if they exist.
        slots_to_save = copyreg._slotnames(
            self.__class__)  # type: ignore[attr-defined]
        for slot in slots_to_save:
            if hasattr(self, slot):
                setattr(replica, slot, deepcopy(getattr(self, slot), memo))
        return replica

    def getstate(self):
        raise RuntimeError(
            "Serialization of parametrized modules is only "
            "supported through state_dict(). See:\n"
            "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
            "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
        )

    dct = {"__getstate__": getstate}
    # We don't allow serialization of parametrized modules but should still allow deepcopying.
    # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
    if not hasattr(cls, "__deepcopy__"):
        dct["__deepcopy__"] = default_deepcopy  # type: ignore[assignment]

    param_cls = type(
        f"Parametrized{cls.__name__}",
        (cls, ),
        dct,
    )

    module.__class__ = param_cls
Example #3
0
    def remove_parametrizations(
        module: Module, tensor_name: str, leave_parametrized: bool = True
    ) -> Module:
        r"""Removes the parametrizations on a tensor in a module.

        - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to
          its current output. In this case, the parametrization shall not change the ``dtype``
          of the tensor.
        - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to
          the unparametrised tensor in ``module.parametrizations[tensor_name].original``.
          This is only possible when the parametrization depends on just one tensor.

        Args:
            module (nn.Module): module from which remove the parametrization
            tensor_name (str): name of the parametrization to be removed
            leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized.
                Default: ``True``

        Returns:
            Module: module

        Raises:
            ValueError: if ``module[tensor_name]`` is not parametrized
            ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors
        """

        if not is_parametrized(module, tensor_name):
            raise ValueError(
                f"Module {module} does not have a parametrization on {tensor_name}"
            )

        # Fetch the original tensor
        assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy
        parametrizations = module.parametrizations[tensor_name]
        if parametrizations.is_tensor:
            original = parametrizations.original
            if leave_parametrized:
                with torch.no_grad():
                    t = getattr(module, tensor_name)
                # We know they have the same dtype because we have checked this when registering the
                # parametrizations. As such, we can use set_
                # We do this so that the parameter does not to change the id()
                # This way the user does not need to update the optimizer
                with torch.no_grad():
                    original.set_(t)
        else:
            if leave_parametrized:
                # We cannot use no_grad because we need to know whether one or more
                # original tensors required grad
                t = getattr(module, tensor_name)
                # We'll have to trust the user to add it to the optimizer
                original = Parameter(t) if t.requires_grad else t
            else:
                raise ValueError(
                    "Cannot leave unparametrized (`leave_parametrized=False`) a tensor "
                    "that is parametrized in terms of a sequence of tensors."
                )

        # Delete the property that manages the parametrization
        delattr(module.__class__, tensor_name)
        # Delete the ParametrizationList
        del module.parametrizations[tensor_name]

        # Restore the parameter / buffer into the main class
        _register_parameter_or_buffer(module, tensor_name, original)

        # Roll back the parametrized class if no other buffer or parameter
        # is currently parametrized in this class
        if not is_parametrized(module):
            delattr(module, "parametrizations")
            # Restore class
            orig_cls = module.__class__.__bases__[0]
            module.__class__ = orig_cls
        return module
Example #4
0
    def register_parametrization(
        module: Module,
        tensor_name: str,
        parametrization: Module,
        *,
        unsafe: bool = False,
    ) -> Module:
        r"""Adds a parametrization to a tensor in a module.

        Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``,
        the module will return the parametrized version ``parametrization(module.weight)``.
        If the original tensor requires a gradient, the backward pass will differentiate
        through :attr:`parametrization`, and the optimizer will update the tensor accordingly.

        The first time that a module registers a parametrization, this function will add an attribute
        ``parametrizations`` to the module of type :class:`~ParametrizationList`.

        The list of parametrizations on the tensor ``weight`` will be accessible under
        ``module.parametrizations.weight``.

        The original tensor will be accessible under
        ``module.parametrizations.weight.original``.

        Parametrizations may be concatenated by registering several parametrizations
        on the same attribute.

        The training mode of a registered parametrization is updated on registration
        to match the training mode of the host module

        Parametrized parameters and buffers have an inbuilt caching system that can be activated
        using the context manager :func:`cached`.

        A :attr:`parametrization` may optionally implement a method with signature

        .. code-block:: python

            def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]

        This method is called on the unparametrized tensor when the first parametrization
        is registered to compute the initial value of the original tensor.
        If this method is not implemented, the original tensor will be just the unparametrized tensor.

        If all the parametrizations registered on a tensor implement `right_inverse` it is possible
        to initialize a parametrized tensor by assigning to it, as shown in the example below.

        It is possible for the first parametrization to depend on several inputs.
        This may be implemented returning a tuple of tensors from ``right_inverse``
        (see the example implementation of a ``RankOne`` parametrization below).

        In this case, the unconstrained tensors are also located under ``module.parametrizations.weight``
        with names ``original0``, ``original1``,...

        .. note::

            If unsafe=False (default) both the forward and right_inverse methods will be called
            once to perform a number of consistency checks.
            If unsafe=True, then right_inverse will be called if the tensor is not parametrized,
            and nothing will be called otherwise.

        .. note::

            In most situations, ``right_inverse`` will be a function such that
            ``forward(right_inverse(X)) == X`` (see
            `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
            Sometimes, when the parametrization is not surjective, it may be reasonable
            to relax this.

        .. warning::

            If a parametrization depends on several inputs, :func:`~register_parametrization`
            will register a number of new parameters. If such parametrization is registered
            after the optimizer is created, these new parameters will need to be added manually
            to the optimizer. See :meth:`torch.Optimizer.add_param_group`.

        Args:
            module (nn.Module): module on which to register the parametrization
            tensor_name (str): name of the parameter or buffer on which to register
                the parametrization
            parametrization (nn.Module): the parametrization to register
        Keyword args:
            unsafe (bool): a boolean flag that denotes whether the parametrization
                may change the dtype and shape of the tensor. Default: `False`
                Warning: the parametrization is not checked for consistency upon registration.
                Enable this flag at your own risk.

        Raises:
            ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name`

        Examples:
            >>> import torch
            >>> import torch.nn as nn
            >>> import torch.nn.utils.parametrize as P
            >>>
            >>> class Symmetric(nn.Module):
            >>>     def forward(self, X):
            >>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix
            >>>
            >>>     def right_inverse(self, A):
            >>>         return A.triu()
            >>>
            >>> m = nn.Linear(5, 5)
            >>> P.register_parametrization(m, "weight", Symmetric())
            >>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric
            True
            >>> A = torch.rand(5, 5)
            >>> A = A + A.T   # A is now symmetric
            >>> m.weight = A  # Initialize the weight to be the symmetric matrix A
            >>> print(torch.allclose(m.weight, A))
            True

            >>> class RankOne(nn.Module):
            >>>     def forward(self, x, y):
            >>>         # Form a rank 1 matrix multiplying two vectors
            >>>         return x.unsqueeze(-1) @ y.unsqueeze(-2)
            >>>
            >>>     def right_inverse(self, Z):
            >>>         # Project Z onto the rank 1 matrices
            >>>         U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
            >>>         # Return rescaled singular vectors
            >>>         s0_sqrt = S[0].sqrt().unsqueeze(-1)
            >>>         return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
            >>>
            >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
            >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
            1

        """
        parametrization.train(module.training)
        if is_parametrized(module, tensor_name):
            # Correctness checks.
            # If A is the space of tensors with shape and dtype equal to module.weight
            # we check that parametrization.forward and parametrization.right_inverse are
            # functions from A to A
            if not unsafe:
                Y = getattr(module, tensor_name)
                X = parametrization(Y)
                if not isinstance(X, Tensor):
                    raise ValueError(
                        f"A parametrization must return a tensor. Got {type(X).__name__}."
                    )
                if X.dtype != Y.dtype:
                    raise ValueError(
                        "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n"
                        f"module.{tensor_name}.dtype: {Y.dtype}\n"
                        f"parametrization(module.{tensor_name}).dtype: {X.dtype}"
                    )
                if X.shape != Y.shape:
                    raise ValueError(
                        "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n"
                        f"module.{tensor_name}.shape: {Y.shape}\n"
                        f"parametrization(module.{tensor_name}).shape: {X.shape}"
                    )
                if hasattr(parametrization, "right_inverse"):
                    try:
                        Z = parametrization.right_inverse(X)  # type: ignore[operator]
                    except NotImplementedError:
                        pass
                    else:
                        if not isinstance(Z, Tensor):
                            raise ValueError(
                                f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
                            )
                        if Z.dtype != Y.dtype:
                            raise ValueError(
                                "The tensor returned by parametrization.right_inverse must have the same dtype "
                                f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
                                f"module.{tensor_name}.dtype: {Y.dtype}\n"
                                f"returned dtype: {Z.dtype}"
                            )
                        if Z.shape != Y.shape:
                            raise ValueError(
                                "The tensor returned by parametrization.right_inverse must have the same shape "
                                f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
                                f"module.{tensor_name}.shape: {Y.shape}\n"
                                f"returned shape: {Z.shape}"
                            )
                # else right_inverse is assumed to be the identity

            # add the new parametrization to the parametrization list
            assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy
            module.parametrizations[tensor_name].append(parametrization)
            # If unsafe was True in previous parametrization, keep it enabled
            module.parametrizations[tensor_name].unsafe |= unsafe  # type: ignore[index, union-attr]
        elif tensor_name in module._buffers or tensor_name in module._parameters:
            # Set the parametrization mechanism
            # Fetch the original buffer or parameter
            original = getattr(module, tensor_name)
            # We create this early to check for possible errors
            parametrizations = ParametrizationList(
                [parametrization], original, unsafe=unsafe
            )
            # Delete the previous parameter or buffer
            delattr(module, tensor_name)
            # If this is the first parametrization registered on the module,
            # we prepare the module to inject the property
            if not is_parametrized(module):
                # Change the class
                _inject_new_class(module)
                # Inject a ``ModuleDict`` into the instance under module.parametrizations
                module.parametrizations = ModuleDict()
            # Add a property into the class
            _inject_property(module, tensor_name)
            # Add a ParametrizationList
            assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy
            module.parametrizations[tensor_name] = parametrizations
        else:
            raise ValueError(
                f"Module '{module}' does not have a parameter, a buffer, or a "
                f"parametrized element with name '{tensor_name}'"
            )
        return module
Example #5
0
def remove_parametrizations(module: Module,
                            tensor_name: str,
                            leave_parametrized: bool = True) -> Module:
    r"""Removes the parametrizations on a tensor in a module.

    - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to
      its current output. In this case, the parametrization shall not change the ``dtype``
      of the tensor.
    - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to
      the unparametrised tensor in ``module.parametrizations[tensor_name].original``.

    Args:
        module (nn.Module): module from which remove the parametrization
        tensor_name (str): name of the parametrization to be removed
        leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized.
            Default: ``True``

    Returns:
        Module: module

    Raises:
        ValueError: if ``module[tensor_name]`` is not parametrized
        ValueError: if ``leave_parametrized=True`` and the parametrization changes the size or dtype
            of the tensor
    """

    if not is_parametrized(module, tensor_name):
        raise ValueError(
            "Module {} does not have a parametrization on {}".format(
                module, tensor_name))

    # Fetch the original tensor
    original = module.parametrizations[tensor_name].original  # type: ignore
    if leave_parametrized:
        t = getattr(module, tensor_name)
        # If they have the same dtype, we reuse the original tensor.
        # We do this so that the parameter does not to change the id()
        # This way the user does not need to update the optimizer
        if t.dtype == original.dtype:
            with torch.no_grad():
                original.set_(t)
        else:
            raise ValueError(
                "The parametrization changes the dtype of the tensor from {} to {}. "
                "It is not supported to leave the tensor parametrized (`leave_parametrized=True`) "
                "in this case.".format(original.dtype, t.dtype))
    # Delete the property that manages the parametrization
    delattr(module.__class__, tensor_name)
    # Delete the ParametrizationList
    del module.parametrizations[tensor_name]  # type: ignore

    # Restore the parameter / buffer into the main class
    if isinstance(original, Parameter):
        module.register_parameter(tensor_name, original)
    else:
        module.register_buffer(tensor_name, original)

    # Roll back the parametrized class if no other buffer or parameter
    # is currently parametrized in this class
    if not is_parametrized(module):
        delattr(module, "parametrizations")
        # Restore class
        orig_cls = module.__class__.__bases__[0]
        module.__class__ = orig_cls
    return module
Example #6
0
def register_parametrization(module: Module, tensor_name: str,
                             parametrization: Module) -> Module:
    r"""Adds a parametrization to a tensor in a module.

    Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``,
    the module will return the parametrized version ``parametrization(module.weight)``.
    If the original tensor requires a gradient, the backward pass will differentiate
    through the :attr:`parametrization`, and the optimizer will update the tensor accordingly.

    The first time that a module registers a parametrization, this function will add an attribute
    ``parametrizations`` to the module of type :class:`~ParametrizationList`.

    The list of parametrizations on a tensor will be accessible under
    ``module.parametrizations.weight``.

    The original tensor will be accessible under
    ``module.parametrizations.weight.original``.

    Parametrizations may be concatenated by registering several parametrizations
    on the same attribute.

    Parametrized parameters and buffers have an inbuilt caching system that can be activated
    using the context manager :func:`cached`.

    A :attr:`parametrization` may optionally implement a method with signature

    .. code-block:: python

        def right_inverse(self, X: Tensor) -> Tensor

    If :attr:`parametrization` implements this method, it will be possible to assign
    to the parametrized tensor. This may be used to initialize the tensor, as shown in the example.

    In most situations, ``right_inverse`` will be a function such that
    ``forward(right_inverse(X)) == X`` (see
    `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
    Sometimes, when the parametrization is not surjective, it may be reasonable
    to relax this, as shown in the example below.

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (str): name of the parameter or buffer on which to register
            the parametrization
        parametrization (nn.Module): the parametrization to register

    Returns:
        Module: module

    Raises:
        ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name`

    Examples:
        >>> import torch
        >>> import torch.nn.utils.parametrize as P
        >>>
        >>> class Symmetric(torch.nn.Module):
        >>>     def forward(self, X):
        >>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix
        >>>
        >>>     def right_inverse(self, A):
        >>>         return A.triu()
        >>>
        >>> m = torch.nn.Linear(5, 5)
        >>> P.register_parametrization(m, "weight", Symmetric())
        >>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric
        True
        >>> A = torch.rand(5, 5)
        >>> A = A + A.T   # A is now symmetric
        >>> m.weight = A  # Initialize the weight to be the symmetric matrix A
        >>> print(torch.allclose(m.weight, A))
        True
    """
    if is_parametrized(module, tensor_name):
        # Just add the new parametrization to the parametrization list
        module.parametrizations[tensor_name].append(
            parametrization)  # type: ignore
    elif tensor_name in module._buffers or tensor_name in module._parameters:
        # Set the parametrization mechanism
        # Fetch the original buffer or parameter
        original = getattr(module, tensor_name)
        # Delete the previous parameter or buffer
        delattr(module, tensor_name)
        # If this is the first parametrization registered on the module,
        # we prepare the module to inject the property
        if not is_parametrized(module):
            # Change the class
            _inject_new_class(module)
            # Inject the a ``ModuleDict`` into the instance under module.parametrizations
            module.parametrizations = ModuleDict()
        # Add a property into the class
        _inject_property(module, tensor_name)
        # Add a ParametrizationList
        module.parametrizations[
            tensor_name] = ParametrizationList(  # type: ignore
                [parametrization], original)
    else:
        raise ValueError(
            "Module '{}' does not have a parameter, a buffer, or a "
            "parametrized element with name '{}'".format(module, tensor_name))
    return module