Esempio n. 1
0
def test_scaling_not_implemented():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    pa = sball.random(10)
    with pytest.raises(NotImplementedError) as e:
        sball.mobius_fn_apply(lambda x: x, pa)
    assert e.match("Scaled version of 'mobius_fn_apply' is not available")
Esempio n. 2
0
def test_rescaling_methods_accessible():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    rsball = geoopt.Scaled(sball, 0.5)
    v0 = torch.arange(10).float() / 10
    v1 = -torch.arange(10).float() / 10
    rsball.geodesic(0.5, v0, v1)
def test_spline_conv(bias, sizes, kernel_size, degree, root_weight, dim):
    ball = geoopt.PoincareBallExact()
    ball_out = geoopt.PoincareBallExact(c=0.1)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    x = ball.random(4, 5)
    pseudo = torch.rand(edge_index.size(1), dim)
    out = HyperbolicSplineConv(*sizes,
                               bias=bias,
                               ball=ball,
                               ball_out=ball_out,
                               kernel_size=kernel_size,
                               degree=degree,
                               root_weight=root_weight,
                               dim=dim)(x, edge_index, pseudo=pseudo)
    assert out.shape == (4, sizes[-1])
    ball_out.assert_check_point_on_manifold(out)
def test_graph_conv(aggr_method, bias, learn_origin, sizes, weighted, improved, cached):
    ball = geoopt.PoincareBallExact()
    ball_out = geoopt.PoincareBallExact(c=0.1)
    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]])
    if weighted:
        edge_weight = torch.rand(edge_index.size(1))
    else:
        edge_weight = None
    x = ball.random(3, 5)
    out = HyperbolicGCNConv(
        *sizes,
        aggr_method=aggr_method,
        bias=bias,
        learn_origin=learn_origin,
        ball=ball,
        ball_out=ball_out,
        improved=improved,
        cached=cached
    )(x, edge_index, edge_weight=edge_weight)
    assert out.shape == (3, sizes[-1])
    ball_out.assert_check_point_on_manifold(out)
Esempio n. 5
0
def test_scale_poincare():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    v = torch.arange(10).float() / 10
    np.testing.assert_allclose(
        ball.dist0(ball.expmap0(v)).item(),
        sball.dist0(sball.expmap0(v)).item(),
        atol=1e-5,
    )
    np.testing.assert_allclose(
        sball.dist0(sball.expmap0(v)).item(),
        sball.norm(torch.zeros_like(v), v),
        atol=1e-5,
    )
Esempio n. 6
0
def test_message_passing():
    ball = geoopt.PoincareBallExact()
    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]])

    # x = torch.Tensor([[1], [2]])
    x = ball.random(2, 5)
    out = HyperbolicMessagePassing(flow="source_to_target",
                                   ball=ball).propagate(edge_index,
                                                        x=x,
                                                        size=(2, 3))
    assert out.shape == (3, 5)
    mean = geoopt_layers.poincare.math.poincare_mean(x, ball=ball)
    np.testing.assert_allclose(out,
                               mean.unsqueeze(0).expand_as(out),
                               atol=1e-5)

    # x = torch.Tensor([[1], [2], [3]])
    x = ball.random(3, 5)
    out = HyperbolicMessagePassing(flow="source_to_target",
                                   ball=ball).propagate(edge_index, x=x)
    assert out.shape == (3, 5)
    ball.assert_check_point_on_manifold(out)

    # x = torch.Tensor([[1], [2], [3]])
    x = ball.random(3, 5)
    out = HyperbolicMessagePassing(flow="target_to_source",
                                   ball=ball).propagate(edge_index,
                                                        x=x,
                                                        size=(2, 3))
    assert out.shape == (2, 5)
    ball.assert_check_point_on_manifold(out)

    # x = torch.Tensor([[1], [2], [3]])
    x = ball.random(3, 5)
    out = HyperbolicMessagePassing(flow="target_to_source",
                                   ball=ball).propagate(edge_index, x=x)
    assert out.shape == (3, 5)
    ball.assert_check_point_on_manifold(out)

    x = (ball.random(2, 5), ball.random(3, 5))
    out = HyperbolicMessagePassing(flow="source_to_target",
                                   ball=ball).propagate(edge_index, x=x)
    assert out.shape == (3, 5)
    ball.assert_check_point_on_manifold(out)

    x = (ball.random(2, 5), ball.random(3, 5))
    out = HyperbolicMessagePassing(flow="target_to_source",
                                   ball=ball).propagate(edge_index, x=x)
    assert out.shape == (2, 5)
    ball.assert_check_point_on_manifold(out)
Esempio n. 7
0
def poincare_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.PoincareBall]
    ex = torch.randn(*shape, dtype=torch.float64) / 3
    ev = torch.randn(*shape, dtype=torch.float64) / 3
    x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex)
    ex = x.clone()
    v = ev.clone()
    manifold = geoopt.PoincareBall().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.PoincareBallExact().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
Esempio n. 8
0
def test_scaling_getattr():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    pa, pb = sball.random(2, 10)
    # this one is representative and not present in __scaling__
    sball.geodesic(0.5, pa, pb)
Esempio n. 9
0
def test_scaling_compensates():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    rsball = geoopt.Scaled(sball, 0.5)
    v = torch.arange(10).float() / 10
    np.testing.assert_allclose(ball.expmap0(v), rsball.expmap0(v))