def test_fails_Sphere(): with pytest.raises(ValueError): manifold = geoopt.Sphere() manifold.origin(()) with pytest.raises(ValueError): manifold = geoopt.Sphere() manifold.origin(1)
def test_fails_Sphere(): with pytest.raises(ValueError): manifold = geoopt.Sphere() manifold.random_uniform(()) with pytest.raises(ValueError): manifold = geoopt.Sphere() manifold.random_uniform(1)
def product_case(): torch.manual_seed(42) ex = [ torch.randn(10), torch.randn(3) / 10, torch.randn(3, 2), torch.randn(()) ] ev = [ torch.randn(10), torch.randn(3) / 10, torch.randn(3, 2), torch.randn(()) ] manifolds = [ geoopt.Sphere(), geoopt.PoincareBall(), geoopt.Stiefel(), geoopt.Euclidean(), ] x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))] v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))] product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape) for i in range(len(ex)))) yield UnaryCase( manifold_shapes[geoopt.ProductManifold], product_manifold.pack_point(*x), product_manifold.pack_point(*ex), product_manifold.pack_point(*v), product_manifold.pack_point(*ev), product_manifold, ) # + 1 case without stiefel torch.manual_seed(42) ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())] ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())] manifolds = [ geoopt.Sphere(), geoopt.PoincareBall(), # geoopt.Stiefel(), geoopt.Euclidean(), ] x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))] v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))] product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape) for i in range(len(ex)))) yield UnaryCase( manifold_shapes[geoopt.ProductManifold], product_manifold.pack_point(*x), product_manifold.pack_point(*ex), product_manifold.pack_point(*v), product_manifold.pack_point(*ev), product_manifold, )
class ManifoldFactory: geoopt_manifolds = { "euclidean": lambda dims: gt.Euclidean(1), "poincare": lambda dims: gt.PoincareBall(), "lorentz": lambda dims: gt.Lorentz(), "sphere": lambda dims: gt.Sphere(), "prod-hysph": get_prod_hysph_manifold, "prod-hyhy": get_prod_hyhy_manifold, "prod-hyeu": get_prod_hyeu_manifold, "prod-sphsph": get_prod_sphsph_manifold, "spd": lambda dims: SymmetricPositiveDefinite(), } sympa_manifolds = { "upper": UpperHalfManifold, "bounded": BoundedDomainManifold, "dual": CompactDualManifold } @classmethod def get_manifold(cls, manifold_name, metric_name, dims): if manifold_name in cls.geoopt_manifolds: return cls.geoopt_manifolds[manifold_name](dims) manifold = cls.sympa_manifolds[manifold_name] metric = MetricType.from_str(metric_name) return manifold(dims=dims, metric=metric)
def test_knn(): man = geoopt.Sphere() layer = geoopt_layers.KNN(manifold=man, dim=0, k=5) points = man.random(10, 9) knn_points = layer(points) assert knn_points.shape == (10, 5, 9) man.assert_check_point_on_manifold(knn_points)
def test_expmap(): sphere = geoopt.Sphere() layer = geoopt_layers.Expmap(manifold=sphere, origin=sphere.origin(10)) smth = torch.randn(12, 3, 10) out = layer(smth) sphere.assert_check_point_on_manifold(out)
def test_product(): manifold = geoopt.ProductManifold( (geoopt.Sphere(), 10), (geoopt.PoincareBall(), 3), (geoopt.Stiefel(), (20, 2)), (geoopt.Euclidean(), 43), ) sample = manifold.random(20, manifold.n_elements) manifold.assert_check_point_on_manifold(sample)
def test_knn_unroll(): man = geoopt.Sphere() layer = geoopt_layers.KNN(manifold=man, dim=1, k=5) layer_u = geoopt_layers.KNN(manifold=man, dim=1, k=5, unroll=0) points = man.random(7, 10, 9) knn_points1 = layer(points) knn_points2 = layer_u(points) assert knn_points1.shape == (7, 10, 5, 9) assert knn_points1.shape == knn_points2.shape np.testing.assert_allclose(knn_points1, knn_points2)
def test_lambda_one_origin(): sphere = geoopt.Sphere() point = sphere.random(1, 10) func = torch.nn.Linear(10, 10) with pytest.raises(ValueError): geoopt_layers.TangentLambda(func, manifold=sphere) layer = geoopt_layers.TangentLambda(func, manifold=sphere, origin_shape=10) out = layer(point) sphere.assert_check_point_on_manifold(out)
def test_lambda_no_dim_change_origins(): sphere = geoopt.Sphere() func = torch.nn.Linear(10, 11) with pytest.raises(ValueError): geoopt_layers.TangentLambda( func, manifold=sphere, origin=sphere.origin(10), out_origin=sphere.origin(11), )
def test_lambda_two_origins(): sphere = geoopt.Sphere() point = sphere.random(1, 10) func = torch.nn.Linear(10, 10) layer = geoopt_layers.TangentLambda( func, manifold=sphere, origin_shape=10, same_origin=False ) out = layer(point) sphere.assert_check_point_on_manifold(out)
def test_distance_pairwise_auto(): man = geoopt.Sphere() layer = geoopt_layers.PairwiseDistances(manifold=man, dim=0) points = man.random(10, 9) distances = layer(points) assert distances.shape == (10, 10) layer = geoopt_layers.PairwiseDistances(manifold=man, dim=-2) points = man.random(10, 3, 9) distances = layer(points.detach()) assert distances.shape == (10, 3, 3)
def test_pickle2(): t = torch.ones(10) p = geoopt.ManifoldParameter(t, manifold=geoopt.Sphere()) with tempfile.TemporaryDirectory() as path: torch.save(p, os.path.join(path, "tens.t7")) p1 = torch.load(os.path.join(path, "tens.t7")) assert isinstance(p1, geoopt.ManifoldParameter) assert p.stride() == p1.stride() assert p.storage_offset() == p1.storage_offset() assert p.requires_grad == p1.requires_grad np.testing.assert_allclose(p.detach(), p1.detach()) assert isinstance(p.manifold, type(p1.manifold))
def test_distance_pairwise_paired(): man = geoopt.Sphere() layer = geoopt_layers.PairwiseDistances(manifold=man, dim=0) points = man.random(10, 9) points1 = man.random(7, 9) distances = layer(points, points1) assert distances.shape == (10, 7) layer = geoopt_layers.PairwiseDistances(manifold=man, dim=-3) points = man.random(10, 3, 9) points1 = man.random(7, 3, 9) distances = layer(points, points1) assert distances.shape == (10, 7, 3)
def test_knn_permutations(): man = geoopt.Sphere() layer = geoopt_layers.KNN(manifold=man, dim=-2, k=7) points = man.random(1, 2, 3, 9, 5) knn_points_1 = layer(points) assert knn_points_1.shape == (1, 2, 3, 9, 7, 5) perm1 = geoopt_layers.shape.Permute(3, 0, 1, 2, 4, contiguous=True) layer = geoopt_layers.KNN(manifold=man, dim=0, k=7) knn_points_2 = layer(perm1(points)) assert knn_points_2.shape == (9, 7, 1, 2, 3, 5) knn_points_2 = knn_points_2.permute(2, 3, 4, 0, 1, 5) man.assert_check_point_on_manifold(knn_points_1) np.testing.assert_allclose(knn_points_1, knn_points_2)
def test_remap_provided_origin(): sphere = geoopt.Sphere() poincare = geoopt.PoincareBall() point = sphere.random(1, 10) func = torch.nn.Linear(10, 13) layer = geoopt_layers.RemapLambda( func, source_manifold=sphere, target_manifold=poincare, source_origin=sphere.origin(10), target_origin=poincare.origin(13), ) out = layer(point) poincare.assert_check_point_on_manifold(out)
def sphere_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = ex / torch.norm(ex) v = ev - (x @ ev) * x manifold = geoopt.Sphere() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def test_remap_request_shapes(): sphere = geoopt.Sphere() poincare = geoopt.PoincareBall() func = torch.nn.Linear(10, 13) with pytest.raises(ValueError): geoopt_layers.RemapLambda( func, source_manifold=sphere, target_manifold=poincare, source_origin_shape=None, target_origin_shape=10, ) with pytest.raises(ValueError): geoopt_layers.RemapLambda( func, source_manifold=sphere, target_manifold=poincare, source_origin_shape=10, target_origin_shape=None, )
def sphere_subspace_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] subspace = torch.rand(shape[-1], 2, dtype=torch.float64) Q, _ = geoopt.linalg.batch_linalg.qr(subspace, "reduced") P = Q @ Q.t() ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = (ex @ P.t()) / torch.norm(ex @ P.t()) v = (ev - (x @ ev) * x) @ P.t() manifold = geoopt.Sphere(intersection=subspace) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact(intersection=subspace) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def sphere_compliment_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] complement = torch.rand(shape[-1], 1, dtype=torch.float64) Q, _ = geoopt.linalg.batch_linalg.qr(complement, "reduced") P = -Q @ Q.transpose(-1, -2) P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1 ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = (ex @ P.t()) / torch.norm(ex @ P.t()) v = (ev - (x @ ev) * x) @ P.t() manifold = geoopt.Sphere(complement=complement) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact(complement=complement) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def test_knn_index(): man = geoopt.Sphere() layer = geoopt_layers.KNNIndex(manifold=man, dim=0, k=5) points = man.random(10, 9) idx = layer(points) assert idx.shape == (10, 5)
def test_remap_init(): sphere = geoopt.Sphere() layer = geoopt_layers.Remap(source_manifold=sphere, source_origin=sphere.origin(10)) assert layer.target_manifold is sphere
def test_fails_SphereProjection(): subspace = torch.rand(10, 2, dtype=torch.float64) manifold = geoopt.Sphere(intersection=subspace) with pytest.raises(ValueError): manifold.random_uniform(50)
def test_random_Sphere(): manifold = geoopt.Sphere() point = manifold.random_uniform(3, 10, 10) manifold.assert_check_point_on_manifold(point) assert point.manifold is manifold
def test_random_SphereProjection(): subspace = torch.rand(10, 2, dtype=torch.float64) manifold = geoopt.Sphere(intersection=subspace) point = manifold.random_uniform(3, 10, 10) manifold.assert_check_point_on_manifold(point) assert point.manifold is manifold
def get_prod_hysph_manifold(dims): poincare = gt.PoincareBall() sphere = gt.Sphere() return gt.ProductManifold((poincare, dims // 2), (sphere, dims // 2))
def get_prod_sphsph_manifold(dims): sphere = gt.Sphere() return gt.ProductManifold((sphere, dims // 2), (sphere, dims // 2))
def test_no_type_promotion(): p = geoopt.Sphere().random(10) t = p.manifold.proju(p, torch.randn(10)) assert not isinstance(t, type(p))
def test_logmap(): sphere = geoopt.Sphere() layer = geoopt_layers.Logmap(manifold=sphere, origin=sphere.origin(10)) out = layer(sphere.random(2, 30, 10)) assert isinstance(out, torch.Tensor) assert not torch.isnan(out).any()
def test_distance2centroids(): man = geoopt.Sphere() layer = geoopt_layers.Distance2Centroids(man, 9, 256) points = man.random(10, 9) distances = layer(points) assert distances.shape == (10, 256)