def sphere_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = ex / torch.norm(ex) v = ev - (x @ ev) * x manifold = geoopt.Sphere() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact() x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def sphere_subspace_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] subspace = torch.rand(shape[-1], 2, dtype=torch.float64) Q, _ = geoopt.linalg.batch_linalg.qr(subspace, "reduced") P = Q @ Q.t() ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = (ex @ P.t()) / torch.norm(ex @ P.t()) v = (ev - (x @ ev) * x) @ P.t() manifold = geoopt.Sphere(intersection=subspace) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact(intersection=subspace) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case
def sphere_compliment_case(): torch.manual_seed(42) shape = manifold_shapes[geoopt.manifolds.Sphere] complement = torch.rand(shape[-1], 1, dtype=torch.float64) Q, _ = geoopt.linalg.batch_linalg.qr(complement, "reduced") P = -Q @ Q.transpose(-1, -2) P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1 ex = torch.randn(*shape, dtype=torch.float64) ev = torch.randn(*shape, dtype=torch.float64) x = (ex @ P.t()) / torch.norm(ex @ P.t()) v = (ev - (x @ ev) * x) @ P.t() manifold = geoopt.Sphere(complement=complement) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case manifold = geoopt.SphereExact(complement=complement) x = geoopt.ManifoldTensor(x, manifold=manifold) case = UnaryCase(shape, x, ex, v, ev, manifold) yield case