예제 #1
0
def test_exp_log(seed, rand_spd, rand_sym, d):
    spd = SPD(d)
    x = rand_spd(10, d)
    u = rand_sym(10, d)
    y = spd.exp(x, u)
    assert_allclose(u, spd.log(x, y), atol=1e-4)
    assert_allclose(spd.norm(x, u), spd.dist(x, y), atol=1e-4)
예제 #2
0
def test_stein_pdiv(seed, d):
    spd = SPD(2)
    xs = spd.rand(10, ir=1.0, out=torch.empty(10, d, d, dtype=torch.float64))
    pdivs = spd.stein_pdiv(xs)
    m = torch.triu_indices(10, 10, 1)
    ref_pdivs = spd.stein_div(xs[m[0]], xs[m[1]])
    assert_allclose(ref_pdivs, pdivs, atol=1e-4)
예제 #3
0
def build_manifold(*names):
    from graphembed.manifolds import (Euclidean, Grassmann, Lorentz,
                                      SymmetricPositiveDefinite,
                                      SpecialOrthogonalGroup, Sphere)

    factors = []
    for name in names:
        parts = name.split('_')
        identifier = parts[0]
        if identifier in ['euc', 'sph', 'hyp', 'so', 'spd', 'spdstein']:
            n = int(parts[1])
        elif identifier in ['grass']:
            n1 = int(parts[1])
            n2 = int(parts[2])
        else:
            raise ValueError(f'Unkown manifold identifier {identifier}')

        if identifier == 'euc':
            man = Euclidean(n)
        elif identifier == 'sph':
            man = Sphere(n)
        elif identifier == 'hyp':
            man = Lorentz(n)
        elif identifier == 'so':
            man = SpecialOrthogonalGroup(n)
        elif identifier == 'spd':
            man = SymmetricPositiveDefinite(n)
        elif identifier == 'spdstein':
            man = SymmetricPositiveDefinite(n, use_stein_div=True)
        elif identifier == 'grass':
            man = Grassmann(n1, n2)

        factors.append(man)

    return factors
예제 #4
0
def test_gradient(seed, d):
    spd = SPD(d)
    x, y = spd.rand(2, ir=1.0, out=torch.empty(2, d, d, dtype=torch.float64))
    x.requires_grad_()
    dist = 0.5 * spd.dist(x, y, squared=True)
    grad_e = torch.autograd.grad(dist, x)[0]
    grad = spd.egrad2rgrad(x, grad_e)
    assert_allclose(grad.detach(), -spd.log(x.detach(), y), atol=1e-4)
예제 #5
0
def test_unit_distance(d, seed):
    spd = SPD(d)
    u_vec = torch.randn(spd.dim)
    u = SPD.from_vec(u_vec / u_vec.norm())
    x = torch.eye(d)
    assert_allclose(1.0, spd.norm(x, u), atol=1e-4)
    y = spd.exp(x, u)
    assert_allclose(1.0, spd.dist(x, y), atol=1e-4)
예제 #6
0
def test_h2_to_sspd2(seed, n, d):
    spd = SPD(2)
    lorentz = Lorentz(3)

    x = lorentz.rand(n, ir=1.0).mul_(d)
    y = h2_to_sspd2(x)
    assert_allclose(sspd2_to_h2(y), x / d, atol=1e-4)
    hyp_dists = sspd2_hyp_radius_ * lorentz.pdist(x / d)
    assert_allclose(spd.pdist(y), hyp_dists, atol=1e-4)
예제 #7
0
def test_distance_formulas(seed, rand_spd, d):
    spd = SPD(d)
    x, y = rand_spd(2, d)
    ref_dist = spd.dist(x, y)

    # compute :math:`Y^{-1} X` and take its eigenvalues (we have to use
    # `torch.eig` for this as the resulting matrix might not be symmetric)
    d1 = torch.solve(y, x)[0].eig()[0][:, 0].log_().pow_(2).sum().sqrt_()
    assert_allclose(ref_dist, d1, atol=1e-4)

    d2 = torch.solve(x, y)[0].eig()[0][:, 0].log_().pow_(2).sum().sqrt_()
    assert_allclose(ref_dist, d2, atol=1e-4)
예제 #8
0
def test_sspd2_to_h2(seed, n):
    spd = SPD(2)
    lorentz = Lorentz(3)

    x = spd.rand(n, ir=1.0)
    x.div_(x.det().sqrt_().reshape(-1, 1, 1))  # unit determinant
    assert_allclose(x.det(), torch.ones(n), atol=1e-4)
    assert_allclose(x, spd.projx(x), atol=1e-4)

    y = sspd2_to_h2(x)
    hyp_dists = sspd2_hyp_radius_ * lorentz.pdist(y)
    assert_allclose(spd.pdist(x), hyp_dists, atol=1e-4)
예제 #9
0
def eucl_to_tangent_space(man, vs):
    if isinstance(man, Euclidean):
        return vs
    elif isinstance(man, (Lorentz, Sphere)):
        return torch.cat([torch.zeros(len(vs), 1), vs], dim=1)
    elif isinstance(man, SymmetricPositiveDefinite):
        return SymmetricPositiveDefinite.from_vec(vs)

    raise ValueError(f'Manifold {man} is not supported.')
예제 #10
0
def test_sspd2_to_h2_nonconst_factor(seed, n, d):
    spd = SPD(2)
    lorentz = Lorentz(3)

    x = spd.rand(n, ir=1.0)
    x.div_(x.det().sqrt_().reshape(-1, 1, 1))  # unit determinant
    x.mul_(d)  # d**2 determinant
    dets = torch.empty(n).fill_(d**2)
    assert_allclose(x.det(), dets, atol=1e-4)
    assert_allclose(x, spd.projx(x), atol=1e-4)

    y = sspd2_to_h2(x)
    hyp_dists = sspd2_hyp_radius_ * lorentz.pdist(y)

    # The determinant essentially does not affect the curvatures, they are all
    # isometric to the 2-dimensional hyperbolic space of -1/2 constant sectional
    # curvature.
    assert_allclose(spd.pdist(x), hyp_dists, atol=1e-4)
예제 #11
0
def test_inner_norm(seed, d):
    spd = SPD(d)
    xs = spd.rand(100, ir=1.0, out=torch.empty(100, d, d, dtype=torch.float64))
    us = spd.randvec(xs)
    assert_allclose(spd.inner(xs, us, us)**0.5, spd.norm(xs, us), atol=1e-4)
예제 #12
0
def test_no_nan_dists(seed, rand_spd, d, n):
    spd = SPD(d)
    x = rand_spd(n, d)
    assert not torch.isnan(spd.pdist(x)).any()
예제 #13
0
def gen_samples(man, num_nodes, radius):
    vs = gen_from_ball(num_nodes, man.dim, radius)
    us = SymmetricPositiveDefinite.from_vec(vs)
    zeros = man.zero(num_nodes)
    assert torch.allclose(us, man.proju(zeros, us), atol=1e-4)
    return man.exp(zeros, us)
예제 #14
0
import numpy as np
from matplotlib.patches import Ellipse
import matplotlib.pyplot as plt
import torch

from graphembed.manifolds import SymmetricPositiveDefinite as SPD

spd = SPD(2)
x, y = spd.rand(2, ir=3.0)
u = spd.log(x, y)

n = 100
ys = spd.exp(x.repeat(n, 1, 1), torch.linspace(0, 1.0, n).reshape(n, 1, 1) * u)


def add_ellipsis(x, offset):
    ws, us = x.symeig(eigenvectors=True)
    rad = torch.atan2(us[0][1], us[0][0])
    degs = np.rad2deg(rad)
    ellipse = Ellipse(xy=(offset, 0), width=ws[0], height=ws[1], angle=degs)

    max_x = max(rad.cos().abs() * ws[0], rad.sin().abs() * ws[1]) / 2
    max_y = max(rad.sin().abs() * ws[0], rad.cos().abs() * ws[1]) / 2
    return ellipse, max_x, max_y


fig, ax = plt.subplots()
max_width = 0
max_height = 0
min_width = 0
for i, y in enumerate(ys):