def test_expmap0(): manifolds = ( (PoincareBall(), 2), (SphereProjection(), 2), (Stereographic(), 2), ) pman = StereographicProductManifold(*manifolds) u = torch.randn(6) x = pman.expmap0(u) assert pman._check_point_on_manifold(x)[0] for i, (manifold, _) in enumerate(manifolds): x_ = manifold.expmap0(pman.take_submanifold_value(u, i)) assert torch.isclose(pman.take_submanifold_value(x, i), x_).all()
def test_mobius_add(): manifolds = ( (PoincareBall(), 2), (SphereProjection(), 2), (Stereographic(), 2), ) pman = StereographicProductManifold(*manifolds) x = pman.random(6) y = pman.random(6) z = pman.mobius_add(x, y) assert pman._check_point_on_manifold(z)[0] for i, (manifold, _) in enumerate(manifolds): x_ = pman.take_submanifold_value(x, i) y_ = pman.take_submanifold_value(y, i) z_ = manifold.mobius_add(x_, y_) assert torch.isclose(pman.take_submanifold_value(z, i), z_).all()
def test_dist2plane(): manifolds = ( (PoincareBall(), 2), (SphereProjection(), 2), (Stereographic(), 2), ) pman = StereographicProductManifold(*manifolds) x = pman.random(6) p = pman.random(6) a = torch.randn(6) dist = pman.dist2plane(x, p, a) dists = [] for i, (manifold, _) in enumerate(manifolds): x_ = pman.take_submanifold_value(x, i) p_ = pman.take_submanifold_value(p, i) a_ = pman.take_submanifold_value(a, i) dists.append(manifold.dist2plane(x_, p_, a_)) dists = torch.tensor(dists)**2 assert torch.isclose(dists.sum().sqrt(), dist).all()
def add_geodesic_grid(ax: plt.Axes, manifold: Stereographic, line_width=0.1): import math # define geodesic grid parameters N_EVALS_PER_GEODESIC = 10000 STYLE = "--" COLOR = "gray" LINE_WIDTH = line_width # get manifold properties K = manifold.k.item() R = manifold.radius.item() # get maximal numerical distance to origin on manifold if K < 0: # create point on R r = torch.tensor((R, 0.0), dtype=manifold.dtype) # project point on R into valid range (epsilon border) r = manifold.projx(r) # determine distance from origin max_dist_0 = manifold.dist0(r).item() else: max_dist_0 = math.pi * R # adjust line interval for spherical geometry circumference = 2 * math.pi * R # determine reasonable number of geodesics # choose the grid interval size always as if we'd be in spherical # geometry, such that the grid interpolates smoothly and evenly # divides the sphere circumference n_geodesics_per_circumference = 4 * 6 # multiple of 4! n_geodesics_per_quadrant = n_geodesics_per_circumference // 2 grid_interval_size = circumference / n_geodesics_per_circumference if K < 0: n_geodesics_per_quadrant = int(max_dist_0 / grid_interval_size) # create time evaluation array for geodesics if K < 0: min_t = -1.2 * max_dist_0 else: min_t = -circumference / 2.0 t = torch.linspace(min_t, -min_t, N_EVALS_PER_GEODESIC)[:, None] # define a function to plot the geodesics def plot_geodesic(gv): ax.plot(*gv.t().numpy(), STYLE, color=COLOR, linewidth=LINE_WIDTH) # define geodesic directions u_x = torch.tensor((0.0, 1.0)) u_y = torch.tensor((1.0, 0.0)) # add origin x/y-crosshair o = torch.tensor((0.0, 0.0)) if K < 0: x_geodesic = manifold.geodesic_unit(t, o, u_x) y_geodesic = manifold.geodesic_unit(t, o, u_y) plot_geodesic(x_geodesic) plot_geodesic(y_geodesic) else: # add the crosshair manually for the sproj of sphere # because the lines tend to get thicker if plotted # as done for K<0 ax.axvline(0, linestyle=STYLE, color=COLOR, linewidth=LINE_WIDTH) ax.axhline(0, linestyle=STYLE, color=COLOR, linewidth=LINE_WIDTH) # add geodesics per quadrant for i in range(1, n_geodesics_per_quadrant): i = torch.as_tensor(float(i)) # determine start of geodesic on x/y-crosshair x = manifold.geodesic_unit(i * grid_interval_size, o, u_y) y = manifold.geodesic_unit(i * grid_interval_size, o, u_x) # compute point on geodesics x_geodesic = manifold.geodesic_unit(t, x, u_x) y_geodesic = manifold.geodesic_unit(t, y, u_y) # plot geodesics plot_geodesic(x_geodesic) plot_geodesic(y_geodesic) if K < 0: plot_geodesic(-x_geodesic) plot_geodesic(-y_geodesic)
x = torch.tensor((-0.25, -0.75)) xv1 = torch.tensor((np.sin(np.pi / 3), np.cos(np.pi / 3))) / 5 xv2 = torch.tensor((np.sin(-np.pi / 3), np.cos(np.pi / 3))) / 5 t = torch.linspace(0, 1, 10)[:, None] def plot_gv(a, gv, **kwargs): a.plot(*gv.t().numpy(), **kwargs) a.arrow(*gv[-2], *(gv[-1] - gv[-2]), width=0.01, **kwargs) fig, ax = plt.subplots(1, 2, figsize=(9, 4)) fig.suptitle(r"gyrovector parallel transport $P_{x\to y}$") manifold = Stereographic(-1) y = torch.tensor((0.65, -0.55)) xy = manifold.logmap(x, y) path = manifold.geodesic(t, x, y) yv1 = manifold.transp(x, y, xv1) yv2 = manifold.transp(x, y, xv2) xgv1 = manifold.geodesic_unit(t, x, xv1) xgv2 = manifold.geodesic_unit(t, x, xv2) ygv1 = manifold.geodesic_unit(t, y, yv1) ygv2 = manifold.geodesic_unit(t, y, yv2) circle = plt.Circle((0, 0), 1, fill=False, color="b")
from geoopt import Stereographic import torch import matplotlib.pyplot as plt import seaborn as sns from matplotlib import rcParams import shutil if shutil.which("latex") is not None: rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" rcParams["text.usetex"] = True sns.set_style("white") x = torch.tensor((-0.25, -0.75)) / 2 y = torch.tensor((0.65, -0.55)) / 2 manifold = Stereographic(-1) x_plus_y = manifold.mobius_add(x, y) circle = plt.Circle((0, 0), 1, fill=False, color="b") plt.gca().add_artist(circle) plt.xlim(-1.1, 1.1) plt.ylim(-1.1, 1.1) plt.gca().set_aspect("equal") plt.annotate("x", x - 0.09, fontsize=15) plt.annotate("y", y - 0.09, fontsize=15) plt.annotate(r"$x\oplus y$", x_plus_y - torch.tensor([0.1, 0.15]), fontsize=15) plt.arrow(0, 0, *x, width=0.01, color="r") plt.arrow(0, 0, *y, width=0.01, color="g") plt.arrow(0, 0, *x_plus_y, width=0.01, color="b") plt.title(r"Addition $x\oplus y$")
if K < 0: plot_geodesic(-x_geodesic) plot_geodesic(-y_geodesic) lim = 1.1 coords = np.linspace(-lim, lim, 100) x = torch.tensor([-0.75, 0]) v = torch.tensor([0.1 / 3, 0.0]) xx, yy = np.meshgrid(coords, coords) dist2 = xx**2 + yy**2 mask = dist2 <= 1 grid = np.stack([xx, yy], axis=-1) fig, ax = plt.subplots(1, 2, figsize=(9, 4)) manifold = Stereographic(-1) dists = manifold.dist2plane(torch.from_numpy(grid).float(), x, v) dists[(~mask).nonzero()] = np.nan circle = plt.Circle((0, 0), 1, fill=False, color="b") ax[0].add_artist(circle) ax[0].set_xlim(-lim, lim) ax[0].set_ylim(-lim, lim) ax[0].set_aspect("equal") ax[0].contourf(grid[..., 0], grid[..., 1], dists.log().numpy(), levels=100, cmap="inferno") add_geodesic_grid(ax[0], manifold, 0.5)
from geoopt import Stereographic from matplotlib import rcParams import torch import matplotlib.pyplot as plt import seaborn as sns import shutil if shutil.which("latex") is not None: rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" rcParams["text.usetex"] = True sns.set_style("white") x = torch.tensor((-0.25, -0.75)) / 3 manifold = Stereographic(-1) f_x = manifold.mobius_fn_apply(torch.sigmoid, x) circle = plt.Circle((0, 0), 1, fill=False, color="b") plt.gca().add_artist(circle) plt.xlim(-1.1, 1.1) plt.ylim(-1.1, 1.1) plt.gca().set_aspect("equal") plt.annotate("x", x - 0.09, fontsize=15) plt.annotate(r"$\sigma(x)=\frac{1}{1+e^{-x}}$", x + torch.tensor([-0.7, 0.5]), fontsize=15) plt.annotate(r"$\sigma^\otimes(x)$", f_x - torch.tensor([0.1, 0.15]), fontsize=15) plt.arrow(0, 0, *x, width=0.01, color="r") plt.arrow(0, 0, *f_x, width=0.01, color="b") plt.title(r"Mobius function (sigmoid) apply $\sigma^\otimes(x)$")
from geoopt import Stereographic import torch import matplotlib.pyplot as plt import seaborn as sns from matplotlib import rcParams import shutil if shutil.which("latex") is not None: rcParams["text.latex.preamble"] = r"\usepackage{amsmath}" rcParams["text.usetex"] = True sns.set_style("white") x = torch.tensor((-0.25, -0.75)) / 3 M = torch.tensor([[-1, -1.5], [0.2, 0.5]]) manifold = Stereographic(-1) M_x = manifold.mobius_matvec(M, x) circle = plt.Circle((0, 0), 1, fill=False, color="b") plt.gca().add_artist(circle) plt.xlim(-1.1, 1.1) plt.ylim(-1.1, 1.1) plt.gca().set_aspect("equal") plt.annotate("x", x - 0.09, fontsize=15) if shutil.which("latex") is not None: plt.annotate( r"$M=\begin{bmatrix}-1 &-1.5\\.2 &.5\end{bmatrix}$", x + torch.tensor([-0.5, 0.5]), fontsize=15, ) plt.annotate(r"$M^\otimes x$", M_x - torch.tensor([0.1, 0.15]), fontsize=15)