def product_case(): torch.manual_seed(42) ex = [ torch.randn(10), torch.randn(3) / 10, torch.randn(3, 2), torch.randn(()) ] ev = [ torch.randn(10), torch.randn(3) / 10, torch.randn(3, 2), torch.randn(()) ] manifolds = [ geoopt.Sphere(), geoopt.PoincareBall(), geoopt.Stiefel(), geoopt.Euclidean(), ] x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))] v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))] product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape) for i in range(len(ex)))) yield UnaryCase( manifold_shapes[geoopt.ProductManifold], product_manifold.pack_point(*x), product_manifold.pack_point(*ex), product_manifold.pack_point(*v), product_manifold.pack_point(*ev), product_manifold, ) # + 1 case without stiefel torch.manual_seed(42) ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())] ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())] manifolds = [ geoopt.Sphere(), geoopt.PoincareBall(), # geoopt.Stiefel(), geoopt.Euclidean(), ] x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))] v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))] product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape) for i in range(len(ex)))) yield UnaryCase( manifold_shapes[geoopt.ProductManifold], product_manifold.pack_point(*x), product_manifold.pack_point(*ex), product_manifold.pack_point(*v), product_manifold.pack_point(*ev), product_manifold, )
def test_product(): manifold = geoopt.ProductManifold( (geoopt.Sphere(), 10), (geoopt.PoincareBall(), 3), (geoopt.Stiefel(), (20, 2)), (geoopt.Euclidean(), 43), ) sample = manifold.random(20, manifold.n_elements) manifold.assert_check_point_on_manifold(sample)
def get_prod_hyeu_manifold(dims): poincare = gt.PoincareBall() euclidean = gt.Euclidean(1) return gt.ProductManifold((poincare, dims // 2), (euclidean, dims // 2))
def get_prod_sphsph_manifold(dims): sphere = gt.Sphere() return gt.ProductManifold((sphere, dims // 2), (sphere, dims // 2))
def get_prod_hyhy_manifold(dims): poincare = gt.PoincareBall() return gt.ProductManifold((poincare, dims // 2), (poincare, dims // 2))
def get_prod_hysph_manifold(dims): poincare = gt.PoincareBall() sphere = gt.Sphere() return gt.ProductManifold((poincare, dims // 2), (sphere, dims // 2))