Exemple #1
0
def pt_manifold(A, B, Pbs):
    """Parallel transport in the manifold space.
    
    Parameters
    ----------
    A : ndarray
        SPD matrix.
    B : ndarray
        SPD matrix
    Pbs : ndarray
        SPD matrix in the manifold B, shape (n_trials, n_channels, n_channels) or (n_channels, n_channels).
    
    Returns
    -------
    Pas : ndarray
        SPD matrix in the manifold A.
    """
    Pbs = Pbs.reshape(-1, *Pbs.shape[-2:])
    n_trials, _, _ = Pbs.shape
    E = scipy_sqrtm(A @ inv(B))

    Pas = np.zeros_like(Pbs)
    for i, Pb in enumerate(Pbs):
        Pas[i] = E @ Pb @ E.T

    if n_trials == 1:
        Pas = Pas[0]

    return Pas
Exemple #2
0
def pt_tangent(A, B, Sbs):
    """Parallel transport in the tangent space.
    
    Parameters
    ----------
    A : ndarray
        SPD matrix.
    B : ndarray
        SPD matrix
    Sbs : ndarray
        The tangent vector (matrix form) in the tangent space of manifold B, 
        shape (n_trials, n_channels, n_channels) or (n_channels, n_channels).
    
    Returns
    -------
    Sas : ndarray
        The tangent vector (matrix form) in the tangent space of manifold A.
    """
    Sbs = Sbs.reshape(-1, *Sbs.shape[-2:])
    n_trials, _, _ = Sbs.shape
    E = scipy_sqrtm(A @ inv(B))

    Sas = np.zeros_like(Sbs)
    for i, Sb in enumerate(Sbs):
        Sas[i] = E @ Sb @ E.T

    if n_trials == 1:
        Sas = Sas[0]

    return Sas
Exemple #3
0
def test_matrix_sqrt(matrix_size):
    """ test that metrix sqrt function works as expected """
    def generate_cov(n):
        data = torch.randn(2 * n, n)
        return (data - data.mean(dim=0)).T @ (data - data.mean(dim=0))

    cov1 = generate_cov(matrix_size)
    cov2 = generate_cov(matrix_size)

    scipy_res = scipy_sqrtm((cov1 @ cov2).numpy()).real
    tm_res = sqrtm(cov1 @ cov2)
    assert torch.allclose(torch.tensor(scipy_res).float(), tm_res, atol=1e-3)
Exemple #4
0
def sqrtm(A):
    return scipy_sqrtm(A).real