def get_item(item, mode, config): x = item.x.cuda()[None] y = item.y.cuda() y_outer = item.y_outer.cuda() shape = item.shape # augmentation done only during training if mode == DataModes.TRAINING: # if training do augmentation if torch.rand(1)[0] > 0.5: x = x.permute([0, 1, 3, 2]) y = y.permute([0, 2, 1]) y_outer = y_outer.permute([0, 2, 1]) if torch.rand(1)[0] > 0.5: x = torch.flip(x, dims=[1]) y = torch.flip(y, dims=[0]) y_outer = torch.flip(y_outer, dims=[0]) if torch.rand(1)[0] > 0.5: x = torch.flip(x, dims=[2]) y = torch.flip(y, dims=[1]) y_outer = torch.flip(y_outer, dims=[1]) if torch.rand(1)[0] > 0.5: x = torch.flip(x, dims=[3]) y = torch.flip(y, dims=[2]) y_outer = torch.flip(y_outer, dims=[2]) orientation = torch.tensor([0, -1, 0]).float() new_orientation = (torch.rand(3) - 0.5) * 2 * np.pi new_orientation = F.normalize(new_orientation, dim=0) q = orientation + new_orientation q = F.normalize(q, dim=0) theta_rotate = stns.stn_quaternion_rotations(q) shift = torch.tensor([ d / (D // 2) for d, D in zip( 2 * (torch.rand(3) - 0.5) * config.augmentation_shift_range, y.shape) ]) theta_shift = stns.shift(shift) f = 0.1 scale = 1.0 - 2 * f * (torch.rand(1) - 0.5) theta_scale = stns.scale(scale) theta = theta_rotate @ theta_shift @ theta_scale x, y, y_outer = stns.transform(theta, x, y, y_outer) surface_points_normalized_all = [] vertices_mc_all = [] faces_mc_all = [] for i in range(1, config.num_classes): shape = torch.tensor(y.shape)[None].float() if mode != DataModes.TRAINING: gap = 1 y_ = clean_border_pixels((y == i).long(), gap=gap) vertices_mc, faces_mc = voxel2mesh(y_, gap, shape) vertices_mc_all += [vertices_mc] faces_mc_all += [faces_mc] y_outer = sample_outer_surface_in_voxel((y == i).long()) surface_points = torch.nonzero(y_outer) surface_points = torch.flip( surface_points, dims=[1]).float() # convert z,y,x -> x, y, z surface_points_normalized = normalize_vertices(surface_points, shape) # surface_points_normalized = y_outer perm = torch.randperm(len(surface_points_normalized)) point_count = 3000 surface_points_normalized_all += [ surface_points_normalized[ perm[:np.min([len(perm), point_count])]].cuda() ] # randomly pick 3000 points if mode == DataModes.TRAINING: return { 'x': x, 'y_voxels': y, 'surface_points': surface_points_normalized_all, 'unpool': [0, 1, 0, 1, 0] } else: return { 'x': x, 'y_voxels': y, 'vertices_mc': vertices_mc_all, 'faces_mc': faces_mc_all, 'surface_points': surface_points_normalized_all, 'unpool': [0, 1, 1, 1, 1] }
def get_item_(item, mode, config): x = item.x.cuda()[None] y = item.y.cuda() # x = y[None] # <<<<<<<<<<<<< comment y_outer = item.y_outer.cuda() w = item.w.cuda() x_super_res = item.x_super_res[None] y_super_res = item.y_super_res shape = item.shape # print('in') # x_temp = x.clone() # y_temp = y.clone() # embed() # x = x_temp # y = y_temp surface_points = y_outer # surface_points_before = torch.nonzero(y_outer) # surface_points_before = torch.flip(surface_points_before, dims=[1]) # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check300.tif', np.uint8(y_outer.data.cpu().numpy() * 255)) # surface_points = y_outer # augmentation done only during training if mode == DataModes.TRAINING_EXTENDED: # if training do augmentation if torch.rand(1)[0] > 0.0: x = x.permute([0, 1, 3, 2]) y = y.permute([0, 2, 1]) surface_points = surface_points[:, [1, 0, 2]] if torch.rand(1)[0] > 0.5: x = torch.flip(x, dims=[1]) y = torch.flip(y, dims=[0]) surface_points[:, 2] = -surface_points[:, 2] if torch.rand(1)[0] > 0.5: x = torch.flip(x, dims=[2]) y = torch.flip(y, dims=[1]) surface_points[:, 1] = -surface_points[:, 1] if torch.rand(1)[0] > 0.5: x = torch.flip(x, dims=[3]) y = torch.flip(y, dims=[2]) surface_points[:, 0] = -surface_points[:, 0] orientation = torch.tensor([0, -1, 0]).float() new_orientation = (torch.rand(3) - 0.5) * 2 * np.pi new_orientation = F.normalize(new_orientation, dim=0) q = orientation + new_orientation q = F.normalize(q, dim=0) theta_rotate = stns.stn_quaternion_rotations(q) shift = torch.tensor([ d / (D // 2) for d, D in zip( 2 * (torch.rand(3) - 0.5) * config.augmentation_shift_range, y.shape) ]) D, H, W = y.shape # shift = torch.tensor([10,15,20]).float() / D theta_shift = stns.shift(shift) f = 0.1 scale = 1.0 - 2 * f * (torch.rand(1) - 0.5) theta_scale = stns.scale(scale) theta = theta_rotate @ theta_shift @ theta_scale x, y, w = stns.transform(theta, x, y, w) # not necessary during training x_super_res = None y_super_res = None # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check307.tif', np.uint8(y_outer_grid_sampler.data.cpu().numpy() * 255)) # theta_shift = stns.shift(torch.tensor([10,0,0])) # surface_points_after = surface_points_before.float() - (shape.cuda()-1)/2 theta_inv = theta_scale.inverse() @ theta_shift.inverse( ) @ theta_rotate.inverse() # theta_inv = theta_rotate theta_inv = theta_inv[:3] surface_points = torch.cat( [surface_points, torch.ones(len(surface_points), 1).cuda()], dim=1) surface_points = theta_inv.cuda() @ surface_points.float().permute( 1, 0) surface_points = surface_points.permute(1, 0) # surface_points_after = surface_points_after.float() @ theta_rotate_1.cuda() # surface_points_after = surface_points_after + (shape.cuda()-1)/2 # surface_points = torch.round(surface_points_after).long() # embed() # print('{} {}'.format(torch.any(surface_points<-1), torch.any(surface_points>1)), end='') surface_points = surface_points[torch.all(surface_points > -1, dim=1) * torch.all(surface_points < 1, dim=1)] # print(' | {} {}'.format(torch.any(surface_points<-1), torch.any(surface_points>1))) # vertices_ = torch.round((surface_points + 1)*63.0/2).long() # y_outer_ = torch.zeros_like(y) # y_outer_[vertices_[:,2], vertices_[:,1], vertices_[:,0]] = 1 # y_outer_ = y_outer_ + 3*y # x_ = (x - x.min())/(x.max()-x.min()) # overlay_y_hat = blend_cpu(x_[0].cpu(), y_outer_.cpu(), 4) # x_ = x_[0] # x_ = 255*x_[:,:,:,None].repeat(1,1,1,3).cpu() # overlay = np.concatenate([x_, overlay_y_hat], axis=2) # io.imsave('/cvlabdata2/cvlab/datasets_udaranga/check_{}.tif'.format(int(torch.rand(1)*10000)), np.uint8(overlay)) # print(crash) gap = 1 y_ = clean_border_pixels(y, gap=gap) vertices_mc, faces_mc = voxel2mesh(y_, gap, torch.tensor(y.shape)[None].float()) sphere_vertices = config.sphere_vertices sphere_faces = config.sphere_faces # self.sphere_vertices = sphere_vertices.repeat(self.config.config.batch_size,1,1).float() p = torch.acos(sphere_vertices[:, 2]) t = torch.atan2(sphere_vertices[:, 1], sphere_vertices[:, 0]) p = torch.tensor(p, requires_grad=True) t = torch.tensor(t, requires_grad=True) # # points on sphere # x_ = torch.sin(p)*torch.cos(t) # y_ = torch.sin(p)*torch.sin(t) # z_ = torch.cos(p) # atlas_vertices = torch.cat([x_[:,None],y_[:,None],z_[:,None]],dim=1).float() # surface_points = torch.nonzero(y_outer) # surface_points = normalize_vertices(surface_points, shape) if mode == DataModes.TRAINING_EXTENDED: return { 'x': x, 'faces_atlas': sphere_faces, 'y_voxels': y, 'surface_points': surface_points, 'p': p, 't': t, 'unpool': config.unpool_indices } else: return { 'x': x, 'x_super_res': x_super_res, 'faces_atlas': sphere_faces, 'y_voxels': y, 'y_voxels_super_res': y_super_res, 'vertices_mc': vertices_mc, 'faces_mc': faces_mc, 'surface_points': surface_points, 'p': p, 't': t, 'unpool': [1, 1, 1, 0, 0] }
def get_item__(item, mode, config): x = item.x.cuda()[None] y = item.y.cuda() # x = y[None].float() # <<<<<<<<<<<<< comment y_outer = item.y_outer.cuda() w = item.w.cuda() x_super_res = item.x_super_res[None] y_super_res = item.y_super_res # x_super_res = y_super_res[None].float() # <<<<<<<<<<<<< comment shape = item.shape # augmentation done only during training if mode == DataModes.TRAINING_EXTENDED: # if training do augmentation if torch.rand(1)[0] > 0.5: x = x.permute([0, 1, 3, 2]) y = y.permute([0, 2, 1]) y_outer = y_outer.permute([0, 2, 1]) if torch.rand(1)[0] > 0.5: x = torch.flip(x, dims=[1]) y = torch.flip(y, dims=[0]) y_outer = torch.flip(y_outer, dims=[0]) if torch.rand(1)[0] > 0.5: x = torch.flip(x, dims=[2]) y = torch.flip(y, dims=[1]) y_outer = torch.flip(y_outer, dims=[1]) if torch.rand(1)[0] > 0.5: x = torch.flip(x, dims=[3]) y = torch.flip(y, dims=[2]) y_outer = torch.flip(y_outer, dims=[2]) orientation = torch.tensor([0, -1, 0]).float() new_orientation = (torch.rand(3) - 0.5) * 2 * np.pi # new_orientation[2] = new_orientation[2] * 0 # no rotation outside x-y plane new_orientation = F.normalize(new_orientation, dim=0) q = orientation + new_orientation q = F.normalize(q, dim=0) theta_rotate = stns.stn_quaternion_rotations(q) shift = torch.tensor([ d / (D // 2) for d, D in zip( 2 * (torch.rand(3) - 0.5) * config.augmentation_shift_range, y.shape) ]) theta_shift = stns.shift(shift) f = 0.1 scale = 1.0 - 2 * f * (torch.rand(1) - 0.5) theta_scale = stns.scale(scale) theta = theta_rotate @ theta_shift @ theta_scale # x, y = stns.transform(theta, x, y) x, y, y_outer = stns.transform(theta, x, y, y_outer) # not necessary during training x_super_res = None y_super_res = None # y_outer = sample_outer_surface_in_voxel(y) if mode != DataModes.TRAINING_EXTENDED: gap = 1 y_ = clean_border_pixels(y, gap=gap) vertices_mc, faces_mc = voxel2mesh(y_, gap, shape) sphere_vertices = config.sphere_vertices atlas_faces = config.sphere_faces # self.sphere_vertices = sphere_vertices.repeat(self.config.config.batch_size,1,1).float() p = torch.acos(sphere_vertices[:, 2]).cuda() t = torch.atan2(sphere_vertices[:, 1], sphere_vertices[:, 0]).cuda() p = torch.tensor(p, requires_grad=True) t = torch.tensor(t, requires_grad=True) surface_points = torch.nonzero(y_outer) surface_points = torch.flip(surface_points, dims=[1]).float() # convert z,y,x -> x, y, z surface_points_normalized = normalize_vertices(surface_points, shape) # surface_points_normalized = y_outer perm = torch.randperm(len(surface_points_normalized)) point_count = 3000 surface_points_normalized = surface_points_normalized[ perm[:np.min([len(perm), point_count])]] # randomly pick 3000 points if mode == DataModes.TRAINING_EXTENDED: return { 'x': x, 'faces_atlas': atlas_faces, 'y_voxels': y, 'surface_points': surface_points_normalized, 'p': p, 't': t, 'unpool': config.unpool_indices, 'w': y_outer } else: return { 'x': x, 'x_super_res': x_super_res, 'faces_atlas': atlas_faces, 'y_voxels': y, 'y_voxels_super_res': y_super_res, 'vertices_mc': vertices_mc, 'faces_mc': faces_mc, 'surface_points': surface_points_normalized, 'p': p, 't': t, 'unpool': [0, 1, 0, 1, 0] }
def __getitem__(self, idx): item = self.data[idx] while True: x = torch.from_numpy(item.x).cuda()[None] y = torch.from_numpy(item.y).cuda().long() # y[y == 2] = 0 ## now y==2 means inside points y[y == 3] = 0 # y[y==3] = 1 if self.base_sparse_plane is not None: base_plane = torch.from_numpy(self.base_sparse_plane[idx]).cuda().float() else: base_plane = torch.ones_like(y).float() # breakpoint() C, D, H, W = x.shape center = (D//2, H//2, W//2) y = y.long() if self.mode == DataModes.TRAINING_EXTENDED: # if training do augmentation orientation = torch.tensor([0, -1, 0]).float() new_orientation = (torch.rand(3) - 0.5) * 2 * np.pi # new_orientation[2] = new_orientation[2] * 0 # no rotation outside x-y plane new_orientation = F.normalize(new_orientation, dim=0) q = orientation + new_orientation q = F.normalize(q, dim=0) theta_rotate = stns.stn_quaternion_rotations(q) shift = torch.tensor([d / (D // 2) for d, D in zip(2 * (torch.rand(3) - 0.5) * self.cfg.augmentation_shift_range, y.shape)]) theta_shift = stns.shift(shift) f = 0.1 scale = 1.0 - 2 * f *(torch.rand(1) - 0.5) theta_scale = stns.scale(scale) theta = theta_rotate @ theta_shift @ theta_scale x, y, base_plane = stns.transform(theta, x, y, w=base_plane) else: pose = torch.zeros(6).cuda() # w = torch.zeros_like(y) # base_plane = torch.ones_like(y) theta = torch.eye(4).cuda() x_super_res = torch.tensor(1) y_super_res = torch.tensor(1) x = crop(x, (C,) + self.cfg.patch_shape, (0,) + center) y = crop(y, self.cfg.patch_shape, center) base_plane = crop(base_plane, self.cfg.patch_shape, center) ## change for model_id = 4 if self.point_model is not None: surface_points = torch.nonzero((y == 1)) y_outer = torch.zeros_like(y) y_outer[surface_points[:, 0], surface_points[:, 1], surface_points[:, 2]] = 1 y[y == 2] = 1 surface_points_normalized_all = [] vertices_mc_all = [] faces_mc_all = [] for i in range(1, self.cfg.num_classes): shape = torch.tensor(y.shape)[None].float() if self.mode != DataModes.TRAINING_EXTENDED: gap = 1 y_ = clean_border_pixels((y==i).long(), gap=gap) vertices_mc, faces_mc = voxel2mesh(y_, gap, shape) vertices_mc_all += [vertices_mc] faces_mc_all += [faces_mc] sphere_vertices = self.cfg.sphere_vertices atlas_faces = self.cfg.sphere_faces # self.sphere_vertices = sphere_ssvertices.repeat(self.config.config.batch_size,1,1).float() p = torch.acos(sphere_vertices[:,2]).cuda() t = torch.atan2(sphere_vertices[:,1], sphere_vertices[:,0]).cuda() p = torch.tensor(p, requires_grad=True) t = torch.tensor(t, requires_grad=True) ## change for model_id = 4 if self.point_model is None: y_outer = sample_outer_surface_in_voxel((y==i).long()) surface_points = torch.nonzero(y_outer) surface_points = torch.flip(surface_points, dims=[1]).float() # convert z,y,x -> x, y, z surface_points_normalized = normalize_vertices(surface_points, shape) # surface_points_normalized = y_outer # perm = torch.randperm(len(surface_points_normalized)) N = len(surface_points_normalized) surface_points_normalized_all += [surface_points_normalized.cuda()] if N > 0: break else: print("re-applying deformation coz N=0") # print('in') # breakpoint() if self.mode == DataModes.TRAINING_EXTENDED: return { 'x': x, 'faces_atlas': atlas_faces, 'y_voxels': y, 'surface_points': surface_points_normalized_all, 'p':p, 't':t, 'unpool':self.cfg.unpool_indices, 'w': y_outer, 'theta': theta.inverse()[:3], 'base_plane' : base_plane } else: return { 'x': x, 'x_super_res': x_super_res, 'faces_atlas': atlas_faces, 'y_voxels': y, 'y_voxels_super_res': y_super_res, 'vertices_mc': vertices_mc_all, 'faces_mc': faces_mc_all, 'surface_points': surface_points_normalized_all, 'p':p, 't':t, 'unpool':[0, 1, 0, 1, 1], 'theta': theta.inverse()[:3], 'base_plane': base_plane }