def test_scaling_not_implemented(): ball = geoopt.PoincareBallExact() sball = geoopt.Scaled(ball, 2) pa = sball.random(10) with pytest.raises(NotImplementedError) as e: sball.mobius_fn_apply(lambda x: x, pa) assert e.match("Scaled version of 'mobius_fn_apply' is not available")
def test_rescaling_methods_accessible(): ball = geoopt.PoincareBallExact() sball = geoopt.Scaled(ball, 2) rsball = geoopt.Scaled(sball, 0.5) v0 = torch.arange(10).float() / 10 v1 = -torch.arange(10).float() / 10 rsball.geodesic(0.5, v0, v1)
def test_spline_conv(bias, sizes, kernel_size, degree, root_weight, dim): ball = geoopt.PoincareBallExact() ball_out = geoopt.PoincareBallExact(c=0.1) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) x = ball.random(4, 5) pseudo = torch.rand(edge_index.size(1), dim) out = HyperbolicSplineConv(*sizes, bias=bias, ball=ball, ball_out=ball_out, kernel_size=kernel_size, degree=degree, root_weight=root_weight, dim=dim)(x, edge_index, pseudo=pseudo) assert out.shape == (4, sizes[-1]) ball_out.assert_check_point_on_manifold(out)
def test_graph_conv(aggr_method, bias, learn_origin, sizes, weighted, improved, cached): ball = geoopt.PoincareBallExact() ball_out = geoopt.PoincareBallExact(c=0.1) edge_index = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]) if weighted: edge_weight = torch.rand(edge_index.size(1)) else: edge_weight = None x = ball.random(3, 5) out = HyperbolicGCNConv( *sizes, aggr_method=aggr_method, bias=bias, learn_origin=learn_origin, ball=ball, ball_out=ball_out, improved=improved, cached=cached )(x, edge_index, edge_weight=edge_weight) assert out.shape == (3, sizes[-1]) ball_out.assert_check_point_on_manifold(out)
def test_scale_poincare(): ball = geoopt.PoincareBallExact() sball = geoopt.Scaled(ball, 2) v = torch.arange(10).float() / 10 np.testing.assert_allclose( ball.dist0(ball.expmap0(v)).item(), sball.dist0(sball.expmap0(v)).item(), atol=1e-5, ) np.testing.assert_allclose( sball.dist0(sball.expmap0(v)).item(), sball.norm(torch.zeros_like(v), v), atol=1e-5, )
def test_message_passing(): ball = geoopt.PoincareBallExact() edge_index = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]]) # x = torch.Tensor([[1], [2]]) x = ball.random(2, 5) out = HyperbolicMessagePassing(flow="source_to_target", ball=ball).propagate(edge_index, x=x, size=(2, 3)) assert out.shape == (3, 5) mean = geoopt_layers.poincare.math.poincare_mean(x, ball=ball) np.testing.assert_allclose(out, mean.unsqueeze(0).expand_as(out), atol=1e-5) # x = torch.Tensor([[1], [2], [3]]) x = ball.random(3, 5) out = HyperbolicMessagePassing(flow="source_to_target", ball=ball).propagate(edge_index, x=x) assert out.shape == (3, 5) ball.assert_check_point_on_manifold(out) # x = torch.Tensor([[1], [2], [3]]) x = ball.random(3, 5) out = HyperbolicMessagePassing(flow="target_to_source", ball=ball).propagate(edge_index, x=x, size=(2, 3)) assert out.shape == (2, 5) ball.assert_check_point_on_manifold(out) # x = torch.Tensor([[1], [2], [3]]) x = ball.random(3, 5) out = HyperbolicMessagePassing(flow="target_to_source", ball=ball).propagate(edge_index, x=x) assert out.shape == (3, 5) ball.assert_check_point_on_manifold(out) x = (ball.random(2, 5), ball.random(3, 5)) out = HyperbolicMessagePassing(flow="source_to_target", ball=ball).propagate(edge_index, x=x) assert out.shape == (3, 5) ball.assert_check_point_on_manifold(out) x = (ball.random(2, 5), ball.random(3, 5)) out = HyperbolicMessagePassing(flow="target_to_source", ball=ball).propagate(edge_index, x=x) assert out.shape == (2, 5) ball.assert_check_point_on_manifold(out)
def poincare_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.PoincareBall] ex = torch.randn(*shape, dtype=torch.float64) / 3 ev = torch.randn(*shape, dtype=torch.float64) / 3 x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex) ex = x.clone() v = ev.clone() manifold = geoopt.PoincareBall().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.PoincareBallExact().to(dtype=torch.float64) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def test_scaling_getattr(): ball = geoopt.PoincareBallExact() sball = geoopt.Scaled(ball, 2) pa, pb = sball.random(2, 10) # this one is representative and not present in __scaling__ sball.geodesic(0.5, pa, pb)
def test_scaling_compensates(): ball = geoopt.PoincareBallExact() sball = geoopt.Scaled(ball, 2) rsball = geoopt.Scaled(sball, 0.5) v = torch.arange(10).float() / 10 np.testing.assert_allclose(ball.expmap0(v), rsball.expmap0(v))