def test_unpack_tensor_method():
    pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)),
                           (Euclidean(), ()))
    point = pman.random(4, pman.n_elements)
    parts = point.unpack_tensor()
    assert parts[0].shape == (4, 10)
    assert parts[1].shape == (4, 3, 2)
    assert parts[2].shape == (4, )
def test_from_point():
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(5, 3, 2),
        Euclidean().random_normal(5),
    ]
    pman = ProductManifold.from_point(*point, batch_dims=1)
    assert pman.n_elements == (10 + 3 * 2 + 1)
def test_from_point_checks_shapes():
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(3, 3, 2),
        Euclidean().random_normal(5),
    ]
    pman = ProductManifold.from_point(*point)
    assert pman.n_elements == (5 * 10 + 3 * 3 * 2 + 5 * 1)
    with pytest.raises(ValueError) as e:
        _ = ProductManifold.from_point(*point, batch_dims=1)
    assert e.match("Not all parts have same batch shape")
def test_reshaping():
    pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)),
                           (Euclidean(), ()))
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(5, 3, 2),
        Euclidean().random_normal(5),
    ]
    tensor = pman.pack_point(*point)
    assert tensor.shape == (5, 10 + 3 * 2 + 1)
    point_new = pman.unpack_tensor(tensor)
    for old, new in zip(point, point_new):
        np.testing.assert_allclose(old, new)
def test_component_inner_product():
    pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)),
                           (Euclidean(), ()))
    point = [
        Sphere().random_uniform(5, 10),
        Sphere().random_uniform(5, 3, 2),
        Euclidean().random_normal(5),
    ]
    tensor = pman.pack_point(*point)
    tangent = torch.randn_like(tensor)
    tangent = pman.proju(tensor, tangent)

    inner = pman.component_inner(tensor, tangent)
    assert inner.shape == (5, pman.n_elements)
def test_init():
    pman = ProductManifold((Sphere(), 10), (Sphere(), (3, 2)))
    assert pman.n_elements == (10 + 3 * 2)
    assert pman.name == "(Sphere)x(Sphere)"