Esempio n. 1
0
    def _optimization_method(self, p, d_p, lr):
        with torch.no_grad():
            gradient = d_p

            if (self.tau is None or self.v is None):
                self.tau = torch.zeros(gradient.size()).to(p.device)
                self.v = torch.zeros(gradient.size()).to(p.device)
            # we update only weigth that moved
            mask = (gradient.norm(2, -1) != 0)
            if (mask.sum() == 0):
                return
            gradient = gradient[mask]
            tau = self.tau[mask]
            v = self.v[mask]
            m = self.beta_1 * tau + (1 - self.beta_1) * gradient
            self.v[mask] =\
                self.beta_2 * v + (1 - self.beta_2) * pf.norm(p[mask], gradient)
            self.v[mask] = torch.cat(
                (v.unsqueeze(0), self.v[mask].unsqueeze(0))).max(0)[0]

            updated_weight = pf.exp(p[mask],
                                    -lr * m / torch.sqrt(self.v[mask]))
            self.tau[mask] =\
                pf.parallel_transport(from_point=p[mask],
                                      to_point=updated_weight,
                                      vector=m)
            p[mask] = updated_weight
def barycenter(z, wik=None, lr=1e-3, tau=5e-6, max_iter=100, distance=pf.distance, normed=False,
               init_method="default", verbose=False):
    with torch.no_grad():
        if(wik is None):
            wik = 1.
            barycenter = z.mean(0, keepdim=True)
        else:
            wik = wik.unsqueeze(-1).expand_as(z)
            if(init_method == "global_mean"):
                print("Bad init selected")
                barycenter = z.mean(0, keepdim=True)            
            else:
                barycenter = (z*wik).sum(0, keepdim=True)/wik.sum(0)

        if(len(z) == 1):
            return z
        iteration = 0
        cvg = math.inf
        while(cvg>tau and max_iter>iteration):

            iteration+=1
            if(type(wik) != float):
                grad_tangent = 2 * pf.log(barycenter.expand_as(z), z) * wik
                nan_values = (~(barycenter == barycenter))              
                if(torch.nonzero(nan_values.squeeze()).shape[0]>0):
                    print("\n\n A At least one barycenter is Nan : ")
                    print(pf.log(barycenter.expand_as(z), z).sum(0))
                    print("index of nan values ", nan_values.squeeze().nonzero())
                    quit()
                    # torch 1.3 minimum for this operation
                    print("index of nan values ", nan_values.squeeze().nonzero())

            else:
                grad_tangent = 2 * pf.log(barycenter.expand_as(z), z)
            
            if(normed):
                if(type(wik) != float):

                    grad_tangent /= wik.sum(0, keepdim=True).expand_as(wik)
                else:
                    grad_tangent /= len(z)

            cc_barycenter = pf.exp(barycenter, lr * grad_tangent.sum(0, keepdim=True))
            nan_values = (~(cc_barycenter == cc_barycenter))

            if(torch.nonzero(nan_values.squeeze()).shape[0]>0):
                    print("\n\n  At least one barycenter is Nan exp update may contain 0: ")
                    print(grad_tangent.sum(0, keepdim=True))
                    quit()
                    # torch 1.3 minimum for this operation
            cvg = distance(cc_barycenter, barycenter).max().item()

            barycenter = cc_barycenter
            if(cvg<=tau and verbose):
                print("Frechet Mean converged in ", iteration, " iterations")
        return barycenter
Esempio n. 3
0
    def _optimization_method(self, p, d_p, lr):
        with torch.no_grad():
            a = pf.exp(p, -lr * d_p)

            if (((a.norm(2, -1)) >= self.eps).max() > 0):
                if (not self.first_over):
                    print("New update out of the disc:", a.norm(2, -1).max())
                    self.first_over = True
                mask = a[a.norm(2, -1) >= self.eps]
                a[a.norm(2, -1) >= self.eps] = p[a.norm(2, -1) >= self.eps]
            p.copy_(a)
Esempio n. 4
0
 def _optimization_method(self, p, d_p, lr):
     p.copy_(pf.exp(p, lr * d_p))
Esempio n. 5
0
 def _optimization_method(self, p, d_p, lr):
     p.copy_(
         pf.add(pf.renorm_projection(p.data),
                -lr * pf.exp(d_p.new(d_p.size()).zero_(), d_p)))