コード例 #1
0
def test_h2_to_sspd2(seed, n, d):
    spd = SPD(2)
    lorentz = Lorentz(3)

    x = lorentz.rand(n, ir=1.0).mul_(d)
    y = h2_to_sspd2(x)
    assert_allclose(sspd2_to_h2(y), x / d, atol=1e-4)
    hyp_dists = sspd2_hyp_radius_ * lorentz.pdist(x / d)
    assert_allclose(spd.pdist(y), hyp_dists, atol=1e-4)
コード例 #2
0
def test_sspd2_to_h2(seed, n):
    spd = SPD(2)
    lorentz = Lorentz(3)

    x = spd.rand(n, ir=1.0)
    x.div_(x.det().sqrt_().reshape(-1, 1, 1))  # unit determinant
    assert_allclose(x.det(), torch.ones(n), atol=1e-4)
    assert_allclose(x, spd.projx(x), atol=1e-4)

    y = sspd2_to_h2(x)
    hyp_dists = sspd2_hyp_radius_ * lorentz.pdist(y)
    assert_allclose(spd.pdist(x), hyp_dists, atol=1e-4)
コード例 #3
0
def test_sspd2_to_h2_nonconst_factor(seed, n, d):
    spd = SPD(2)
    lorentz = Lorentz(3)

    x = spd.rand(n, ir=1.0)
    x.div_(x.det().sqrt_().reshape(-1, 1, 1))  # unit determinant
    x.mul_(d)  # d**2 determinant
    dets = torch.empty(n).fill_(d**2)
    assert_allclose(x.det(), dets, atol=1e-4)
    assert_allclose(x, spd.projx(x), atol=1e-4)

    y = sspd2_to_h2(x)
    hyp_dists = sspd2_hyp_radius_ * lorentz.pdist(y)

    # The determinant essentially does not affect the curvatures, they are all
    # isometric to the 2-dimensional hyperbolic space of -1/2 constant sectional
    # curvature.
    assert_allclose(spd.pdist(x), hyp_dists, atol=1e-4)
コード例 #4
0
def test_no_nan_dists(seed, rand_spd, d, n):
    spd = SPD(d)
    x = rand_spd(n, d)
    assert not torch.isnan(spd.pdist(x)).any()