예제 #1
0
def sphere_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = ex / torch.norm(ex)
    v = ev - (x @ ev) * x

    manifold = geoopt.Sphere()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
예제 #2
0
def sphere_subspace_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    subspace = torch.rand(shape[-1], 2, dtype=torch.float64)

    Q, _ = geoopt.linalg.batch_linalg.qr(subspace, "reduced")
    P = Q @ Q.t()

    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = (ex @ P.t()) / torch.norm(ex @ P.t())
    v = (ev - (x @ ev) * x) @ P.t()

    manifold = geoopt.Sphere(intersection=subspace)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact(intersection=subspace)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
예제 #3
0
def sphere_compliment_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    complement = torch.rand(shape[-1], 1, dtype=torch.float64)

    Q, _ = geoopt.linalg.batch_linalg.qr(complement, "reduced")
    P = -Q @ Q.transpose(-1, -2)
    P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1

    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = (ex @ P.t()) / torch.norm(ex @ P.t())
    v = (ev - (x @ ev) * x) @ P.t()

    manifold = geoopt.Sphere(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case