def test_weighted_midpoint_euclidean(lincomb): manifold = stereographic.Stereographic(0) a = geoopt.ManifoldParameter(manifold.random(2, 3, 10)) mid = manifold.weighted_midpoint(a, reducedim=[0], lincomb=lincomb) assert mid.shape == a.shape[-2:] if lincomb: assert torch.allclose(mid, a.sum(0)) else: assert torch.allclose(mid, a.mean(0))
def test_weighted_midpoint_reduce_dim(_k, lincomb): manifold = stereographic.Stereographic(_k, learnable=True) a = geoopt.ManifoldParameter(manifold.random(2, 3, 10)) mid = manifold.weighted_midpoint(a, reducedim=[0], lincomb=lincomb) assert mid.shape == a.shape[-2:] assert torch.isfinite(mid).all() mid.sum().backward() assert torch.isfinite(a.grad).all() assert not torch.isclose(manifold.k.grad, manifold.k.new_zeros(()))
def test_weighted_midpoint(_k, lincomb): manifold = stereographic.Stereographic(_k, learnable=True) a = manifold.random(2, 3, 10).requires_grad_(True) mid = manifold.weighted_midpoint(a, lincomb=lincomb) assert torch.isfinite(mid).all() assert mid.shape == (a.shape[-1], ) mid.sum().backward() assert torch.isfinite(a.grad).all() assert not torch.isclose(manifold.k.grad, manifold.k.new_zeros(()))
def test_weighted_midpoint_zero(_k, lincomb): manifold = stereographic.Stereographic(_k, learnable=True) a = geoopt.ManifoldParameter(manifold.random(2, 3, 10)) mid = manifold.weighted_midpoint( a, reducedim=[0], lincomb=lincomb, weights=torch.zeros_like(a[..., 0]) ) assert mid.shape == a.shape[-2:] assert torch.allclose(mid, torch.zeros_like(mid)) mid.sum().backward() assert torch.isfinite(a.grad).all() assert torch.isfinite(manifold.k.grad).all()
def test_weighted_midpoint_weighted(_k, lincomb): manifold = stereographic.Stereographic(_k, learnable=True) a = manifold.random(2, 3, 10).requires_grad_(True) mid = manifold.weighted_midpoint(a, reducedim=[0], lincomb=lincomb, weights=torch.rand_like(a[..., 0])) assert mid.shape == a.shape[-2:] assert torch.isfinite(mid).all() mid.sum().backward() assert torch.isfinite(a.grad).all() assert not torch.isclose(manifold.k.grad, manifold.k.new_zeros(()))
def test_weighted_midpoint_weighted_zero_sum(_k, lincomb): manifold = stereographic.Stereographic(_k, learnable=True) a = manifold.expmap0(torch.eye(3, 10)).detach().requires_grad_(True) weights = torch.rand_like(a[..., 0]) weights = weights - weights.sum() / weights.numel() mid = manifold.weighted_midpoint(a, lincomb=lincomb, weights=weights, posweight=True) if _k == 0 and lincomb: np.testing.assert_allclose( mid.detach(), torch.cat([weights, torch.zeros(a.size(-1) - a.size(0))]), atol=1e-6, ) assert mid.shape == a.shape[-1:] assert torch.isfinite(mid).all() mid.sum().backward() assert torch.isfinite(a.grad).all()
def manifold(k): return stereographic.Stereographic(k=k, learnable=True)