Пример #1
0
    def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        r"""
        Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`Z`.

        For a function :math:`f(Z)` on :math:`\mathcal{B}_n`, the gradient is:

        .. math::

            \operatorname{grad}_{R}(f(Z)) = A \cdot \operatorname{grad}_E(f(Z)) \cdot A

        where :math:`A = Id - \overline{Z}Z`

        Parameters
        ----------
        z : torch.Tensor
             point on the manifold
        u : torch.Tensor
             gradient to be projected

        Returns
        -------
        torch.Tensor
            Riemannian gradient
        """
        a = get_id_minus_conjugate_z_times_z(z)
        return lalg.sym(
            a @ u @ a)  # impose symmetry due to numerical instabilities
Пример #2
0
    def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        r"""
        Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`Z`.

        For a function :math:`f(Z)` on :math:`\mathcal{S}_n`, the gradient is:

        .. math::

            \operatorname{grad}_{R}(f(Z)) = Y \cdot \operatorname{grad}_E(f(Z)) \cdot Y

        where :math:`Y` is the imaginary part of :math:`Z`.

        Parameters
        ----------
        z : torch.Tensor
             point on the manifold
        u : torch.Tensor
             gradient to be projected

        Returns
        -------
        torch.Tensor
            Riemannian gradient
        """
        real_grad, imag_grad = u.real, u.imag
        y = z.imag
        real_grad = y @ real_grad @ y
        imag_grad = y @ imag_grad @ y
        return lalg.sym(sm.to_complex(
            real_grad,
            imag_grad))  # impose symmetry due to numerical instabilities
Пример #3
0
 def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor:
     if dtype and dtype not in COMPLEX_DTYPES:
         raise ValueError(f"dtype must be one of {COMPLEX_DTYPES}")
     if dtype is None:
         dtype = torch.complex128
     tens = 0.5 * torch.randn(*size, dtype=dtype, device=device)
     tens = lalg.sym(tens)
     tens.imag = lalg.expm(tens.imag)
     return tens
Пример #4
0
def test_inverse_cayley_transform_from_projx(bounded, upper, rank, dtype, eps):
    ex = torch.randn((10, rank, rank), dtype=dtype)
    ex = sym(ex)
    x = upper.projx(ex).detach()

    tran_x = sm.inverse_cayley_transform(x)
    result = sm.cayley_transform(tran_x)

    np.testing.assert_allclose(x, result, atol=eps, rtol=eps)
    upper.assert_check_point_on_manifold(result)
    bounded.assert_check_point_on_manifold(tran_x)
Пример #5
0
 def projx(self, x: torch.Tensor) -> torch.Tensor:
     return lalg.sym(x)
Пример #6
0
def get_random_complex_symmetric_matrices(points: int, dims: int,
                                          dtype: torch.dtype):
    """Returns 'points' random symmetric matrices of 'dims' x 'dims'"""
    m = torch.rand((points, dims, dims), dtype=dtype)
    return sym(m)