예제 #1
0
def fixed_rank(module, tensor_name, rank, f="softplus", triv="expm"):
    r""" Adds a fixed rank parametrization to the tensor ``module[tensor_name]``.

    When accessing ``module[tensor_name]``, the module will return the
    parametrized version :math:`X` which will have rank equal to ``rank``.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.fixed_rank(layer, "weight", 5)
        >>> list(torch.svd(layer.weight).S > 1e-7).count(True)
        5

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        rank (int): Rank of the matrix.
            It has to be less than the minimum of the two dimensions of the
            matrix
        f (str or callable): Optional. The string `"softplus"` or a callable
            that maps real numbers to the interval (0, infty). Default: `"softplus"`
        triv (str or callable): Optional.
            A map that maps :math:`\operatorname{Skew}(n)` onto the orthogonal
            matrices surjectively. This is used to optimize the U, V in the
            SVD. It can be one of `["expm", "cayley"]` or a custom
            callable. Default: `"expm"`
    """
    size = getattr(module, tensor_name).size()
    P.register_parametrization(module, tensor_name,
                               FixedRank(size, rank, f, triv))
예제 #2
0
def skew(module, tensor_name="weight", lower=True):
    r"""Adds a skew-symmetric parametrization to the matrix ``module.tensor_name``.

    When accessing ``module.tensor_name``, the module will return the parametrized
    version :math:`X` so that :math:`X^\intercal = -X`.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(30, 30)
        >>> geotorch.skew(layer, "weight")
        >>> torch.allclose(layer.weight, -layer.weight.T)
        True

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied. Default: ``"weight"``
        lower (bool): Optional. Uses the lower triangular part of the matrix to
            parametrize the matrix. Default: ``True``
    """
    P.register_parametrization(module, tensor_name, Skew(lower))
    return module
예제 #3
0
def sphere(module, tensor_name, radius=1.0, embedded=False):
    r"""Adds a spherical parametrization to the vector (or tensor) ``module.tensor_name``.

    When accessing ``module.tensor_name``, the module will return the parametrized
    version :math:`v` so that :math:`\lVert v \rVert = 1`.

    If the tensor has more than one dimension, the parametrization will be
    applied to the last dimension.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.sphere(layer, "bias")
        >>> torch.norm(layer.bias)
        tensor(1.)
        >>> geotorch.sphere(layer, "weight")  # Make the columns unit norm
        >>> torch.norm(torch.norm(layer.weight, dim=-1) - torch.ones(30))
        tensor(6.1656e-07)

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        radius (float): Optional.
            Radius of the sphere. It has to be positive. Default: 1.
        embedded (bool): Optional.
            Chooses between the implementation of the sphere using the exponential
            map (``embedded=False``) and that using the projection from the ambient space (``embedded=True``)
            Default. ``True``
    """
    size = getattr(module, tensor_name).size()
    cls = SphereEmbedded if embedded else Sphere
    M = cls(size, radius)
    P.register_parametrization(module, tensor_name, M)
    setattr(module, tensor_name, M.sample())
예제 #4
0
    def test_backprop(self):
        r"""Test that we may instantiate the parametrizations and
        register them in modules of several sizes. Check that the
        results are on the sphere
        """
        sizes = [1, 2, 3, 8]

        for n, lower in itertools.product(sizes, [True, False]):
            layer = nn.Linear(n, n)
            P.register_parametrization(
                layer, "weight",
                Symmetric(size=layer.weight.size(), lower=lower))

            input_ = torch.rand(5, n)
            optim = torch.optim.SGD(layer.parameters(), lr=1.0)

            # Assert that is stays in Sym(n) after some optimiser steps
            for i in range(2):
                print(i)
                with P.cached():
                    self.assertIsSymmetric(layer.weight)
                    loss = layer(input_).sum()
                optim.zero_grad()
                loss.backward()
                optim.step()
예제 #5
0
def skew(module, tensor_name, lower=True):
    r""" Adds a skew-symmetric parametrization to the matrix ``module[tensor_name]``.

    When accessing ``module[tensor_name]``, the module will return the parametrized
    version :math:`X` so that :math:`X^\intercal = -X`.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(30, 30)
        >>> geotorch.skew(layer, "weight")
        >>> torch.norm(layer.weight + layer.weight.t())
        tensor(0.)

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        lower (bool): Optional. Uses the lower triangular part of the matrix to
            parametrize the matrix. Default: `True`
    """
    size = getattr(module, tensor_name).size()
    P.register_parametrization(module, tensor_name, Skew(size, lower))
예제 #6
0
def positive_semidefinite(module, tensor_name, triv="expm"):
    r""" Adds a positive definiteness constraint to the tensor
    ``module[tensor_name]``.

    When accessing ``module[tensor_name]``, the module will return the
    parametrized version :math:`X` which will be symmetric and with
    non-negative eigenvalues

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 20)
        >>> geotorch.positive_semidefinite(layer, "weight")
        >>> L = torch.symeig(layer.weight).eigenvalues
        >>> L[L.abs() < 1e-7] = 0.0  # Round errors
        >>> (L >= 0.0).all()
        tensor(True)

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        triv (str or callable): Optional.
            A map that maps :math:`\operatorname{Skew}(n)` onto the orthogonal
            matrices surjectively. This is used to optimize the Q in the eigenvalue
            decomposition. It can be one of `["expm", "cayley"]` or a custom
            callable. Default: `"expm"`
    """
    size = getattr(module, tensor_name).size()
    P.register_parametrization(module, tensor_name, PSSD(size, triv))
예제 #7
0
def sphere(module, tensor_name, r=1.0):
    r""" Adds a spherical parametrization to the vector (or tensor)
    ``module[tensor_name]``.

    When accessing ``module[tensor_name]``, the module will return the parametrized
    version :math:`v` so that :math:`\lVert v \rVert = 1`.

    If the tensor has more than one dimension, the parametrization will be
    applied to the last dimension.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.sphere(layer, "bias")
        >>> torch.norm(layer.bias)
        tensor(1.)
        >>> geotorch.sphere(layer, "weight")  # Make the columns unit norm
        >>> torch.norm(torch.norm(layer.weight, dim=-1) - torch.ones(30))
        tensor(6.1656e-07)

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        r (float): Optional.
            Radius of the sphere. It has to be positive. Default: 1.
    """
    size = getattr(module, tensor_name).size()
    P.register_parametrization(module, tensor_name, Sphere(size, r))
예제 #8
0
def low_rank(module, tensor_name, rank, triv="expm"):
    r"""Adds a low rank parametrization to the tensor ``module.tensor_name``.

    When accessing ``module.tensor_name``, the module will return the
    parametrized version :math:`X` which will have rank at most ``rank``.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.low_rank(layer, "weight", 4)
        >>> list(torch.svd(layer.weight).S > 1e-7).count(True) <= 4
        True

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        rank (int): Rank of the matrix.
            It has to be less than the minimum of the two dimensions of the
            matrix
        triv (str or callable): Optional.
            A map that maps skew-symmetric matrices onto the orthogonal matrices
            surjectively. This is used to optimize the :math:`U` and :math:`V` in the
            SVD. It can be one of ``["expm", "cayley"]`` or a custom
            callable. Default: ``"expm"``
    """
    size = getattr(module, tensor_name).size()
    M = LowRank(size, rank, triv)
    P.register_parametrization(module, tensor_name, M)
    setattr(module, tensor_name, M.sample())
예제 #9
0
def orthogonal(module, tensor_name, triv="expm"):
    r"""Adds an orthogonal parametrization to the tensor ``module.tensor_name``.

    When accessing ``module.tensor_name``, the module will return the
    parametrized version :math:`X` so that :math:`X^\intercal X = \operatorname{I}`.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.orthogonal(layer, "weight")
        >>> torch.norm(layer.weight.t() @ layer.weight - torch.eye(20,20))
        tensor(4.8488e-05)

        >>> layer = nn.Conv2d(20, 40, 3, 3)  # Make the kernels orthogonal
        >>> geotorch.orthogonal(layer, "weight")
        >>> torch.norm(layer.weight.transpose(-2, -1) @ layer.weight - torch.eye(3,3).repeat(40,20,1,1))
        tensor(1.2225e-05)

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        triv (str or callable): Optional.
            A map that maps a skew-symmetric matrix to an orthogonal matrix.
            It can be the exponential of matrices or the cayley transform passing
            ``["expm", "cayley"]`` or a custom callable.  Default: ``"expm"``
    """
    size = getattr(module, tensor_name).size()
    M = Stiefel(size, triv)
    P.register_parametrization(module, tensor_name, M)
    setattr(module, tensor_name, M.sample())
예제 #10
0
    def _test_manifold(self, M, args_sample, args_constr, device, size,
                       initialize):
        inputs = [torch.rand(3, size[0], device=device)]
        layers = [nn.Linear(*size, device=device)]
        # Just test on convolution for small layers, otherwise it takes too long
        if min(size) < 100:
            inputs.append(
                torch.rand(6, 5, size[0] + 7, size[1] + 3, device=device))
            layers.append(nn.Conv2d(5, 4, size, device=device))

        for input_, layer in zip(inputs, layers):
            old_size = layer.weight.size()
            # Somewhat dirty but will do
            if isinstance(M, types.FunctionType):
                M(layer, "weight", **args_constr)
            else:
                # initialize the weight first (annoying)
                M_ = M(size=layer.weight.size(), **args_constr).to(device)
                X = M_.sample(**args_sample)
                with torch.no_grad():
                    layer.weight.copy_(X)
                P.register_parametrization(layer, "weight", M_)
            # Check that it does not change the size of the layer
            self.assertEqual(old_size, layer.weight.size(), msg=f"{layer}")
            self._test_training(layer, args_sample, input_, initialize)
예제 #11
0
    def _test_manifold(self, M, args_sample, args_constr, device, size):
        # Test Linear
        layer = nn.Linear(*size)
        input_ = torch.rand(3, size[0]).to(device)
        old_size = layer.weight.size()
        # Somewhat dirty but will do
        if isinstance(M, types.FunctionType):
            M(layer, "weight", **args_constr)
        else:
            P.register_parametrization(
                layer, "weight", M(size=layer.weight.size(), **args_constr))
        layer = layer.to(device)
        # Check that it does not change the size of the layer
        self.assertEqual(old_size, layer.weight.size(), msg=f"{layer}")
        self._test_interface(layer, args_sample, input_)

        # Just for the smaller ones, for the large ones this is just too expensive
        if min(size) < 100:
            # Test Convolutionar (tensorial)
            layer = nn.Conv2d(5, 4, size)
            input_ = torch.rand(6, 5, size[0] + 7, size[1] + 3).to(device)
            old_size = layer.weight.size()
            # Somewhat dirty but will do
            if isinstance(M, types.FunctionType):
                M(layer, "weight", **args_constr)
            else:
                P.register_parametrization(
                    layer, "weight", M(size=layer.weight.size(),
                                       **args_constr))
            layer = layer.to(device)
            # Check that it does not change the size of the layer
            self.assertEqual(old_size, layer.weight.size(), msg=f"{layer}")
            self._test_interface(layer, args_sample, input_)
예제 #12
0
def almost_orthogonal(module, tensor_name, lam, f="sigmoid", triv="expm"):
    r""" Adds an almost orthogonal parametrization to the tensor ``module[tensor_name]``.

    When accessing ``module[tensor_name]``, the module will return the
    parametrized version :math:`X` which will have its singular values in
    the interval :math:`[1-\texttt{lam}, 1+\texttt{lam}]`

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.almost_orthogonal(layer, "weight", 0.5)
        >>> S = torch.svd(layer.weight).S
        >>> all(S >= 0.5 and S <= 1.5)
        True

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        lam (float): Radius. A float in the interval [0, 1]
        f (str or callable): Optional. One of `["sigmoid", "tanh", "sin"]`
            or a callable that maps real numbers to the interval [-1, 1].
            Default: `"sigmoid"`
        triv (str or callable): Optional.
            A map that maps :math:`\operatorname{Skew}(n)` onto the orthogonal
            matrices surjectively. This is used to optimize the U, V in the
            SVD. It can be one of `["expm", "cayley"]` or a custom
            callable. Default: `"expm"`
    """
    size = getattr(module, tensor_name).size()
    P.register_parametrization(module, tensor_name,
                               AlmostOrthogonal(size, lam, f, triv))
예제 #13
0
def invertible(module, tensor_name, f="softplus", triv="expm"):
    r""" Adds an invertibility constraint to the tensor ``module[tensor_name]``.

    When accessing ``module[tensor_name]``, the module will return the
    parametrized version :math:`X` which will have positive determinant and,
    in particular, it will be invertible.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 20)
        >>> geotorch.invertible(layer, "weight", 5)
        >>> torch.det(layer.weight) > 0.0
        True

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        f (str or callable): Optional. The string `"softplus"` or a callable
            that maps real numbers to the interval (0, infty). Default: `"softplus"`
        triv (str or callable): Optional.
            A map that maps :math:`\operatorname{Skew}(n)` onto the orthogonal
            matrices surjectively. This is used to optimize the U, V in the
            SVD. It can be one of `["expm", "cayley"]` or a custom
            callable. Default: `"expm"`
    """
    size = getattr(module, tensor_name).size()
    P.register_parametrization(module, tensor_name, GLp(size, f, triv))
예제 #14
0
def positive_semidefinite_fixed_rank(module,
                                     tensor_name,
                                     rank,
                                     f="softplus",
                                     triv="expm"):
    r"""Adds a positive definiteness constraint to the tensor ``module.tensor_name``.

    When accessing ``module.tensor_name``, the module will return the
    parametrized version :math:`X` which will be symmetric and with non-negative
    eigenvalues and exactly ``rank`` of them non-zero.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 20)
        >>> geotorch.positive_semidefinite_fixed_rank(layer, "weight", 5)
        >>> L = torch.symeig(layer.weight).eigenvalues
        >>> L[L.abs() < 1e-7] = 0.0  # Round errors
        >>> (L >= 0.0).all()
        tensor(True)
        >>> list(L > 0.0).count(True)
        5

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        rank (int): Rank of the matrix.
            It has to be less than the minimum of the two dimensions of the
            matrix
        f (str or callable or pair of callables): Optional. Either:

            - ``"softplus"``

            - A callable that maps real numbers to the interval :math:`(0, \infty)`

            - A pair of callables such that the first maps the real numbers to
              :math:`(0, \infty)` and the second is a (right) inverse of the first

            Default: ``"softplus"``
        triv (str or callable): Optional.
            A map that maps skew-symmetric matrices onto the orthogonal
            matrices surjectively. This is used to optimize the :math:`Q` in the
            eigenvalue decomposition. It can be one of ``["expm", "cayley"]`` or
            a custom callable. Default: ``"expm"``
    """
    size = getattr(module, tensor_name).size()
    M = PSSDFixedRank(size, rank, f, triv)
    P.register_parametrization(module, tensor_name, M)
    setattr(module, tensor_name, M.sample())
예제 #15
0
    def test_backprop(self):
        r"""Test that we may instantiate the parametrizations and
        register them in modules of several sizes. Check that the
        results are on the sphere
        """
        sizes = [1, 2, 3, 4, 7, 8]

        with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
            torch.random.manual_seed(8888)
            for n in sizes:
                for cls in [Sphere, SphereEmbedded]:
                    layer = nn.Linear(n, 4)
                    P.register_parametrization(layer, "bias",
                                               cls(size=layer.bias.size()))
                    P.register_parametrization(layer, "weight",
                                               cls(size=layer.weight.size()))

                    with torch.no_grad():
                        layer.parametrizations.weight.uniform_init_()
                        layer.parametrizations.bias.uniform_init_()
                        self.assertInSn(layer.weight)
                        self.assertInSn(layer.bias)

                    input_ = torch.rand(5, n)
                    optim = torch.optim.SGD(layer.parameters(), lr=1.0)

                    # Assert that is stays in S^n after some optimiser steps
                    with torch.autograd.set_detect_anomaly(True):
                        for i in range(2):
                            print(i)
                            with P.cached():
                                self.assertInSn(layer.weight)
                                self.assertInSn(layer.bias)
                                loss = layer(input_).sum()
                            optim.zero_grad()
                            loss.backward()
                            optim.step()

                    # If we change the base, the forward pass should give the same
                    # SphereEmbedded does not have a base
                    if cls != SphereEmbedded:
                        for w in ["weight", "bias"]:
                            with torch.no_grad():
                                out_old = layer(input_)
                                getattr(layer.parametrizations,
                                        w).update_base()
                                out_new = layer(input_)
                                self.assertAlmostEqual(
                                    (out_old - out_new).abs().max().item(),
                                    0.0,
                                    places=5,
                                )
예제 #16
0
def grassmannian(module, tensor_name, triv="expm"):
    r""" Adds an parametrization to the tensor ``module[tensor_name]`` so that the
    result represents a subspace. If the initial matrix was of size :math:`n \times k`
    the parametrized matrix will represent a subspace of dimension :math:`k` of
    :math:`\mathbb{R}^n`.

    When accessing ``module[tensor_name]``, the module will return the parametrized
    version :math:`X` so that :math:`X` represents :math:`k` orthogonal vectors of
    :math:`\mathbb{R}^n` that span the subspace. That is, the resulting matrix will
    be orthogonal, :math:`X^\intercal X = \operatorname{Id}`.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    .. note::

        Even though this space resembles that generated by :func:`geotorch.orthogonal`,
        it is actually a subspace of that, as every subspace can be represented by many
        different basis of vectors that span it.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.grassmannian(layer, "weight")
        >>> torch.norm(layer.weight.t() @ layer.weight - torch.eye(20,20))
        tensor(1.8933e-05)

        >>> layer = nn.Conv2d(20, 40, 3, 3)  # Make the kernels represent subspaces
        >>> geotorch.grassmannian(layer, "weight")
        >>> torch.norm(layer.weight.transpose(-2, -1) @ layer.weight - torch.eye(3,3).repeat(40,20,1,1))
        tensor(8.3796-06)

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        triv (str or callable): Optional.
            A map that maps a skew-symmetric matrix to an orthogonal matrix.
            It can be the exponential of matrices or the cayley transform passing
            `["expm", "cayley"]` or a custom callable.  Default: `"expm"`
    """
    size = getattr(module, tensor_name).size()
    if len(size) < 2:
        raise ValueError(
            "Cannot put grassmannian constraints on a vector. "
            "Got a tensor of size {}".format(size)
        )
    n, k = size[-2:]
    n, k = max(n, k), min(n, k)
    cls = GrassmannianTall if n > 4 * k else Grassmannian
    P.register_parametrization(module, tensor_name, cls(size, triv))
예제 #17
0
def _register_manifold(module, tensor_name, cls, *args):
    tensor = getattr(module, tensor_name)
    M = cls(tensor.size(), *args).to(device=tensor.device, dtype=tensor.dtype)
    P.register_parametrization(module, tensor_name, M)

    # Initialize without checking in manifold
    X = M.sample()
    param_list = module.parametrizations[tensor_name]
    with torch.no_grad():
        for m in reversed(param_list):
            X = m.right_inverse(X, check_in_manifold=False)
        param_list.original.copy_(X)

    return module
예제 #18
0
def _register_manifold(module, tensor_name, cls, *args):
    tensor = getattr(module, tensor_name)
    M = cls(tensor.size(), *args).to(device=tensor.device, dtype=tensor.dtype)

    # Initialize without checking in manifold
    X = M.sample()
    if not P.is_parametrized(module, tensor_name):
        with torch.no_grad():
            tensor.copy_(X)
    else:
        setattr(module, tensor_name, X)

    P.register_parametrization(module, tensor_name, M, unsafe=True)

    return module
예제 #19
0
def orthogonal(module, tensor_name, triv="expm"):
    r""" Adds an orthogonal parametrization to the tensor ``module[tensor_name]``.

    When accessing ``module[tensor_name]``, the module will return the
    parametrized version :math:`X` so that :math:`X^\intercal X = \operatorname{Id}`.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.orthogonal(layer, "weight")
        >>> torch.norm(layer.weight.t() @ layer.weight - torch.eye(20,20))
        tensor(4.8488e-05)

        >>> layer = nn.Conv2d(20, 40, 3, 3)  # Make the kernels orthogonal
        >>> geotorch.orthogonal(layer, "weight")
        >>> torch.norm(layer.weight.transpose(-2, -1) @ layer.weight - torch.eye(3,3).repeat(40,20,1,1))
        tensor(1.2225e-05)

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        triv (str or callable): Optional.
            A map that maps a skew-symmetric matrix to an orthogonal matrix.
            It can be the exponential of matrices or the cayley transform passing
            `["expm", "cayley"]` or a custom callable.  Default: `"expm"`
    """
    size = getattr(module, tensor_name).size()
    if len(size) < 2:
        raise ValueError(
            "Cannot put orthogonal constraints on a vector. "
            "Got a tensor of size {}".format(size)
        )
    n, k = size[-2:]
    n, k = max(n, k), min(n, k)
    if n == k:
        cls = SO
    elif n > 4 * k:
        cls = StiefelTall
    else:
        cls = Stiefel
    P.register_parametrization(module, tensor_name, cls(size, triv))
예제 #20
0
    def _test_custom_trivialization(self, cls):
        def qr(X):
            return torch.qr(X).Q

        # Note that qr is not an analytic function. As such, it may not be used with StiefelTall
        layer = nn.Linear(5, 3)
        P.register_parametrization(layer, "weight",
                                   cls(size=layer.weight.size(), triv=qr))

        optim = torch.optim.SGD(layer.parameters(), lr=0.1)
        input_ = torch.rand(5, layer.in_features)
        for _ in range(2):
            with P.cached():
                self.assertIsOrthogonal(layer.weight)
                loss = layer(input_).sum()
            optim.zero_grad()
            loss.backward()
            optim.step()
예제 #21
0
    def _test_custom_trivialization(self, cls):
        def cayley(X):
            n = X.size(0)
            Id = torch.eye(n, dtype=X.dtype, device=X.device)
            return torch.solve(Id - X, Id + X)[0]

        layer = nn.Linear(5, 3)
        P.register_parametrization(layer, "weight",
                                   cls(size=layer.weight.size(), triv=cayley))

        optim = torch.optim.SGD(layer.parameters(), lr=0.1)
        input_ = torch.rand(5, layer.in_features)
        for _ in range(2):
            with P.cached():
                self.assertIsOrthogonal(layer.weight)
                loss = layer(input_).sum()
            optim.zero_grad()
            loss.backward()
            optim.step()
예제 #22
0
def almost_orthogonal(module, tensor_name, lam, f="sin", triv="expm"):
    r"""Adds an almost orthogonal parametrization to the tensor ``module.tensor_name``.

    When accessing ``module.tensor_name``, the module will return the
    parametrized version :math:`X` which will have its singular values in
    the interval :math:`[1-\texttt{lam}, 1+\texttt{lam}]`

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.almost_orthogonal(layer, "weight", 0.5)
        >>> S = torch.svd(layer.weight).S
        >>> all(S >= 0.5 and S <= 1.5)
        True

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        lam (float): Radius of the interval for the singular values. A float in the interval :math:`[0, 1]`
        f (str or callable or pair of callables): Optional. Either:

            - One of ``["scaled_sigmoid", "tanh", "sin"]``

            - A callable that maps real numbers to the interval :math:`[-1, 1]`

            - A pair of callables such that the first maps the real numbers to
              :math:`[-1, 1]` and the second is a (right) inverse of the first

            Default: ``"sin"``
        triv (str or callable): Optional.
            A map that maps skew-symmetric matrices onto the orthogonal matrices
            surjectively. This is used to optimize the :math:`U` and :math:`V` in the
            SVD. It can be one of ``["expm", "cayley"]`` or a custom
            callable. Default: ``"expm"``
    """
    size = getattr(module, tensor_name).size()
    M = AlmostOrthogonal(size, lam, f, triv)
    P.register_parametrization(module, tensor_name, M)
    setattr(module, tensor_name, M.sample())
예제 #23
0
def fixed_rank(module, tensor_name, rank, f="softplus", triv="expm"):
    r"""Adds a fixed rank parametrization to the tensor ``module.tensor_name``.

    When accessing ``module.tensor_name``, the module will return the
    parametrized version :math:`X` which will have rank equal to ``rank``.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 30)
        >>> geotorch.fixed_rank(layer, "weight", 5)
        >>> list(torch.svd(layer.weight).S > 1e-7).count(True)
        5

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        rank (int): Rank of the matrix.
            It has to be less than the minimum of the two dimensions of the
            matrix
        f (str or callable or pair of callables): Optional. Either:

            - ``"softplus"``

            - A callable that maps real numbers to the interval :math:`(0, \infty)`

            - A pair of callables such that the first maps the real numbers to
              :math:`(0, \infty)` and the second is a (right) inverse of the first

            Default: ``"softplus"``
        triv (str or callable): Optional.
            A map that maps skew-symmetric matrices onto the orthogonal matrices
            surjectively. This is used to optimize the :math:`U` and :math:`V` in the
            SVD. It can be one of ``["expm", "cayley"]`` or a custom
            callable. Default: ``"expm"``
    """
    size = getattr(module, tensor_name).size()
    M = FixedRank(size, rank, f, triv)
    P.register_parametrization(module, tensor_name, M)
    setattr(module, tensor_name, M.sample())
예제 #24
0
def invertible(module, tensor_name, f="softplus", triv="expm"):
    r"""Adds an invertibility constraint to the tensor ``module.tensor_name``.

    When accessing ``module.tensor_name``, the module will return the
    parametrized version :math:`X` which will have positive determinant and,
    in particular, it will be invertible.

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 20)
        >>> geotorch.invertible(layer, "weight", 5)
        >>> torch.det(layer.weight) > 0.0
        True

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        f (str or callable or pair of callables): Optional. Either:

            - ``"softplus"``

            - A callable that maps real numbers to the interval :math:`(0, \infty)`

            - A pair of callables such that the first maps the real numbers to
              :math:`(0, \infty)` and the second is a (right) inverse of the first

            Default: ``"softplus"``
        triv (str or callable): Optional.
            A map that maps skew-symmetric matrices onto the orthogonal matrices
            surjectively. This is used to optimize the :math:`U` and :math:`V` in the
            SVD. It can be one of ``["expm", "cayley"]`` or a custom
            callable. Default: ``"expm"``
    """
    size = getattr(module, tensor_name).size()
    M = GLp(size, f, triv)
    P.register_parametrization(module, tensor_name, M)
    setattr(module, tensor_name, M.sample())
예제 #25
0
def positive_definite(module, tensor_name, f="softplus", triv="expm"):
    r"""Adds a positive definiteness constraint to the tensor ``module.tensor_name``.

    When accessing ``module.tensor_name``, the module will return the
    parametrized version :math:`X` which will be symmetric and with positive
    eigenvalues

    If the tensor has more than two dimensions, the parametrization will be
    applied to the last two dimensions.

    Examples::

        >>> layer = nn.Linear(20, 20)
        >>> geotorch.positive_definite(layer, "weight")
        >>> (torch.symeig(layer.weight).eigenvalues > 0.0).all()
        tensor(True)

    Args:
        module (nn.Module): module on which to register the parametrization
        tensor_name (string): name of the parameter, buffer, or parametrization
            on which the parametrization will be applied
        f (str or callable or pair of callables): Optional. Either:

            - ``"softplus"``

            - A callable that maps real numbers to the interval :math:`(0, \infty)`

            - A pair of callables such that the first maps the real numbers to
              :math:`(0, \infty)` and the second is a (right) inverse of the first

            Default: ``"softplus"``
        triv (str or callable): Optional.
            A map that maps skew-symmetric matrices onto the orthogonal
            matrices surjectively. This is used to optimize the :math:`Q` in the eigenvalue
            decomposition. It can be one of ``["expm", "cayley"]`` or a custom
            callable. Default: ``"expm"``
    """
    size = getattr(module, tensor_name).size()
    M = PSD(size, f, triv)
    P.register_parametrization(module, tensor_name, M)
    setattr(module, tensor_name, M.sample())
예제 #26
0
    def test_GLp(self):
        sizes = [1, 2, 3, 4, 7, 8]

        with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
            torch.random.manual_seed(8888)
            for n in sizes:
                for layer in [nn.Linear(n, n), nn.Conv2d(7, 4, n)]:
                    print("GLp({}) on {}".format(n, str(layer)))
                    M = GLp(size=layer.weight.size())
                    P.register_parametrization(layer, "weight", M)
                    self.assertTrue(P.is_parametrized(layer, "weight"))
                    self.assertPositiveDet(layer.weight)

                    optim = torch.optim.SGD(layer.parameters(), lr=0.1)
                    if isinstance(layer, nn.Linear):
                        input_ = torch.rand(5, n)
                    elif isinstance(layer, nn.Conv2d):
                        # batch x in_channel x in_length x in_width
                        input_ = torch.rand(6, 7, 9, 8)

                    for i in range(2):
                        print(i)
                        loss = layer(input_).sum()
                        optim.zero_grad()
                        loss.backward()
                        optim.step()

                        self.assertPositiveDet(layer.weight)

                    # Test update_base
                    prev_out = layer(input_)
                    layer.parametrizations.weight.update_base()
                    new_out = layer(input_)
                    self.assertAlmostEqual(
                        torch.norm(prev_out - new_out).abs().max().item(),
                        0.0,
                        places=3,
                    )
예제 #27
0
    def _test_layers(self, cls, cls_tall):
        sizes = [
            (8, 1),
            (8, 3),
            (8, 4),
            (8, 8),
            (7, 1),
            (7, 3),
            (7, 4),
            (7, 7),
            (1, 7),
            (2, 7),
            (1, 1),
            (1, 2),
        ]
        trivs = ["expm"]

        for (n, k), triv in itertools.product(sizes, trivs):
            for layer in [nn.Linear(n, k), nn.Conv2d(n, 4, k)]:
                layers = []
                test_so = cls != Grassmannian and n == k
                layers.append(layer)
                layers.append(deepcopy(layer))
                if test_so:
                    layers.append(deepcopy(layer))
                    P.register_parametrization(
                        layers[2], "weight", SO(size=layers[2].weight.size(), triv=triv)
                    )
                elif n != k:
                    # If it's not square it should throw
                    with self.assertRaises(ValueError):
                        size = layer.weight.size()[:-2] + (n, k)
                        SO(size=size, triv=triv)

                P.register_parametrization(
                    layers[0], "weight", cls(size=layers[0].weight.size(), triv=triv)
                )
                P.register_parametrization(
                    layers[1],
                    "weight",
                    cls_tall(size=layers[1].weight.size(), triv=triv),
                )
                yield layers
예제 #28
0
    def test_positive_semidefinite(self):
        sizes = [
            (1, 1),
            (2, 2),
            (3, 3),
            (4, 4),
            (7, 7),
            (8, 8),
        ]

        rs = [1, 3, 4]

        with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
            torch.random.manual_seed(8888)
            for cls in [PSSDLowRank, PSSDFixedRank, PSSD, PSD]:
                for (n, k), r in itertools.product(sizes, rs):
                    for layer in [nn.Linear(n, k), nn.Conv2d(n, 4, k)]:
                        needs_rank = cls in [PSSDLowRank, PSSDFixedRank]
                        if not needs_rank and r != 1:
                            continue
                        # Only show r when we have a non-full rank
                        print(
                            "{}({}, {}{}) on {}".format(
                                cls.__name__,
                                n,
                                k,
                                ", {}".format(r) if needs_rank else "",
                                str(layer),
                            )
                        )
                        r = min(n, k, r)
                        if needs_rank:
                            M = cls(size=layer.weight.size(), rank=r)
                        else:
                            M = cls(size=layer.weight.size())
                        P.register_parametrization(layer, "weight", M)
                        self.assertTrue(P.is_parametrized(layer, "weight"))
                        Q_orig, L_orig = M.original
                        L_orig = M.f(L_orig)
                        self.assertIsOrthogonal(Q_orig)
                        self.assertIsSymmetric(layer.weight)
                        self.assertHasEigenvalues(layer.weight, L_orig)

                        optim = torch.optim.SGD(layer.parameters(), lr=0.1)
                        if isinstance(layer, nn.Linear):
                            input_ = torch.rand(5, n)
                        elif isinstance(layer, nn.Conv2d):
                            # batch x in_channel x in_length x in_width
                            input_ = torch.rand(6, n, 9, 8)

                        for i in range(2):
                            print(i)
                            loss = layer(input_).sum()
                            optim.zero_grad()
                            loss.backward()
                            optim.step()

                            Q_orig, L_orig, = M.original
                            L_orig = M.f(L_orig)
                            self.assertIsOrthogonal(Q_orig)
                            self.assertIsSymmetric(layer.weight)
                            self.assertHasEigenvalues(layer.weight, L_orig)

                        # Test update_base
                        prev_out = layer(input_)
                        layer.parametrizations.weight.update_base()
                        new_out = layer(input_)
                        self.assertAlmostEqual(
                            torch.norm(prev_out - new_out).abs().max().item(),
                            0.0,
                            places=3,
                        )
예제 #29
0
    def test_lowrank(self):
        sizes = [
            (8, 1),
            (8, 4),
            (8, 8),
            (7, 1),
            (7, 3),
            (7, 4),
            (7, 7),
            (1, 7),
            (2, 7),
            (1, 8),
            (2, 8),
            (1, 1),
            (2, 1),
            (1, 2),
        ]

        rs = [1, 3, 8]

        with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
            torch.random.manual_seed(8888)
            for cls in [FixedRank, LowRank]:
                for (n, k), r in itertools.product(sizes, rs):
                    for layer in [nn.Linear(n, k), nn.Conv2d(n, 4, k)]:
                        print("{}({}, {}, {}) on {}".format(
                            cls.__name__, n, k, r, str(layer)))
                        r = min(n, k, r)
                        M = cls(size=layer.weight.size(), rank=r)
                        P.register_parametrization(layer, "weight", M)
                        self.assertTrue(P.is_parametrized(layer, "weight"))
                        U_orig, S_orig, V_orig = M.original
                        if cls == FixedRank:
                            # Apply f, as S_orig is just the unconstrained vector in R^n
                            S_orig = M.f(S_orig)
                        self.assertIsOrthogonal(U_orig)
                        self.assertIsOrthogonal(V_orig)
                        self.assertHasSingularValues(layer.weight, S_orig)

                        optim = torch.optim.SGD(layer.parameters(), lr=0.1)
                        if isinstance(layer, nn.Linear):
                            input_ = torch.rand(5, n)
                        elif isinstance(layer, nn.Conv2d):
                            # batch x in_channel x in_length x in_width
                            input_ = torch.rand(6, n, 9, 8)

                        for i in range(2):
                            print(i)
                            loss = layer(input_).sum()
                            optim.zero_grad()
                            loss.backward()
                            optim.step()

                            U_orig, S_orig, V_orig = M.original
                            if cls == FixedRank:
                                # Apply f, as S_orig is just the unconstrained vector in R^n
                                S_orig = M.f(S_orig)
                            self.assertIsOrthogonal(U_orig)
                            self.assertIsOrthogonal(V_orig)
                            self.assertHasSingularValues(layer.weight, S_orig)

                        # Test update_base
                        prev_out = layer(input_)
                        layer.parametrizations.weight.update_base()
                        new_out = layer(input_)
                        self.assertAlmostEqual(
                            torch.norm(prev_out - new_out).abs().max().item(),
                            0.0,
                            places=3,
                        )
예제 #30
0
def _register_manifold(module, tensor_name, cls, *args):
    tensor = getattr(module, tensor_name)
    M = cls(tensor.size(), *args).to(device=tensor.device, dtype=tensor.dtype)
    P.register_parametrization(module, tensor_name, M)
    setattr(module, tensor_name, M.sample())
    return module