예제 #1
0
def product_case():
    torch.manual_seed(42)
    ex = [
        torch.randn(10),
        torch.randn(3) / 10,
        torch.randn(3, 2),
        torch.randn(())
    ]
    ev = [
        torch.randn(10),
        torch.randn(3) / 10,
        torch.randn(3, 2),
        torch.randn(())
    ]
    manifolds = [
        geoopt.Sphere(),
        geoopt.PoincareBall(),
        geoopt.Stiefel(),
        geoopt.Euclidean(),
    ]
    x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
    v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]

    product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape)
                                                for i in range(len(ex))))

    yield UnaryCase(
        manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
        product_manifold,
    )
    # + 1 case without stiefel
    torch.manual_seed(42)
    ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    manifolds = [
        geoopt.Sphere(),
        geoopt.PoincareBall(),
        # geoopt.Stiefel(),
        geoopt.Euclidean(),
    ]
    x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
    v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]

    product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape)
                                                for i in range(len(ex))))

    yield UnaryCase(
        manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
        product_manifold,
    )
예제 #2
0
def test_product():
    manifold = geoopt.ProductManifold(
        (geoopt.Sphere(), 10),
        (geoopt.PoincareBall(), 3),
        (geoopt.Stiefel(), (20, 2)),
        (geoopt.Euclidean(), 43),
    )
    sample = manifold.random(20, manifold.n_elements)
    manifold.assert_check_point_on_manifold(sample)
예제 #3
0
def get_prod_hyeu_manifold(dims):
    poincare = gt.PoincareBall()
    euclidean = gt.Euclidean(1)
    return gt.ProductManifold((poincare, dims // 2), (euclidean, dims // 2))
예제 #4
0
def get_prod_sphsph_manifold(dims):
    sphere = gt.Sphere()
    return gt.ProductManifold((sphere, dims // 2), (sphere, dims // 2))
예제 #5
0
def get_prod_hyhy_manifold(dims):
    poincare = gt.PoincareBall()
    return gt.ProductManifold((poincare, dims // 2), (poincare, dims // 2))
예제 #6
0
def get_prod_hysph_manifold(dims):
    poincare = gt.PoincareBall()
    sphere = gt.Sphere()
    return gt.ProductManifold((poincare, dims // 2), (sphere, dims // 2))