コード例 #1
0
    def compute_geodesic_control(self, man):
        vs = self.adjoint(man)
        d_vx = vs(self.manifold.gd[0].view(-1, self.manifold.dim), k=1)

        S = 0.5 * (d_vx + torch.transpose(d_vx, 1, 2))
        S = torch.tensordot(S,
                            eta(self.manifold.dim, device=self.device),
                            dims=2)

        tlambdas = self.solve_sks(self.manifold.gd[0].reshape(-1, self.dim),
                                  self.manifold.gd[0].reshape(-1, self.dim),
                                  S,
                                  self.__keops_eye,
                                  self.__keops_invsigmasq,
                                  self.__keops_A,
                                  backend=self.__keops_backend,
                                  alpha=self.nu,
                                  eps=self.eps) / self.coeff

        (aq, aqkiaq) = self.__compute_aqkiaq()

        c, _ = torch.solve(torch.mm(aq.t(), tlambdas.view(-1, 1)), aqkiaq)

        self.controls = c.flatten()
        self.__compute_moments()
コード例 #2
0
    def __compute_aqh(self, h):
        R = self.manifold.gd[1].view(-1, self.manifold.dim, self.manifold.dim)

        return torch.einsum(
            'nli, nik, k, nui, niv, lvt->nt', R, self.C, h,
            torch.eye(self.manifold.dim,
                      device=self.device).repeat(self.manifold.nb_pts, 1, 1),
            torch.transpose(R, 1, 2), eta(self.dim, device=self.device))
コード例 #3
0
 def __compute_moments(self):
     self.__aqh = self.__compute_aqh(self.controls)
     lambdas, _ = torch.solve(self.__aqh.view(-1, 1), self.__sks)
     self.__lambdas = lambdas.contiguous()
     self.moments = torch.tensordot(self.__lambdas.view(-1, self.sym_dim),
                                    torch.transpose(
                                        eta(self.manifold.dim,
                                            device=self.device), 0, 2),
                                    dims=1)
コード例 #4
0
 def __compute_moments(self):
     self.__aqh = self.__compute_aqh(self.controls)
     self.__lambdas = self.solve_sks(
         self.manifold.gd[0].reshape(-1, self.dim),
         self.manifold.gd[0].reshape(-1, self.dim),
         self.__aqh,
         self.__keops_eye,
         self.__keops_invsigmasq,
         self.__keops_A,
         backend=self.__keops_backend,
         alpha=self.nu)
     self.moments = torch.tensordot(self.__lambdas.view(-1, self.sym_dim),
                                    torch.transpose(
                                        eta(self.manifold.dim,
                                            device=self.device), 0, 2),
                                    dims=1)
コード例 #5
0
    def compute_geodesic_control(self, man):
        vs = self.adjoint(man)
        d_vx = vs(self.manifold.gd[0], k=1)

        S = 0.5 * (d_vx + torch.transpose(d_vx, 1, 2))
        S = torch.tensordot(S,
                            eta(self.manifold.dim, device=self.device),
                            dims=2)

        self.__compute_sks()

        tlambdas, _ = torch.solve(S.view(-1, 1), self.__sks)
        tlambdas = tlambdas / self.coeff

        (aq, aqkiaq) = self.__compute_aqkiaq()
        c, _ = torch.solve(torch.mm(aq.t(), tlambdas), aqkiaq)
        self.controls = c.reshape(-1)
        self.__compute_moments()