예제 #1
0
    def forward(self, x):
        if (not self.evaluation):
            # Create the veronese map of z
            npoints, dims = x.size()
            v_x, _ = generate_veronese(x.view(dims, npoints), self.n)
            dim_veronese, BS = v_x.size()
            M_inv_temp = self.create_M(v_x).cuda(1)
            #TODO: Update Minv batch wise with sherman-morrisson techniques
            M_inv = ((self.M_inv_copy.cuda(1) + M_inv_temp) * 0.5)

            del M_inv_temp
            torch.cuda.empty_cache()

            x = torch.matmul(
                torch.matmul(v_x.view(BS, 1, dim_veronese), M_inv),
                v_x.view(BS, dim_veronese, 1))
            self.M_inv_copy = M_inv.cpu().detach().clone()
        else:
            npoints, dims = x.size()
            v_x, _ = generate_veronese(x.view(dims, npoints), self.n)
            dim_veronese, BS = v_x.size()
            x = torch.matmul(
                torch.matmul(v_x.view(BS, 1, dim_veronese),
                             self.M_inv_copy.cuda(1)),
                v_x.view(BS, dim_veronese, 1))

        return x
예제 #2
0
 def forward(self, x):
     if (not self.has_M_inv and self.build_M):
         # We want only veronese maps to build the moment matrix once we have good reconstruction (self.build_M = True) !
         npoints, dims = x.size()
         v_x, _ = generate_veronese(x.view(dims, npoints), self.n)
         self.veroneses.append(v_x.cpu())
     elif (self.has_M_inv):
         # Create the veronese map of z
         npoints, dims = x.size()
         v_x, _ = generate_veronese(x.view(dims, npoints), self.n)
         dim_veronese, BS = v_x.size()
         x = torch.matmul(
             torch.matmul(v_x.view(BS, 1, dim_veronese), self.M_inv),
             v_x.view(BS, dim_veronese, 1))
     return x
예제 #3
0
 def forward(self, x):
     npoints, dims = x.size()
     v_x, _ = generate_veronese(x.view(dims, npoints), self.n)
     # v_x is (dim_veronese, BS)
     x = self.B(v_x)
     return x
예제 #4
0
 def forward(self, x):
     npoints, dims = x.size()
     v_x, _ = generate_veronese(x.view(dims, npoints), self.n)
     # v_x is (dim_veronese, BS), transpose it to have the batch dim at the beginning
     x = self.B(v_x.t_(), v_x)
     return x