Exemple #1
0
def test_expmap0():
    manifolds = (
        (PoincareBall(), 2),
        (SphereProjection(), 2),
        (Stereographic(), 2),
    )
    pman = StereographicProductManifold(*manifolds)
    u = torch.randn(6)
    x = pman.expmap0(u)
    assert pman._check_point_on_manifold(x)[0]

    for i, (manifold, _) in enumerate(manifolds):
        x_ = manifold.expmap0(pman.take_submanifold_value(u, i))

        assert torch.isclose(pman.take_submanifold_value(x, i), x_).all()
Exemple #2
0
def test_mobius_add():
    manifolds = (
        (PoincareBall(), 2),
        (SphereProjection(), 2),
        (Stereographic(), 2),
    )
    pman = StereographicProductManifold(*manifolds)
    x = pman.random(6)
    y = pman.random(6)
    z = pman.mobius_add(x, y)
    assert pman._check_point_on_manifold(z)[0]

    for i, (manifold, _) in enumerate(manifolds):
        x_ = pman.take_submanifold_value(x, i)
        y_ = pman.take_submanifold_value(y, i)
        z_ = manifold.mobius_add(x_, y_)

        assert torch.isclose(pman.take_submanifold_value(z, i), z_).all()
Exemple #3
0
def test_dist2plane():
    manifolds = (
        (PoincareBall(), 2),
        (SphereProjection(), 2),
        (Stereographic(), 2),
    )
    pman = StereographicProductManifold(*manifolds)
    x = pman.random(6)
    p = pman.random(6)
    a = torch.randn(6)

    dist = pman.dist2plane(x, p, a)
    dists = []
    for i, (manifold, _) in enumerate(manifolds):
        x_ = pman.take_submanifold_value(x, i)
        p_ = pman.take_submanifold_value(p, i)
        a_ = pman.take_submanifold_value(a, i)

        dists.append(manifold.dist2plane(x_, p_, a_))
    dists = torch.tensor(dists)**2
    assert torch.isclose(dists.sum().sqrt(), dist).all()
Exemple #4
0
def add_geodesic_grid(ax: plt.Axes, manifold: Stereographic, line_width=0.1):
    import math
    # define geodesic grid parameters
    N_EVALS_PER_GEODESIC = 10000
    STYLE = "--"
    COLOR = "gray"
    LINE_WIDTH = line_width

    # get manifold properties
    K = manifold.k.item()
    R = manifold.radius.item()

    # get maximal numerical distance to origin on manifold
    if K < 0:
        # create point on R
        r = torch.tensor((R, 0.0), dtype=manifold.dtype)
        # project point on R into valid range (epsilon border)
        r = manifold.projx(r)
        # determine distance from origin
        max_dist_0 = manifold.dist0(r).item()
    else:
        max_dist_0 = math.pi * R
    # adjust line interval for spherical geometry
    circumference = 2 * math.pi * R

    # determine reasonable number of geodesics
    # choose the grid interval size always as if we'd be in spherical
    # geometry, such that the grid interpolates smoothly and evenly
    # divides the sphere circumference
    n_geodesics_per_circumference = 4 * 6  # multiple of 4!
    n_geodesics_per_quadrant = n_geodesics_per_circumference // 2
    grid_interval_size = circumference / n_geodesics_per_circumference
    if K < 0:
        n_geodesics_per_quadrant = int(max_dist_0 / grid_interval_size)

    # create time evaluation array for geodesics
    if K < 0:
        min_t = -1.2 * max_dist_0
    else:
        min_t = -circumference / 2.0
    t = torch.linspace(min_t, -min_t, N_EVALS_PER_GEODESIC)[:, None]

    # define a function to plot the geodesics
    def plot_geodesic(gv):
        ax.plot(*gv.t().numpy(), STYLE, color=COLOR, linewidth=LINE_WIDTH)

    # define geodesic directions
    u_x = torch.tensor((0.0, 1.0))
    u_y = torch.tensor((1.0, 0.0))

    # add origin x/y-crosshair
    o = torch.tensor((0.0, 0.0))
    if K < 0:
        x_geodesic = manifold.geodesic_unit(t, o, u_x)
        y_geodesic = manifold.geodesic_unit(t, o, u_y)
        plot_geodesic(x_geodesic)
        plot_geodesic(y_geodesic)
    else:
        # add the crosshair manually for the sproj of sphere
        # because the lines tend to get thicker if plotted
        # as done for K<0
        ax.axvline(0, linestyle=STYLE, color=COLOR, linewidth=LINE_WIDTH)
        ax.axhline(0, linestyle=STYLE, color=COLOR, linewidth=LINE_WIDTH)

    # add geodesics per quadrant
    for i in range(1, n_geodesics_per_quadrant):
        i = torch.as_tensor(float(i))
        # determine start of geodesic on x/y-crosshair
        x = manifold.geodesic_unit(i * grid_interval_size, o, u_y)
        y = manifold.geodesic_unit(i * grid_interval_size, o, u_x)

        # compute point on geodesics
        x_geodesic = manifold.geodesic_unit(t, x, u_x)
        y_geodesic = manifold.geodesic_unit(t, y, u_y)

        # plot geodesics
        plot_geodesic(x_geodesic)
        plot_geodesic(y_geodesic)
        if K < 0:
            plot_geodesic(-x_geodesic)
            plot_geodesic(-y_geodesic)
Exemple #5
0
x = torch.tensor((-0.25, -0.75))
xv1 = torch.tensor((np.sin(np.pi / 3), np.cos(np.pi / 3))) / 5
xv2 = torch.tensor((np.sin(-np.pi / 3), np.cos(np.pi / 3))) / 5
t = torch.linspace(0, 1, 10)[:, None]


def plot_gv(a, gv, **kwargs):
    a.plot(*gv.t().numpy(), **kwargs)
    a.arrow(*gv[-2], *(gv[-1] - gv[-2]), width=0.01, **kwargs)


fig, ax = plt.subplots(1, 2, figsize=(9, 4))
fig.suptitle(r"gyrovector parallel transport $P_{x\to y}$")

manifold = Stereographic(-1)

y = torch.tensor((0.65, -0.55))
xy = manifold.logmap(x, y)
path = manifold.geodesic(t, x, y)
yv1 = manifold.transp(x, y, xv1)
yv2 = manifold.transp(x, y, xv2)

xgv1 = manifold.geodesic_unit(t, x, xv1)
xgv2 = manifold.geodesic_unit(t, x, xv2)

ygv1 = manifold.geodesic_unit(t, y, yv1)
ygv2 = manifold.geodesic_unit(t, y, yv2)

circle = plt.Circle((0, 0), 1, fill=False, color="b")
Exemple #6
0
from geoopt import Stereographic
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
import shutil
if shutil.which("latex") is not None:
    rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
    rcParams["text.usetex"] = True

sns.set_style("white")

x = torch.tensor((-0.25, -0.75)) / 2
y = torch.tensor((0.65, -0.55)) / 2

manifold = Stereographic(-1)

x_plus_y = manifold.mobius_add(x, y)

circle = plt.Circle((0, 0), 1, fill=False, color="b")
plt.gca().add_artist(circle)
plt.xlim(-1.1, 1.1)
plt.ylim(-1.1, 1.1)
plt.gca().set_aspect("equal")
plt.annotate("x", x - 0.09, fontsize=15)
plt.annotate("y", y - 0.09, fontsize=15)
plt.annotate(r"$x\oplus y$", x_plus_y - torch.tensor([0.1, 0.15]), fontsize=15)
plt.arrow(0, 0, *x, width=0.01, color="r")
plt.arrow(0, 0, *y, width=0.01, color="g")
plt.arrow(0, 0, *x_plus_y, width=0.01, color="b")
plt.title(r"Addition $x\oplus y$")
Exemple #7
0
        if K < 0:
            plot_geodesic(-x_geodesic)
            plot_geodesic(-y_geodesic)


lim = 1.1
coords = np.linspace(-lim, lim, 100)
x = torch.tensor([-0.75, 0])
v = torch.tensor([0.1 / 3, 0.0])
xx, yy = np.meshgrid(coords, coords)
dist2 = xx**2 + yy**2
mask = dist2 <= 1
grid = np.stack([xx, yy], axis=-1)
fig, ax = plt.subplots(1, 2, figsize=(9, 4))

manifold = Stereographic(-1)

dists = manifold.dist2plane(torch.from_numpy(grid).float(), x, v)
dists[(~mask).nonzero()] = np.nan
circle = plt.Circle((0, 0), 1, fill=False, color="b")

ax[0].add_artist(circle)
ax[0].set_xlim(-lim, lim)
ax[0].set_ylim(-lim, lim)
ax[0].set_aspect("equal")
ax[0].contourf(grid[..., 0],
               grid[..., 1],
               dists.log().numpy(),
               levels=100,
               cmap="inferno")
add_geodesic_grid(ax[0], manifold, 0.5)
Exemple #8
0
from geoopt import Stereographic
from matplotlib import rcParams
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import shutil
if shutil.which("latex") is not None:
    rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
    rcParams["text.usetex"] = True

sns.set_style("white")

x = torch.tensor((-0.25, -0.75)) / 3
manifold = Stereographic(-1)
f_x = manifold.mobius_fn_apply(torch.sigmoid, x)

circle = plt.Circle((0, 0), 1, fill=False, color="b")
plt.gca().add_artist(circle)
plt.xlim(-1.1, 1.1)
plt.ylim(-1.1, 1.1)
plt.gca().set_aspect("equal")
plt.annotate("x", x - 0.09, fontsize=15)
plt.annotate(r"$\sigma(x)=\frac{1}{1+e^{-x}}$",
             x + torch.tensor([-0.7, 0.5]),
             fontsize=15)
plt.annotate(r"$\sigma^\otimes(x)$",
             f_x - torch.tensor([0.1, 0.15]),
             fontsize=15)
plt.arrow(0, 0, *x, width=0.01, color="r")
plt.arrow(0, 0, *f_x, width=0.01, color="b")
plt.title(r"Mobius function (sigmoid) apply $\sigma^\otimes(x)$")
Exemple #9
0
from geoopt import Stereographic
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
import shutil
if shutil.which("latex") is not None:
    rcParams["text.latex.preamble"] = r"\usepackage{amsmath}"
    rcParams["text.usetex"] = True

sns.set_style("white")
x = torch.tensor((-0.25, -0.75)) / 3
M = torch.tensor([[-1, -1.5], [0.2, 0.5]])

manifold = Stereographic(-1)
M_x = manifold.mobius_matvec(M, x)

circle = plt.Circle((0, 0), 1, fill=False, color="b")
plt.gca().add_artist(circle)
plt.xlim(-1.1, 1.1)
plt.ylim(-1.1, 1.1)
plt.gca().set_aspect("equal")
plt.annotate("x", x - 0.09, fontsize=15)
if shutil.which("latex") is not None:
    plt.annotate(
        r"$M=\begin{bmatrix}-1 &-1.5\\.2 &.5\end{bmatrix}$",
        x + torch.tensor([-0.5, 0.5]),
        fontsize=15,
    )

plt.annotate(r"$M^\otimes x$", M_x - torch.tensor([0.1, 0.15]), fontsize=15)