def pt_manifold(A, B, Pbs): """Parallel transport in the manifold space. Parameters ---------- A : ndarray SPD matrix. B : ndarray SPD matrix Pbs : ndarray SPD matrix in the manifold B, shape (n_trials, n_channels, n_channels) or (n_channels, n_channels). Returns ------- Pas : ndarray SPD matrix in the manifold A. """ Pbs = Pbs.reshape(-1, *Pbs.shape[-2:]) n_trials, _, _ = Pbs.shape E = scipy_sqrtm(A @ inv(B)) Pas = np.zeros_like(Pbs) for i, Pb in enumerate(Pbs): Pas[i] = E @ Pb @ E.T if n_trials == 1: Pas = Pas[0] return Pas
def pt_tangent(A, B, Sbs): """Parallel transport in the tangent space. Parameters ---------- A : ndarray SPD matrix. B : ndarray SPD matrix Sbs : ndarray The tangent vector (matrix form) in the tangent space of manifold B, shape (n_trials, n_channels, n_channels) or (n_channels, n_channels). Returns ------- Sas : ndarray The tangent vector (matrix form) in the tangent space of manifold A. """ Sbs = Sbs.reshape(-1, *Sbs.shape[-2:]) n_trials, _, _ = Sbs.shape E = scipy_sqrtm(A @ inv(B)) Sas = np.zeros_like(Sbs) for i, Sb in enumerate(Sbs): Sas[i] = E @ Sb @ E.T if n_trials == 1: Sas = Sas[0] return Sas
def test_matrix_sqrt(matrix_size): """ test that metrix sqrt function works as expected """ def generate_cov(n): data = torch.randn(2 * n, n) return (data - data.mean(dim=0)).T @ (data - data.mean(dim=0)) cov1 = generate_cov(matrix_size) cov2 = generate_cov(matrix_size) scipy_res = scipy_sqrtm((cov1 @ cov2).numpy()).real tm_res = sqrtm(cov1 @ cov2) assert torch.allclose(torch.tensor(scipy_res).float(), tm_res, atol=1e-3)
def sqrtm(A): return scipy_sqrtm(A).real