Esempio n. 1
0
 def get_pose(self, ind):
     if self.emb_type is None:
         rot = self.rots[ind]
         tran = self.trans[ind] if self.use_trans else None
     else:
         if self.emb_type == 's2s2':
             rot = lie_tools.s2s2_to_SO3(self.rots_emb(ind))
         elif self.emb_type == 'quat':
             rot = lie_tools.quaternions_to_SO3(self.rots_emb(ind))
         else:
             raise RuntimeError # should not reach here
         tran = self.trans_emb(ind) if self.use_trans else None
     return rot, tran
Esempio n. 2
0
 def real_tform(self, rot, trans):
     if rot.shape[-1] == 4:
         rot = lie_tools.quaternions_to_SO3(rot)
     B = rot.size(0)
     D = self.D
     grid = self.lattice @ rot  # B x D^3 x 3
     grid = grid.view(-1, D, D, D, 3)
     offset = self.center - grid[:, self.D2, self.D2, self.D2]
     grid += offset[:, None, None, None, :]
     grid -= trans
     grid = grid.view(1, -1, D, D, 3)
     vol = F.grid_sample(self.vol_real, grid)
     return vol.view(B, D, D, D)
Esempio n. 3
0
 def rotate(self, rot_or_quat):
     rot = rot_or_quat
     if rot_or_quat.shape[-1] == 4:
         rot = lie_tools.quaternions_to_SO3(rot_or_quat)
     B = rot.size(0)
     D = self.D
     grid = self.lattice @ rot  # B x D^3 x 3
     grid = grid.view(-1, D, D, D, 3)
     offset = self.center - grid[:, self.D2, self.D2, self.D2]
     grid += offset[:, None, None, None, :]
     grid = grid.view(1, -1, D, D, 3)
     volr = F.grid_sample(self.volr, grid)
     volr = volr.view(B, D, D, D)
     voli = F.grid_sample(self.voli, grid)
     voli = voli.view(B, D, D, D)
     return volr, voli
Esempio n. 4
0
    def save(self, out_pkl):
        if self.emb_type == 'quat':
            r = lie_tools.quaternions_to_SO3(self.rots_emb.weight.data).cpu().numpy()
        elif self.emb_type == 's2s2':
            r = lie_tools.s2s2_to_SO3(self.rots_emb.weight.data).cpu().numpy()
        else:
            r = self.rots.cpu().numpy()

        if self.use_trans:
            if self.emb_type is None:
                t = self.trans.cpu().numpy()
            else:
                t = self.trans_emb.weight.data.cpu().numpy()
            t /= self.D # convert from pixels to extent
            poses = (r,t)
        else:
            poses = (r,)

        pickle.dump(poses, open(out_pkl,'wb'))
Esempio n. 5
0
 def __init__(self, resol):
     quats = so3_grid.grid_SO3(resol)
     self.rots = lie_tools.quaternions_to_SO3(torch.tensor(quats))
     self.N = len(rots)