Beispiel #1
0
def test_weighted_midpoint_euclidean(lincomb):
    manifold = stereographic.Stereographic(0)
    a = geoopt.ManifoldParameter(manifold.random(2, 3, 10))
    mid = manifold.weighted_midpoint(a, reducedim=[0], lincomb=lincomb)
    assert mid.shape == a.shape[-2:]
    if lincomb:
        assert torch.allclose(mid, a.sum(0))
    else:
        assert torch.allclose(mid, a.mean(0))
Beispiel #2
0
def test_weighted_midpoint_reduce_dim(_k, lincomb):
    manifold = stereographic.Stereographic(_k, learnable=True)
    a = geoopt.ManifoldParameter(manifold.random(2, 3, 10))
    mid = manifold.weighted_midpoint(a, reducedim=[0], lincomb=lincomb)
    assert mid.shape == a.shape[-2:]
    assert torch.isfinite(mid).all()
    mid.sum().backward()
    assert torch.isfinite(a.grad).all()
    assert not torch.isclose(manifold.k.grad, manifold.k.new_zeros(()))
Beispiel #3
0
def test_weighted_midpoint(_k, lincomb):
    manifold = stereographic.Stereographic(_k, learnable=True)
    a = manifold.random(2, 3, 10).requires_grad_(True)
    mid = manifold.weighted_midpoint(a, lincomb=lincomb)
    assert torch.isfinite(mid).all()
    assert mid.shape == (a.shape[-1], )
    mid.sum().backward()
    assert torch.isfinite(a.grad).all()
    assert not torch.isclose(manifold.k.grad, manifold.k.new_zeros(()))
Beispiel #4
0
def test_weighted_midpoint_zero(_k, lincomb):
    manifold = stereographic.Stereographic(_k, learnable=True)
    a = geoopt.ManifoldParameter(manifold.random(2, 3, 10))
    mid = manifold.weighted_midpoint(
        a, reducedim=[0], lincomb=lincomb, weights=torch.zeros_like(a[..., 0])
    )
    assert mid.shape == a.shape[-2:]
    assert torch.allclose(mid, torch.zeros_like(mid))
    mid.sum().backward()
    assert torch.isfinite(a.grad).all()
    assert torch.isfinite(manifold.k.grad).all()
Beispiel #5
0
def test_weighted_midpoint_weighted(_k, lincomb):
    manifold = stereographic.Stereographic(_k, learnable=True)
    a = manifold.random(2, 3, 10).requires_grad_(True)
    mid = manifold.weighted_midpoint(a,
                                     reducedim=[0],
                                     lincomb=lincomb,
                                     weights=torch.rand_like(a[..., 0]))
    assert mid.shape == a.shape[-2:]
    assert torch.isfinite(mid).all()
    mid.sum().backward()
    assert torch.isfinite(a.grad).all()
    assert not torch.isclose(manifold.k.grad, manifold.k.new_zeros(()))
Beispiel #6
0
def test_weighted_midpoint_weighted_zero_sum(_k, lincomb):
    manifold = stereographic.Stereographic(_k, learnable=True)
    a = manifold.expmap0(torch.eye(3, 10)).detach().requires_grad_(True)
    weights = torch.rand_like(a[..., 0])
    weights = weights - weights.sum() / weights.numel()
    mid = manifold.weighted_midpoint(a,
                                     lincomb=lincomb,
                                     weights=weights,
                                     posweight=True)
    if _k == 0 and lincomb:
        np.testing.assert_allclose(
            mid.detach(),
            torch.cat([weights, torch.zeros(a.size(-1) - a.size(0))]),
            atol=1e-6,
        )
    assert mid.shape == a.shape[-1:]
    assert torch.isfinite(mid).all()
    mid.sum().backward()
    assert torch.isfinite(a.grad).all()
Beispiel #7
0
def manifold(k):
    return stereographic.Stereographic(k=k, learnable=True)