def apply_transform(self, subject: tio.Subject) -> tio.Subject: if self.full_dwi_image_name not in subject: return subject full_dwi_image = subject[self.full_dwi_image_name] full_dwi = full_dwi_image.data grad = full_dwi_image[self.bvec_name] bvals = grad[:, 3] bvecs = grad[:, :3] mask = (bvals > self.bval_range[0]) & (bvals < self.bval_range[1]) bvecs = bvecs[mask] full_dwi = full_dwi[mask] rand_bvec = bvecs[np.random.randint(bvecs.shape[0])] dist = torch.sum((bvecs - rand_bvec) ** 2, dim=1) closest_indices = np.argsort(dist)[: self.subset_size] number_of_selections = np.random.randint(low=1, high=self.subset_size) ids = torch.randperm(closest_indices.shape[0])[:number_of_selections] selected_indices = closest_indices[ids] mean_dwi = torch.mean(full_dwi[selected_indices], dim=0) if self.mean_dwi_image_name in subject: mean_dwi_image = subject[self.mean_dwi_image_name] else: mean_dwi_image = copy.deepcopy(full_dwi_image) subject.add_image(mean_dwi_image, self.mean_dwi_image_name) mean_dwi_image.set_data(mean_dwi.unsqueeze(0)) return subject
def apply_transform(self, subject: tio.Subject) -> tio.Subject: if self.full_dwi_image_name not in subject: return subject full_dwi_image = subject[self.full_dwi_image_name] full_dwi = full_dwi_image.data.numpy() grad = full_dwi_image[self.bvec_name].numpy() bvals = grad[:, 3] bvecs = grad[:, :3] mask = (bvals > self.bval_range[0]) & (bvals < self.bval_range[1]) bvecs = bvecs[mask] full_dwi = full_dwi[mask] num_dwis = self.get_num_dwis() num_directions = self.get_num_directions() directionality = self.get_directionality() random_directions = np.random.randn(3, num_directions) random_directions = random_directions / np.linalg.norm(random_directions, axis=0, keepdims=True) sample_probabilities = np.max(np.abs(bvecs @ random_directions) ** directionality, axis=1) sample_probabilities = sample_probabilities / sample_probabilities.sum() indices = np.arange(full_dwi.shape[0]) indices = np.random.choice(indices, size=num_dwis, p=sample_probabilities) mean_dwi = np.mean(full_dwi[indices], axis=0, keepdims=True) if self.mean_dwi_image_name in subject: mean_dwi_image = subject[self.mean_dwi_image_name] else: mean_dwi_image = copy.deepcopy(full_dwi_image) subject.add_image(mean_dwi_image, self.mean_dwi_image_name) mean_dwi_image.set_data(mean_dwi) return subject
def __call__(self, subject: tio.Subject): # Sampler random parameters resection_params = self.get_params( self.volumes, self.volumes_range, self.sigmas_range, self.radii_ratio_range, self.angles_range, self.wm_lesion_p, self.clot_p, ) # Convert images to SimpleITK with timer('Convert to SITK', self.verbose): t1_pre = subject[self.image_name].as_sitk() hemisphere = resection_params['hemisphere'] gm_name = f'resection_gray_matter_{hemisphere}' gray_matter_image = subject[gm_name] gray_matter_mask = gray_matter_image.as_sitk() resectable_name = f'resection_resectable_{hemisphere}' resectable_tissue_image = subject[resectable_name] resectable_tissue_mask = resectable_tissue_image.as_sitk() add_wm = resection_params['add_wm_lesion'] add_clot = resection_params['add_clot'] use_csf_image = self.texture == 'csf' or add_wm or add_clot if use_csf_image: noise_image = subject['resection_noise'].as_sitk() else: noise_image = None # Simulate resection with timer('Resection', self.verbose): results = resect( t1_pre, gray_matter_mask, resectable_tissue_mask, resection_params['sigmas'], resection_params['radii'], noise_image=noise_image, shape=self.shape, texture=self.texture, angles=resection_params['angles'], noise_offset=resection_params['noise_offset'], sphere_poly_data=self.sphere_poly_data, wm_lesion=add_wm, clot=add_clot, simplex_path=self.simplex_path, center_ras=self.center_ras, verbose=self.verbose, ) resected_brain, resection_mask, resection_center, clot_center = results # Store centers for visualization purposes resection_params['resection_center'] = resection_center resection_params['clot_center'] = clot_center # Convert from SITK with timer('Convert from SITK', self.verbose): resected_brain_array = self.sitk_to_array(resected_brain) resected_mask_array = self.sitk_to_array(resection_mask) image_resected = self.add_channels_axis(resected_brain_array) resection_label = self.add_channels_axis(resected_mask_array) assert image_resected.ndim == 4 assert resection_label.ndim == 4 # Update subject if self.delete_resection_keys: subject.remove_image('resection_gray_matter_left') subject.remove_image('resection_gray_matter_right') subject.remove_image('resection_resectable_left') subject.remove_image('resection_resectable_right') if use_csf_image: subject.remove_image('resection_noise') # Add resected image and label to subject if self.add_params: subject['random_resection'] = resection_params if self.keep_original: subject['image_original'] = copy.deepcopy(subject[self.image_name]) subject[self.image_name].data = torch.from_numpy(image_resected) label = tio.LabelMap( tensor=resection_label, affine=subject[self.image_name].affine, ) subject.add_image(label, 'label') if self.add_resected_structures: subject['resected_structures'] = self.get_resected_structures( subject, resection_mask) return subject