def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor: r""" Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`Z`. For a function :math:`f(Z)` on :math:`\mathcal{B}_n`, the gradient is: .. math:: \operatorname{grad}_{R}(f(Z)) = A \cdot \operatorname{grad}_E(f(Z)) \cdot A where :math:`A = Id - \overline{Z}Z` Parameters ---------- z : torch.Tensor point on the manifold u : torch.Tensor gradient to be projected Returns ------- torch.Tensor Riemannian gradient """ a = get_id_minus_conjugate_z_times_z(z) return lalg.sym( a @ u @ a) # impose symmetry due to numerical instabilities
def egrad2rgrad(self, z: torch.Tensor, u: torch.Tensor) -> torch.Tensor: r""" Transform gradient computed using autodiff to the correct Riemannian gradient for the point :math:`Z`. For a function :math:`f(Z)` on :math:`\mathcal{S}_n`, the gradient is: .. math:: \operatorname{grad}_{R}(f(Z)) = Y \cdot \operatorname{grad}_E(f(Z)) \cdot Y where :math:`Y` is the imaginary part of :math:`Z`. Parameters ---------- z : torch.Tensor point on the manifold u : torch.Tensor gradient to be projected Returns ------- torch.Tensor Riemannian gradient """ real_grad, imag_grad = u.real, u.imag y = z.imag real_grad = y @ real_grad @ y imag_grad = y @ imag_grad @ y return lalg.sym(sm.to_complex( real_grad, imag_grad)) # impose symmetry due to numerical instabilities
def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor: if dtype and dtype not in COMPLEX_DTYPES: raise ValueError(f"dtype must be one of {COMPLEX_DTYPES}") if dtype is None: dtype = torch.complex128 tens = 0.5 * torch.randn(*size, dtype=dtype, device=device) tens = lalg.sym(tens) tens.imag = lalg.expm(tens.imag) return tens
def test_inverse_cayley_transform_from_projx(bounded, upper, rank, dtype, eps): ex = torch.randn((10, rank, rank), dtype=dtype) ex = sym(ex) x = upper.projx(ex).detach() tran_x = sm.inverse_cayley_transform(x) result = sm.cayley_transform(tran_x) np.testing.assert_allclose(x, result, atol=eps, rtol=eps) upper.assert_check_point_on_manifold(result) bounded.assert_check_point_on_manifold(tran_x)
def projx(self, x: torch.Tensor) -> torch.Tensor: return lalg.sym(x)
def get_random_complex_symmetric_matrices(points: int, dims: int, dtype: torch.dtype): """Returns 'points' random symmetric matrices of 'dims' x 'dims'""" m = torch.rand((points, dims, dims), dtype=dtype) return sym(m)