class SkewSymmetricMatrices(MatrixLieAlgebra):
    """Class for skew-symmetric matrices.

    n : int
        Number of rows and columns.
    def __init__(self, n):
        dim = int(n * (n - 1) / 2)
        super(SkewSymmetricMatrices, self).__init__(dim, n)
        self.ambient_space = Matrices(n, n)

        if n == 2:
            self.basis = gs.array([[[0.0, -1.0], [1.0, 0.0]]])
        elif n == 3:
            self.basis = gs.array([
                [[0.0, 0.0, 0.0], [0.0, 0.0, -1.0], [0.0, 1.0, 0.0]],
                [[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 0.0]],
                [[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
            self.basis = gs.zeros((dim, n, n))
            basis = []
            for row in gs.arange(n - 1):
                for col in gs.arange(row + 1, n):
                        gs.array_from_sparse([(row, col), (col, row)],
                                             [1.0, -1.0], (n, n)))
            self.basis = gs.stack(basis)

    def belongs(self, mat, atol=gs.atol):
        """Evaluate if mat is a skew-symmetric matrix.

        mat : array-like, shape=[..., n, n]
            Square matrix to check.
        atol : float
            Tolerance for the equality evaluation.
            Optional, default: backend atol.

        belongs : array-like, shape=[...,]
            Boolean evaluating if matrix is skew symmetric.
        has_right_shape = self.ambient_space.belongs(mat)
        if gs.all(has_right_shape):
            return Matrices.is_skew_symmetric(mat=mat, atol=atol)
        return has_right_shape

    def random_point(self, n_samples=1, bound=1.0):
        """Sample from a uniform distribution in a cube.

        n_samples : int
            Number of samples.
            Optional, default: 1.
        bound : float
            Bound of the interval in which to sample each entry.
            Optional, default: 1.

        point : array-like, shape=[..., n, n]
        return self.projection(
            super(SkewSymmetricMatrices, self).random_point(n_samples, bound))

    def projection(cls, mat):
        r"""Compute the skew-symmetric component of a matrix.

        The skew-symmetric part of a matrix :math: `X` is defined by
        .. math:
                    (X - X^T) / 2

        mat : array-like, shape=[..., n, n]

        skew_sym : array-like, shape=[..., n, n]
            Skew-symmetric matrix.
        return Matrices.to_skew_symmetric(mat)

    def basis_representation(self, matrix_representation):
        """Calculate the coefficients of given matrix in the basis.

        Compute a 1d-array that corresponds to the input matrix in the basis

        matrix_representation : array-like, shape=[..., n, n]

        basis_representation : array-like, shape=[..., dim]
            Representation in the basis.
        if self.n == 2:
            return matrix_representation[..., 1, 0][..., None]
        if self.n == 3:
            vec = gs.stack([
                matrix_representation[..., 2, 1],
                matrix_representation[..., 0, 2],
                matrix_representation[..., 1, 0],
            return gs.transpose(vec)

        return gs.triu_to_vec(matrix_representation, k=1)
class TestMatrices(geomstats.tests.TestCase):
    def setUp(self):

        self.m = 2
        self.n = 3
        self.space = Matrices(m=self.n, n=self.n)
        self.space_nonsquare = Matrices(m=self.m, n=self.n)
        self.metric = self.space.metric
        self.n_samples = 2

    def test_mul(self):
        a = gs.eye(3, 3, 1)
        b = gs.eye(3, 3, -1)
        c = gs.array([
            [1., 0., 0.],
            [0., 1., 0.],
            [0., 0., 0.]])
        d = gs.array([
            [0., 0., 0.],
            [0., 1., 0.],
            [0., 0., 1.]])
        result = self.space.mul([a, b], [b, a])
        expected = gs.array([c, d])
        self.assertAllClose(result, expected)

        result = self.space.mul(a, [a, b])
        expected = gs.array([gs.eye(3, 3, 2), c])
        self.assertAllClose(result, expected)

    def test_bracket(self):
        x = gs.array([
            [0., 0., 0.],
            [0., 0., -1.],
            [0., 1., 0.]])
        y = gs.array([
            [0., 0., 1.],
            [0., 0., 0.],
            [-1., 0., 0.]])
        z = gs.array([
            [0., -1., 0.],
            [1., 0., 0.],
            [0., 0., 0.]])
        result = self.space.bracket([x, y], [y, z])
        expected = gs.array([z, x])
        self.assertAllClose(result, expected)

        result = self.space.bracket(x, [x, y, z])
        expected = gs.array([gs.zeros((3, 3)), z, -y])
        self.assertAllClose(result, expected)

    def test_transpose(self):
        tr = self.space.transpose
        ar = gs.array
        a = gs.eye(3, 3, 1)
        b = gs.eye(3, 3, -1)
        self.assertAllClose(tr(a), b)
        self.assertAllClose(tr(ar([a, b])), ar([b, a]))

    def test_is_symmetric(self):
        not_squared = gs.array([[1., 2.], [2., 1.], [3., 1.]])
        result = self.space.is_symmetric(not_squared)
        expected = False
        self.assertAllClose(result, expected)

        sym_mat = gs.array([[1., 2.], [2., 1.]])
        result = self.space.is_symmetric(sym_mat)
        expected = gs.array(True)
        self.assertAllClose(result, expected)

        not_a_sym_mat = gs.array([[1., 0.6, -3.],
                                  [6., -7., 0.],
                                  [0., 7., 8.]])
        result = self.space.is_symmetric(not_a_sym_mat)
        expected = gs.array(False)
        self.assertAllClose(result, expected)

    def test_is_skew_symmetric(self):
        skew_mat = gs.array([[0, - 2.],
                            [2., 0]])
        result = self.space.is_skew_symmetric(skew_mat)
        expected = gs.array(True)
        self.assertAllClose(result, expected)

        not_a_sym_mat = gs.array([[1., 0.6, -3.],
                                  [6., -7., 0.],
                                  [0., 7., 8.]])
        result = self.space.is_skew_symmetric(not_a_sym_mat)
        expected = gs.array(False)
        self.assertAllClose(result, expected)

    def test_is_symmetric_vectorization(self):
        points = gs.array([
            [[1., 2.],
             [2., 1.]],
            [[3., 4.],
             [4., 5.]],
            [[1., 2.],
             [3., 4.]]])
        result = self.space.is_symmetric(points)
        expected = [True, True, False]
        self.assertAllClose(result, expected)

    def test_make_symmetric(self):
        sym_mat = gs.array([[1., 2.],
                            [2., 1.]])
        result = self.space.to_symmetric(sym_mat)
        expected = sym_mat
        self.assertAllClose(result, expected)

        mat = gs.array([[1., 2., 3.],
                        [0., 0., 0.],
                        [3., 1., 1.]])
        result = self.space.to_symmetric(mat)
        expected = gs.array([[1., 1., 3.],
                             [1., 0., 0.5],
                             [3., 0.5, 1.]])
        self.assertAllClose(result, expected)

        mat = gs.array([[1e100, 1e-100, 1e100],
                        [1e100, 1e-100, 1e100],
                        [1e-100, 1e-100, 1e100]])
        result = self.space.to_symmetric(mat)

        res = 0.5 * (1e100 + 1e-100)

        expected = gs.array([[1e100, res, res],
                             [res, 1e-100, res],
                             [res, res, 1e100]])
        self.assertAllClose(result, expected)

    def test_make_symmetric_and_is_symmetric_vectorization(self):
        points = gs.array([
            [[1., 2.],
             [3., 4.]],
            [[5., 6.],
             [4., 9.]]])

        sym_points = self.space.to_symmetric(points)
        result = gs.all(self.space.is_symmetric(sym_points))
        expected = True
        self.assertAllClose(result, expected)

    def test_inner_product(self):
        base_point = gs.array([
            [1., 2., 3.],
            [0., 0., 0.],
            [3., 1., 1.]])

        tangent_vector_1 = gs.array([
            [1., 2., 3.],
            [0., -10., 0.],
            [30., 1., 1.]])

        tangent_vector_2 = gs.array([
            [1., 4., 3.],
            [5., 0., 0.],
            [3., 1., 1.]])

        result = self.metric.inner_product(

        expected = gs.trace(

        self.assertAllClose(result, expected)

    def test_cong(self):
        base_point = gs.array([
            [1., 2., 3.],
            [0., 0., 0.],
            [3., 1., 1.]])

        tangent_vector = gs.array([
            [1., 2., 3.],
            [0., -10., 0.],
            [30., 1., 1.]])

        result = self.space.congruent(tangent_vector, base_point)
        expected = gs.matmul(
            tangent_vector, gs.transpose(base_point))
        expected = gs.matmul(base_point, expected)

        self.assertAllClose(result, expected)

    def test_belongs(self):
        base_point_square = gs.zeros((self.n, self.n))
        base_point_nonsquare = gs.zeros((self.m, self.n))

        result = self.space.belongs(base_point_square)
        expected = True
        self.assertAllClose(result, expected)
        result = self.space_nonsquare.belongs(base_point_square)
        expected = False
        self.assertAllClose(result, expected)

        result = self.space.belongs(base_point_nonsquare)
        expected = False
        self.assertAllClose(result, expected)
        result = self.space_nonsquare.belongs(base_point_nonsquare)
        expected = True
        self.assertAllClose(result, expected)

        result = self.space.belongs(gs.zeros((2, 2, 3)))

        result = self.space.belongs(gs.zeros((2, 3, 3)))

    def test_is_diagonal(self):
        base_point = gs.array([
            [1., 2., 3.],
            [0., 0., 0.],
            [3., 1., 1.]])
        result = self.space.is_diagonal(base_point)
        expected = False
        self.assertAllClose(result, expected)

        diagonal = gs.eye(3)
        result = self.space.is_diagonal(diagonal)

        base_point = gs.stack([base_point, diagonal])
        result = self.space.is_diagonal(base_point)
        expected = gs.array([False, True])
        self.assertAllClose(result, expected)

        base_point = gs.stack([diagonal] * 2)
        result = self.space.is_diagonal(base_point)

        base_point = gs.reshape(gs.arange(6), (2, 3))
        result = self.space.is_diagonal(base_point)

    def test_norm(self):
        for n_samples in [1, 2]:
            mat = self.space.random_point(n_samples)
            result = self.metric.norm(mat)
            expected = self.space.frobenius_product(mat, mat) ** .5
            self.assertAllClose(result, expected)
class _GraphSpace:
    r"""Class for the Graph Space.

    Graph Space to analyse populations of labelled and unlabelled graphs.
    The space focuses on graphs with scalar euclidean attributes on nodes and edges,
    with a finite number of nodes and both directed and undirected edges.
    For undirected graphs, use symmeric adjacency matrices. The space is a quotient
    space obtained by applying the permutation action of nodes to the space
    of adjacency matrices.

    Points are represented by :math:`nodes \times nodes` adjacency matrices.

    nodes : int
        Number of graph nodes
    p : int
        Dimension of euclidean parameter or label associated to a graph.

    ..[Calissano2020]  Calissano, A., Feragen, A., Vantini, S.
              “Graph Space: Geodesic Principal Components for a Population of
              Network-valued Data.”
              Mox report 14, 2020.
    def __init__(self, nodes, p=None):
        self.nodes = nodes
        self.p = p
        self.adjmat = Matrices(self.nodes, self.nodes)

    def belongs(self, graph, atol=gs.atol):
        r"""Check if the matrix is an adjacency matrix.

        The adjacency matrix should be associated to the
        graph with n nodes.

        graph : array-like, shape=[..., n, n]
            Matrix to be checked.
        atol : float
            Optional, default: backend atol.

        belongs : array-like, shape=[...,n]
            Boolean denoting if graph belongs to the space.
        return self.adjmat.belongs(graph, atol=atol)

    def random_point(self, n_samples=1, bound=1.0):
        r"""Sample in Graph Space.

        n_samples : int
            Number of samples.
            Optional, default: 1.
        bound : float
            Bound of the interval in which to sample in the tangent space.
            Optional, default: 1.

        graph_samples : array-like, shape=[..., n, n]
            Points sampled in GraphSpace(n).
        return self.adjmat.random_point(n_samples=n_samples, bound=bound)

    def permute(self, graph_to_permute, permutation):
        r"""Permutation action applied to graph observation.

        graph_to_permute : array-like, shape=[..., n, n]
            Input graphs to be permuted.
        permutation: array-like, shape=[..., n]
            Node permutations where in position i we have the value j meaning
            the node i should be permuted with node j.

        graphs_permuted : array-like, shape=[..., n, n]
            Graphs permuted.
        nodes = self.nodes
        single_graph = len(graph_to_permute.shape) < 3
        if single_graph:
            graph_to_permute = [graph_to_permute]
            permutation = [permutation]
        result = []
        for i, p in enumerate(permutation):
            if gs.all(gs.array(nodes) == gs.array(p)):
                gtype = graph_to_permute[i].dtype
                permutation_matrix = gs.array_from_sparse(
                    data=gs.ones(nodes, dtype=gtype),
                    indices=list(zip(list(range(nodes)), p)),
                    target_shape=(nodes, nodes),
        return result[0] if single_graph else gs.array(result)