コード例 #1
0
ファイル: test_scaling.py プロジェクト: xiongbo010/geoopt
def test_rescaling_methods_accessible():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    rsball = geoopt.Scaled(sball, 0.5)
    v0 = torch.arange(10).float() / 10
    v1 = -torch.arange(10).float() / 10
    rsball.geodesic(0.5, v0, v1)
コード例 #2
0
ファイル: test_utils.py プロジェクト: xiongbo010/geoopt
def test_ismanifold():
    m1 = geoopt.Euclidean()
    assert geoopt.ismanifold(m1, geoopt.Euclidean)
    m1 = geoopt.Scaled(m1)
    m1 = geoopt.Scaled(m1)
    assert geoopt.ismanifold(m1, geoopt.Euclidean)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, int)

    with pytest.raises(TypeError):
        geoopt.ismanifold(m1, 1)

    assert not geoopt.ismanifold(1, geoopt.Euclidean)
コード例 #3
0
ファイル: test_manifold_basic.py プロジェクト: geoopt/geoopt
def unary_case(unary_case_base, scaled):
    if scaled:
        return unary_case_base._replace(
            manifold=geoopt.Scaled(unary_case_base.manifold, 2)
        )
    else:
        return unary_case_base
コード例 #4
0
ファイル: test_scaling.py プロジェクト: xiongbo010/geoopt
def test_scaling_not_implemented():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    pa = sball.random(10)
    with pytest.raises(NotImplementedError) as e:
        sball.mobius_fn_apply(lambda x: x, pa)
    assert e.match("Scaled version of 'mobius_fn_apply' is not available")
コード例 #5
0
ファイル: test_scaling.py プロジェクト: xiongbo010/geoopt
def test_scale_poincare():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    v = torch.arange(10).float() / 10
    np.testing.assert_allclose(
        ball.dist0(ball.expmap0(v)).item(),
        sball.dist0(sball.expmap0(v)).item(),
        atol=1e-5,
    )
    np.testing.assert_allclose(
        sball.dist0(sball.expmap0(v)).item(),
        sball.norm(torch.zeros_like(v), v),
        atol=1e-5,
    )
コード例 #6
0
ファイル: test_scaling.py プロジェクト: xiongbo010/geoopt
def test_scaling_getattr():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    pa, pb = sball.random(2, 10)
    # this one is representative and not present in __scaling__
    sball.geodesic(0.5, pa, pb)
コード例 #7
0
ファイル: test_scaling.py プロジェクト: xiongbo010/geoopt
def test_scaling_compensates():
    ball = geoopt.PoincareBallExact()
    sball = geoopt.Scaled(ball, 2)
    rsball = geoopt.Scaled(sball, 0.5)
    v = torch.arange(10).float() / 10
    np.testing.assert_allclose(ball.expmap0(v), rsball.expmap0(v))