예제 #1
0
def barycenter(z,
               wik=None,
               lr=5e-3,
               tau=5e-3,
               max_iter=math.inf,
               distance=pf.distance,
               normed=False):
    if (wik is None):
        wik = 1.
    else:
        wik = wik.unsqueeze(-1).expand_as(z)
    if (z.dim() > 1):
        barycenter = z.mean(0, keepdim=True)
    else:
        barycenter = z.mean(0, keepdim=True)
    if (len(z) == 1):
        return z
    iteration = 0
    cvg = math.inf

    while (cvg > tau and max_iter > iteration):

        iteration += 1
        grad_tangent = pf.log(barycenter.expand_as(z), z) * wik
        if (normed):
            if (type(wik) != float):
                grad_tangent /= wik.sum(0, keepdim=True).expand_as(wik)
            else:
                grad_tangent /= len(z)
        cc_barycenter = pf.exp(barycenter, lr * grad_tangent.sum(0))
        cvg = distance(cc_barycenter, barycenter).max().item()
        barycenter = cc_barycenter
    return barycenter
def test():
    import numpy as np
    from torch import nn
    from function_tools import poincare_function as tf
    from function_tools.numpy_function import RiemannianFunction as nf
    import cmath

    def sigmoid(x):
        return 1 / (1 + cmath.exp(-x))

    x = torch.rand(3, 2) / 1.5
    y = torch.rand(3, 2) / 1.5
    xn = x[:, 0].detach().numpy() + x[:, 1].detach().numpy() * 1j
    yn = y[:, 0].detach().numpy() + y[:, 1].detach().numpy() * 1j
    x = nn.Parameter(x)
    y = nn.Parameter(y)

    print("LOG : ")
    print("   Torch version")
    print("   " + str(tf.log(x, y)))
    print("   numpy version")
    print("   " + str(nf.log(xn, yn)))

    print("EXP: ")
    print("   Torch version")
    print("   " + str(tf.exp(x, y)))
    print("   numpy version")
    print("   " + str(
        np.array(
            [nf.exp(xn[0], yn[0]),
             nf.exp(xn[1], yn[1]),
             nf.exp(xn[2], yn[2])])))

    print("Poincare DIST : ")
    print("   Torch version")
    print("   " + str(tf.riemannian_distance(x, y)))
    print("   numpy version")
    print("   " + str(nf.riemannian_distance(xn, yn)))

    print("Poincare Grad : ")
    print("   Torch version")
    x = torch.rand(1, 2) / 1.5
    y = torch.rand(1, 2) / 1.5
    xn = x[:, 0].detach().numpy() + x[:, 1].detach().numpy() * 1j
    yn = y[:, 0].detach().numpy() + y[:, 1].detach().numpy() * 1j
    x = nn.Parameter(x)
    y = nn.Parameter(y)
    l = tf.riemannian_distance(x, y)
    l.backward()
    print(" Gradient Angle")
    print("   " + str(x.grad / x.grad.norm(2, -1)))
    print("   numpy version")
    print("   " + str(
        nf.riemannian_distance_grad(xn, yn) /
        abs(nf.riemannian_distance_grad(xn, yn))))
    print(" Real value not angle")
    print(x.grad)
    g_xn = nf.riemannian_distance_grad(xn, yn)
    print(nf.riemannian_distance_grad(xn, yn))
    print("factor ", g_xn[0].real / x.grad[0, 0].item())
    print(
        nf.riemannian_distance_grad(xn, yn) /
        (g_xn[0].real / x.grad[0, 0].item()))
    ln = nf.riemannian_distance_grad(xn, yn)
    #* -sigmoid(nf.riemannian_distance(xn, yn))
    print(ln)
    print("new num", nf.exp(xn, -ln))
    xr = ((1 - torch.sum(x.data**2, dim=-1))) / 4
    print(-x.grad * xr)
    print(x - (x.grad * xr))
    print("new num", tf.exp(x, -x.grad))
def barycenter(z,
               wik=None,
               lr=5e-2,
               tau=5e-3,
               max_iter=math.inf,
               distance=pf.distance,
               normed=False,
               init_method="default"):

    if (wik is None):
        wik = 1.
        # barycenter = z.mean(0, keepdim=True)
        barycenter = z.mean(0, keepdim=True) * 0
    else:

        wik = wik.unsqueeze(-1).expand_as(z)
        if (init_method == "global_mean"):
            print("Bad init selected")
            barycenter = z.mean(0, keepdim=True)
        else:
            barycenter = (z * wik).sum(0, keepdim=True) / wik.sum(0)

    if (len(z) == 1):
        return z
    iteration = 0
    cvg = math.inf
    # print("barycenter_init", barycenter)
    while (cvg > tau and max_iter > iteration):

        iteration += 1
        if (type(wik) != float):
            grad_tangent = 2 * pf.log(barycenter.expand_as(z), z) * wik
            if ((barycenter == barycenter).float().mean() != 1):
                print("\n\n At least one barycenter is Nan : ")
                print(barycenter)
                print(wik.sum(0))
                print(wik.mean(0))
                print(wik.sum(1))
                print(wik.mean(1))
                print(iteration)
                exit()
        else:
            grad_tangent = 2 * pf.log(barycenter.expand_as(z), z)

        #print(type(wik))
        if (normed):
            # print(grad_tangent.size())
            if (type(wik) != float):
                # print(wik.sum(0, keepdim=True))
                grad_tangent /= wik.sum(0, keepdim=True).expand_as(wik)
            else:
                grad_tangent /= len(z)
        cc_barycenter = pf.exp(barycenter,
                               lr * grad_tangent.sum(0, keepdim=True))
        cvg = distance(cc_barycenter, barycenter).max().item()
        # print(cvg)
        barycenter = cc_barycenter
    if (type(wik) != float):
        # # to debug ponderate version
        # print(cvg, iteration, max_iter)
        pass
    # print("BARYCENTERS -> ", barycenter)
    return barycenter