def barycenter(z, wik=None, lr=5e-3, tau=5e-3, max_iter=math.inf, distance=pf.distance, normed=False): if (wik is None): wik = 1. else: wik = wik.unsqueeze(-1).expand_as(z) if (z.dim() > 1): barycenter = z.mean(0, keepdim=True) else: barycenter = z.mean(0, keepdim=True) if (len(z) == 1): return z iteration = 0 cvg = math.inf while (cvg > tau and max_iter > iteration): iteration += 1 grad_tangent = pf.log(barycenter.expand_as(z), z) * wik if (normed): if (type(wik) != float): grad_tangent /= wik.sum(0, keepdim=True).expand_as(wik) else: grad_tangent /= len(z) cc_barycenter = pf.exp(barycenter, lr * grad_tangent.sum(0)) cvg = distance(cc_barycenter, barycenter).max().item() barycenter = cc_barycenter return barycenter
def test(): import numpy as np from torch import nn from function_tools import poincare_function as tf from function_tools.numpy_function import RiemannianFunction as nf import cmath def sigmoid(x): return 1 / (1 + cmath.exp(-x)) x = torch.rand(3, 2) / 1.5 y = torch.rand(3, 2) / 1.5 xn = x[:, 0].detach().numpy() + x[:, 1].detach().numpy() * 1j yn = y[:, 0].detach().numpy() + y[:, 1].detach().numpy() * 1j x = nn.Parameter(x) y = nn.Parameter(y) print("LOG : ") print(" Torch version") print(" " + str(tf.log(x, y))) print(" numpy version") print(" " + str(nf.log(xn, yn))) print("EXP: ") print(" Torch version") print(" " + str(tf.exp(x, y))) print(" numpy version") print(" " + str( np.array( [nf.exp(xn[0], yn[0]), nf.exp(xn[1], yn[1]), nf.exp(xn[2], yn[2])]))) print("Poincare DIST : ") print(" Torch version") print(" " + str(tf.riemannian_distance(x, y))) print(" numpy version") print(" " + str(nf.riemannian_distance(xn, yn))) print("Poincare Grad : ") print(" Torch version") x = torch.rand(1, 2) / 1.5 y = torch.rand(1, 2) / 1.5 xn = x[:, 0].detach().numpy() + x[:, 1].detach().numpy() * 1j yn = y[:, 0].detach().numpy() + y[:, 1].detach().numpy() * 1j x = nn.Parameter(x) y = nn.Parameter(y) l = tf.riemannian_distance(x, y) l.backward() print(" Gradient Angle") print(" " + str(x.grad / x.grad.norm(2, -1))) print(" numpy version") print(" " + str( nf.riemannian_distance_grad(xn, yn) / abs(nf.riemannian_distance_grad(xn, yn)))) print(" Real value not angle") print(x.grad) g_xn = nf.riemannian_distance_grad(xn, yn) print(nf.riemannian_distance_grad(xn, yn)) print("factor ", g_xn[0].real / x.grad[0, 0].item()) print( nf.riemannian_distance_grad(xn, yn) / (g_xn[0].real / x.grad[0, 0].item())) ln = nf.riemannian_distance_grad(xn, yn) #* -sigmoid(nf.riemannian_distance(xn, yn)) print(ln) print("new num", nf.exp(xn, -ln)) xr = ((1 - torch.sum(x.data**2, dim=-1))) / 4 print(-x.grad * xr) print(x - (x.grad * xr)) print("new num", tf.exp(x, -x.grad))
def barycenter(z, wik=None, lr=5e-2, tau=5e-3, max_iter=math.inf, distance=pf.distance, normed=False, init_method="default"): if (wik is None): wik = 1. # barycenter = z.mean(0, keepdim=True) barycenter = z.mean(0, keepdim=True) * 0 else: wik = wik.unsqueeze(-1).expand_as(z) if (init_method == "global_mean"): print("Bad init selected") barycenter = z.mean(0, keepdim=True) else: barycenter = (z * wik).sum(0, keepdim=True) / wik.sum(0) if (len(z) == 1): return z iteration = 0 cvg = math.inf # print("barycenter_init", barycenter) while (cvg > tau and max_iter > iteration): iteration += 1 if (type(wik) != float): grad_tangent = 2 * pf.log(barycenter.expand_as(z), z) * wik if ((barycenter == barycenter).float().mean() != 1): print("\n\n At least one barycenter is Nan : ") print(barycenter) print(wik.sum(0)) print(wik.mean(0)) print(wik.sum(1)) print(wik.mean(1)) print(iteration) exit() else: grad_tangent = 2 * pf.log(barycenter.expand_as(z), z) #print(type(wik)) if (normed): # print(grad_tangent.size()) if (type(wik) != float): # print(wik.sum(0, keepdim=True)) grad_tangent /= wik.sum(0, keepdim=True).expand_as(wik) else: grad_tangent /= len(z) cc_barycenter = pf.exp(barycenter, lr * grad_tangent.sum(0, keepdim=True)) cvg = distance(cc_barycenter, barycenter).max().item() # print(cvg) barycenter = cc_barycenter if (type(wik) != float): # # to debug ponderate version # print(cvg, iteration, max_iter) pass # print("BARYCENTERS -> ", barycenter) return barycenter