def rotator(inputs): ''' just a simple helper function to randomly rotate inputs ''' bs, _, w, h = inputs.size() ## rotation generating matrix ## g2 = torch.zeros(3, 3) g2[0, 1] = -1. g2[1, 0] = 1. g2 = g2.unsqueeze(-1).expand(3, 3, bs) ## weight the rotations randomly ## upper = 10 lower = -10 wghts = upper * torch.rand(bs) - lower g2 = wghts * g2 generators = (g2).transpose(0, 2).transpose(2, 1) affine_mats = expm(generators) flowgrid = F.affine_grid(affine_mats[:, :2, :], size=inputs.size(), align_corners=True) transformed = F.grid_sample(inputs, flowgrid, align_corners=True) return transformed
def affine(self,xyz): # (bs,n,3), (bs,n,c), (bs,n) bs = xyz.shape[0] z = torch.rand(bs,4,dtype=xyz.dtype,device=xyz.device) affine_generators = (z*(self.upper-self.lower)+self.lower).reshape(bs,2,2) affine_matrices = expm(torch.cat([affine_generators,-affine_generators],dim=0)) A,Ainv = affine_matrices[:bs],affine_matrices[bs:] return A,Ainv
def forward(self,inp): xyz,vals,mask = inp # (bs,n,3), (bs,n,c), (bs,n) bs = xyz.shape[0] z = torch.rand(bs,12).to(xyz.device,xyz.dtype)*(self.upper-self.lower)+self.lower affine_generators = torch.zeros(bs,4,4,dtype=xyz.dtype,device=xyz.device) affine_generators[:,:3,:3] += cross_matrix(z[:,:3])+shear_matrix(z[:,3:6])+squeeze_matrix(z[:,6:9]) affine_generators[:,:3,3] += z[:,9:] affine_matrices = expm(affine_generators) transformed_xyz = xyz@affine_matrices[:,:3,:3] + affine_matrices[:,None,:3,3]*self.trans_scale return transformed_xyz,vals,mask
def transform(self, x): bs, _, w, h = x.size() weights = torch.rand(bs, 6) weights = weights.to(x.device, x.dtype) width = self.softplus(self.width) weights = weights * width - width.div(2.) generators = self.generate(weights) ## exponential map affine_matrices = expm(generators.cpu()).to(weights.device) flowgrid = F.affine_grid(affine_matrices[:, :2, :], size=x.size(), align_corners=True) x_out = F.grid_sample(x, flowgrid, align_corners=True) return x_out