예제 #1
0
def test_fails_Sphere():
    with pytest.raises(ValueError):
        manifold = geoopt.Sphere()
        manifold.origin(())
    with pytest.raises(ValueError):
        manifold = geoopt.Sphere()
        manifold.origin(1)
예제 #2
0
def test_fails_Sphere():
    with pytest.raises(ValueError):
        manifold = geoopt.Sphere()
        manifold.random_uniform(())
    with pytest.raises(ValueError):
        manifold = geoopt.Sphere()
        manifold.random_uniform(1)
예제 #3
0
def product_case():
    torch.manual_seed(42)
    ex = [
        torch.randn(10),
        torch.randn(3) / 10,
        torch.randn(3, 2),
        torch.randn(())
    ]
    ev = [
        torch.randn(10),
        torch.randn(3) / 10,
        torch.randn(3, 2),
        torch.randn(())
    ]
    manifolds = [
        geoopt.Sphere(),
        geoopt.PoincareBall(),
        geoopt.Stiefel(),
        geoopt.Euclidean(),
    ]
    x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
    v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]

    product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape)
                                                for i in range(len(ex))))

    yield UnaryCase(
        manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
        product_manifold,
    )
    # + 1 case without stiefel
    torch.manual_seed(42)
    ex = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    ev = [torch.randn(10), torch.randn(3) / 10, torch.randn(())]
    manifolds = [
        geoopt.Sphere(),
        geoopt.PoincareBall(),
        # geoopt.Stiefel(),
        geoopt.Euclidean(),
    ]
    x = [manifolds[i].projx(ex[i]) for i in range(len(manifolds))]
    v = [manifolds[i].proju(x[i], ev[i]) for i in range(len(manifolds))]

    product_manifold = geoopt.ProductManifold(*((manifolds[i], ex[i].shape)
                                                for i in range(len(ex))))

    yield UnaryCase(
        manifold_shapes[geoopt.ProductManifold],
        product_manifold.pack_point(*x),
        product_manifold.pack_point(*ex),
        product_manifold.pack_point(*v),
        product_manifold.pack_point(*ev),
        product_manifold,
    )
예제 #4
0
class ManifoldFactory:

    geoopt_manifolds = {
        "euclidean": lambda dims: gt.Euclidean(1),
        "poincare": lambda dims: gt.PoincareBall(),
        "lorentz": lambda dims: gt.Lorentz(),
        "sphere": lambda dims: gt.Sphere(),
        "prod-hysph": get_prod_hysph_manifold,
        "prod-hyhy": get_prod_hyhy_manifold,
        "prod-hyeu": get_prod_hyeu_manifold,
        "prod-sphsph": get_prod_sphsph_manifold,
        "spd": lambda dims: SymmetricPositiveDefinite(),
    }

    sympa_manifolds = {
        "upper": UpperHalfManifold,
        "bounded": BoundedDomainManifold,
        "dual": CompactDualManifold
    }

    @classmethod
    def get_manifold(cls, manifold_name, metric_name, dims):
        if manifold_name in cls.geoopt_manifolds:
            return cls.geoopt_manifolds[manifold_name](dims)

        manifold = cls.sympa_manifolds[manifold_name]
        metric = MetricType.from_str(metric_name)
        return manifold(dims=dims, metric=metric)
예제 #5
0
def test_knn():
    man = geoopt.Sphere()
    layer = geoopt_layers.KNN(manifold=man, dim=0, k=5)
    points = man.random(10, 9)
    knn_points = layer(points)
    assert knn_points.shape == (10, 5, 9)
    man.assert_check_point_on_manifold(knn_points)
예제 #6
0
def test_expmap():
    sphere = geoopt.Sphere()
    layer = geoopt_layers.Expmap(manifold=sphere, origin=sphere.origin(10))
    smth = torch.randn(12, 3, 10)
    out = layer(smth)

    sphere.assert_check_point_on_manifold(out)
예제 #7
0
def test_product():
    manifold = geoopt.ProductManifold(
        (geoopt.Sphere(), 10),
        (geoopt.PoincareBall(), 3),
        (geoopt.Stiefel(), (20, 2)),
        (geoopt.Euclidean(), 43),
    )
    sample = manifold.random(20, manifold.n_elements)
    manifold.assert_check_point_on_manifold(sample)
예제 #8
0
def test_knn_unroll():
    man = geoopt.Sphere()
    layer = geoopt_layers.KNN(manifold=man, dim=1, k=5)
    layer_u = geoopt_layers.KNN(manifold=man, dim=1, k=5, unroll=0)
    points = man.random(7, 10, 9)
    knn_points1 = layer(points)
    knn_points2 = layer_u(points)
    assert knn_points1.shape == (7, 10, 5, 9)
    assert knn_points1.shape == knn_points2.shape
    np.testing.assert_allclose(knn_points1, knn_points2)
예제 #9
0
def test_lambda_one_origin():
    sphere = geoopt.Sphere()
    point = sphere.random(1, 10)
    func = torch.nn.Linear(10, 10)
    with pytest.raises(ValueError):
        geoopt_layers.TangentLambda(func, manifold=sphere)
    layer = geoopt_layers.TangentLambda(func, manifold=sphere, origin_shape=10)
    out = layer(point)

    sphere.assert_check_point_on_manifold(out)
예제 #10
0
def test_lambda_no_dim_change_origins():
    sphere = geoopt.Sphere()
    func = torch.nn.Linear(10, 11)
    with pytest.raises(ValueError):
        geoopt_layers.TangentLambda(
            func,
            manifold=sphere,
            origin=sphere.origin(10),
            out_origin=sphere.origin(11),
        )
예제 #11
0
def test_lambda_two_origins():
    sphere = geoopt.Sphere()
    point = sphere.random(1, 10)
    func = torch.nn.Linear(10, 10)
    layer = geoopt_layers.TangentLambda(
        func, manifold=sphere, origin_shape=10, same_origin=False
    )
    out = layer(point)

    sphere.assert_check_point_on_manifold(out)
예제 #12
0
def test_distance_pairwise_auto():
    man = geoopt.Sphere()
    layer = geoopt_layers.PairwiseDistances(manifold=man, dim=0)
    points = man.random(10, 9)
    distances = layer(points)
    assert distances.shape == (10, 10)

    layer = geoopt_layers.PairwiseDistances(manifold=man, dim=-2)
    points = man.random(10, 3, 9)
    distances = layer(points.detach())
    assert distances.shape == (10, 3, 3)
예제 #13
0
def test_pickle2():
    t = torch.ones(10)
    p = geoopt.ManifoldParameter(t, manifold=geoopt.Sphere())
    with tempfile.TemporaryDirectory() as path:
        torch.save(p, os.path.join(path, "tens.t7"))
        p1 = torch.load(os.path.join(path, "tens.t7"))
    assert isinstance(p1, geoopt.ManifoldParameter)
    assert p.stride() == p1.stride()
    assert p.storage_offset() == p1.storage_offset()
    assert p.requires_grad == p1.requires_grad
    np.testing.assert_allclose(p.detach(), p1.detach())
    assert isinstance(p.manifold, type(p1.manifold))
예제 #14
0
def test_distance_pairwise_paired():
    man = geoopt.Sphere()
    layer = geoopt_layers.PairwiseDistances(manifold=man, dim=0)
    points = man.random(10, 9)
    points1 = man.random(7, 9)
    distances = layer(points, points1)
    assert distances.shape == (10, 7)

    layer = geoopt_layers.PairwiseDistances(manifold=man, dim=-3)
    points = man.random(10, 3, 9)
    points1 = man.random(7, 3, 9)
    distances = layer(points, points1)
    assert distances.shape == (10, 7, 3)
예제 #15
0
def test_knn_permutations():
    man = geoopt.Sphere()
    layer = geoopt_layers.KNN(manifold=man, dim=-2, k=7)
    points = man.random(1, 2, 3, 9, 5)
    knn_points_1 = layer(points)
    assert knn_points_1.shape == (1, 2, 3, 9, 7, 5)

    perm1 = geoopt_layers.shape.Permute(3, 0, 1, 2, 4, contiguous=True)
    layer = geoopt_layers.KNN(manifold=man, dim=0, k=7)
    knn_points_2 = layer(perm1(points))
    assert knn_points_2.shape == (9, 7, 1, 2, 3, 5)
    knn_points_2 = knn_points_2.permute(2, 3, 4, 0, 1, 5)
    man.assert_check_point_on_manifold(knn_points_1)
    np.testing.assert_allclose(knn_points_1, knn_points_2)
예제 #16
0
def test_remap_provided_origin():
    sphere = geoopt.Sphere()
    poincare = geoopt.PoincareBall()
    point = sphere.random(1, 10)
    func = torch.nn.Linear(10, 13)
    layer = geoopt_layers.RemapLambda(
        func,
        source_manifold=sphere,
        target_manifold=poincare,
        source_origin=sphere.origin(10),
        target_origin=poincare.origin(13),
    )
    out = layer(point)

    poincare.assert_check_point_on_manifold(out)
예제 #17
0
def sphere_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = ex / torch.norm(ex)
    v = ev - (x @ ev) * x

    manifold = geoopt.Sphere()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
예제 #18
0
def test_remap_request_shapes():
    sphere = geoopt.Sphere()
    poincare = geoopt.PoincareBall()
    func = torch.nn.Linear(10, 13)
    with pytest.raises(ValueError):
        geoopt_layers.RemapLambda(
            func,
            source_manifold=sphere,
            target_manifold=poincare,
            source_origin_shape=None,
            target_origin_shape=10,
        )
    with pytest.raises(ValueError):
        geoopt_layers.RemapLambda(
            func,
            source_manifold=sphere,
            target_manifold=poincare,
            source_origin_shape=10,
            target_origin_shape=None,
        )
예제 #19
0
def sphere_subspace_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    subspace = torch.rand(shape[-1], 2, dtype=torch.float64)

    Q, _ = geoopt.linalg.batch_linalg.qr(subspace, "reduced")
    P = Q @ Q.t()

    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = (ex @ P.t()) / torch.norm(ex @ P.t())
    v = (ev - (x @ ev) * x) @ P.t()

    manifold = geoopt.Sphere(intersection=subspace)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact(intersection=subspace)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
예제 #20
0
def sphere_compliment_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    complement = torch.rand(shape[-1], 1, dtype=torch.float64)

    Q, _ = geoopt.linalg.batch_linalg.qr(complement, "reduced")
    P = -Q @ Q.transpose(-1, -2)
    P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1

    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = (ex @ P.t()) / torch.norm(ex @ P.t())
    v = (ev - (x @ ev) * x) @ P.t()

    manifold = geoopt.Sphere(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
예제 #21
0
def test_knn_index():
    man = geoopt.Sphere()
    layer = geoopt_layers.KNNIndex(manifold=man, dim=0, k=5)
    points = man.random(10, 9)
    idx = layer(points)
    assert idx.shape == (10, 5)
예제 #22
0
def test_remap_init():
    sphere = geoopt.Sphere()
    layer = geoopt_layers.Remap(source_manifold=sphere, source_origin=sphere.origin(10))
    assert layer.target_manifold is sphere
예제 #23
0
def test_fails_SphereProjection():
    subspace = torch.rand(10, 2, dtype=torch.float64)
    manifold = geoopt.Sphere(intersection=subspace)
    with pytest.raises(ValueError):
        manifold.random_uniform(50)
예제 #24
0
def test_random_Sphere():
    manifold = geoopt.Sphere()
    point = manifold.random_uniform(3, 10, 10)
    manifold.assert_check_point_on_manifold(point)
    assert point.manifold is manifold
예제 #25
0
def test_random_SphereProjection():
    subspace = torch.rand(10, 2, dtype=torch.float64)
    manifold = geoopt.Sphere(intersection=subspace)
    point = manifold.random_uniform(3, 10, 10)
    manifold.assert_check_point_on_manifold(point)
    assert point.manifold is manifold
예제 #26
0
def get_prod_hysph_manifold(dims):
    poincare = gt.PoincareBall()
    sphere = gt.Sphere()
    return gt.ProductManifold((poincare, dims // 2), (sphere, dims // 2))
예제 #27
0
def get_prod_sphsph_manifold(dims):
    sphere = gt.Sphere()
    return gt.ProductManifold((sphere, dims // 2), (sphere, dims // 2))
예제 #28
0
def test_no_type_promotion():
    p = geoopt.Sphere().random(10)
    t = p.manifold.proju(p, torch.randn(10))
    assert not isinstance(t, type(p))
예제 #29
0
def test_logmap():
    sphere = geoopt.Sphere()
    layer = geoopt_layers.Logmap(manifold=sphere, origin=sphere.origin(10))
    out = layer(sphere.random(2, 30, 10))
    assert isinstance(out, torch.Tensor)
    assert not torch.isnan(out).any()
예제 #30
0
def test_distance2centroids():
    man = geoopt.Sphere()
    layer = geoopt_layers.Distance2Centroids(man, 9, 256)
    points = man.random(10, 9)
    distances = layer(points)
    assert distances.shape == (10, 256)