def __getitem__(self, idx): #modify the collate_fn from the dataloader so that it filters out None elements. img_name = self.list_images[idx] try: _, image, spacing = load_reorient(img_name) except: spacing = [2, 2, 2] print('error loading {0}'.format(img_name)) if self.remove_corrupt and os.path.isfile(img_name): os.remove(img_name) image, _ = self.get_random(fa=True) if not np.any(image): image, _ = self.get_random(fa=True) #preprocessing original_spacing = np.array(spacing) get_foreground = tio.ZNormalization.mean target_shape = 128, 128, 128 crop_pad = tio.CropOrPad(target_shape) ###operations### standardize = tio.ZNormalization(masking_method=get_foreground) if 'wb' not in self.class_names and 'abd-pel' not in self.class_names: downsample = tio.Resample( (2 / spacing[0], 2 / spacing[1], 2 / spacing[2])) try: image = standardize(crop_pad(downsample(image))) except: print(img_name) else: x, y, z = image.shape[1:] / np.asarray(target_shape) downsample = tio.Resample((x, y, z)) image = standardize(crop_pad(downsample(image))) if image.shape[0] > 1: image = np.expand_dims(image[0, :, :, :], axis=0) #print(image.shape) print(img_name) sample = { 'image': torch.from_numpy(image), 'label': np.array(0), 'spacing': original_spacing, 'fn': img_name } return sample
def test_different_spaces(self): t1 = self.sample_subject.t1 label = tio.Resample(2)(self.sample_subject.label) new_subject = tio.Subject(t1=t1, label=label) with self.assertRaises(RuntimeError): tio.RandomAffine()(new_subject) tio.RandomAffine(check_shape=False)(new_subject)
def test_bad_affine(self): shape = 1, 2, 3 affine = np.eye(3) target = shape, affine transform = tio.Resample(target) with self.assertRaises(RuntimeError): transform(self.sample_subject)
def test_spacing(self): # Should this raise an error if sizes are different? spacing = 2 transform = tio.Resample(spacing) transformed = transform(self.sample_subject) for image in transformed.get_images(intensity_only=False): self.assertEqual(image.spacing, 3 * (spacing, ))
def test_reference_path(self): reference_image, reference_path = self.get_reference_image_and_path() transform = tio.Resample(reference_path) transformed = transform(self.sample_subject) for image in transformed.values(): self.assertEqual(reference_image.shape, image.shape) self.assertTensorAlmostEqual(reference_image.affine, image.affine)
def test_transforms(self): landmarks_dict = dict( t1=np.linspace(0, 100, 13), t2=np.linspace(0, 100, 13), ) elastic = torchio.RandomElasticDeformation(max_displacement=1) transforms = ( torchio.CropOrPad((9, 21, 30)), torchio.ToCanonical(), torchio.Resample((1, 1.1, 1.25)), torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1), torchio.RandomMotion(), torchio.RandomGhosting(axes=(0, 1, 2)), torchio.RandomSpike(), torchio.RandomNoise(), torchio.RandomBlur(), torchio.RandomSwap(patch_size=2, num_iterations=5), torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY), torchio.RandomBiasField(), torchio.RescaleIntensity((0, 1)), torchio.ZNormalization(masking_method='label'), torchio.HistogramStandardization(landmarks_dict=landmarks_dict), elastic, torchio.RandomAffine(), torchio.OneOf({ torchio.RandomAffine(): 3, elastic: 1 }), torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3), torchio.Crop((3, 2, 8, 0, 1, 4)), ) transform = torchio.Compose(transforms) transform(self.sample)
def test_reference_name(self): subject = self.get_inconsistent_shape_subject() reference_name = 't1' transform = tio.Resample(reference_name) transformed = transform(subject) reference_image = subject[reference_name] for image in transformed.get_images(intensity_only=False): self.assertEqual(reference_image.shape, image.shape) self.assertTensorAlmostEqual(reference_image.affine, image.affine)
def test_affine(self): spacing = 1 affine_name = 'pre_affine' transform = tio.Resample(spacing, pre_affine_name=affine_name) transformed = transform(self.sample_subject) for image in transformed.values(): if affine_name in image: target_affine = np.eye(4) target_affine[:3, 3] = 10, 0, -0.1 self.assertTensorAlmostEqual(image.affine, target_affine) else: self.assertTensorEqual(image.affine, np.eye(4))
def inference(dataset, model): for i in range(len(dataset)): subject = dataset[i] untransformed_subject = dataset.subjects[i] print(f"Running model for subject {subject['name']}") out_folder = args.out_folder if out_folder == "": out_folder = Path(subject["folder"]) else: out_folder = Path(args.out_folder) / subject['name'] out_folder.mkdir(exist_ok=True, parents=True) with torch.no_grad(): subject = patch_predict(model=model, device=device, subjects=[subject], patch_batch_size=1, patch_size=96, patch_overlap=48, padding_mode="edge", overlap_mode="average")[0] transform = subject.get_composed_history() inverse_transform = transform.inverse(warn=False) pred_subject = tio.Subject({'y': subject['y_pred']}) inverse_pred_subject = inverse_transform(pred_subject) output_label = inverse_pred_subject.get_first_image() label_data = output_label['data'][0].numpy() label_data, hole_voxels_removed = remove_holes(label_data, hole_size=64) print(f"Filled {hole_voxels_removed} voxels from detected holes.") label_data, small_lesions_removed = remove_small_components(label_data, 3) print(f"Removed {small_lesions_removed} voxels from small predictions less than size 3.") label_data = torch.from_numpy(label_data[None]).to(torch.int32) output_label.set_data(label_data) target_image = untransformed_subject.get_first_image() output_label = tio.Resample(target_image)(output_label) if output_label.spatial_shape != target_image.spatial_shape: raise Warning(f"Segmentation shape and original image shape do not match.") print() output_label.save(out_folder / args.output_filename)
def __getitem__(self, index): img_path = os.path.join(self.root_dir, self.df.iloc[index, 1]) subject = tio.Subject(img=Image(img_path, type=tio.INTENSITY)) if (self.transform): # in training phase transformations = ( tio.ZNormalization(), tio.Resample(target=2, pre_affine_name="affine"), # preprocessing tio.OneOf(transforms_dict), tio.OneOf(transforms_dict2)) else: # in validation and testing phase transformations = ( tio.ZNormalization(), tio.Resample(target=2, pre_affine_name="affine") # preprocessing ) transformations = Compose(transformations) transformed_image = transformations(subject) get_image = transformed_image.img tensor_resampled_image = get_image.data tensor_resampled_image = tensor_resampled_image.unsqueeze( dim=0) # adding batch size resampled_image = torch.nn.functional.interpolate( input=tensor_resampled_image, size=(256, 256, 166), mode='trilinear' ) # trilinear because it had 5D (mini-batch x channels x height x width x depth) resampled_image = np.reshape(resampled_image, (1, 256, 256, 166)) y_label = 0.0 if self.df.iloc[index, 2] == 'AD' else 1.0 y_label = torch.tensor(y_label, dtype=torch.float) return resampled_image, y_label
def get_dataset( input_path, tta_iterations=0, interpolation='bspline', tolerance=0.1, mni_transform_path=None, ): if mni_transform_path is None: image = tio.ScalarImage(input_path) else: affine = tio.io.read_matrix(mni_transform_path) image = tio.ScalarImage(input_path, **{TO_MNI: affine}) subject = tio.Subject({IMAGE_NAME: image}) landmarks = np.array([ 0., 0.31331614, 0.61505419, 0.76732501, 0.98887953, 1.71169384, 3.21741126, 13.06931455, 32.70817796, 40.87807389, 47.83508873, 63.4408591, 100. ]) hist_std = tio.HistogramStandardization({IMAGE_NAME: landmarks}) preprocess_transforms = [ tio.ToCanonical(), hist_std, tio.ZNormalization(masking_method=tio.ZNormalization.mean), ] zooms = nib.load(input_path).header.get_zooms() pixdim = np.array(zooms) diff_to_1_iso = np.abs(pixdim - 1) if np.any(diff_to_1_iso > tolerance) or mni_transform_path is not None: kwargs = {'image_interpolation': interpolation} if mni_transform_path is not None: kwargs['pre_affine_name'] = TO_MNI kwargs['target'] = tio.datasets.Colin27().t1.path resample_transform = tio.Resample(**kwargs) preprocess_transforms.append(resample_transform) preprocess_transforms.append(tio.EnsureShapeMultiple(8, method='crop')) preprocess_transform = tio.Compose(preprocess_transforms) no_aug_dataset = tio.SubjectsDataset([subject], transform=preprocess_transform) aug_subjects = tta_iterations * [subject] if not aug_subjects: return no_aug_dataset augment_transform = tio.Compose(( preprocess_transform, tio.RandomFlip(), tio.RandomAffine(image_interpolation=interpolation), )) aug_dataset = tio.SubjectsDataset(aug_subjects, transform=augment_transform) dataset = torch.utils.data.ConcatDataset((no_aug_dataset, aug_dataset)) return dataset
def get_transform(self, channels, is_3d=True, labels=True): landmarks_dict = { channel: np.linspace(0, 100, 13) for channel in channels } disp = 1 if is_3d else (1, 1, 0.01) elastic = tio.RandomElasticDeformation(max_displacement=disp) cp_args = (9, 21, 30) if is_3d else (21, 30, 1) resize_args = (10, 20, 30) if is_3d else (10, 20, 1) flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) remapping = {1: 2, 2: 1, 3: 20, 4: 25} transforms = [ tio.CropOrPad(cp_args), tio.EnsureShapeMultiple(2, method='crop'), tio.Resize(resize_args), tio.ToCanonical(), tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample), tio.CopyAffine(channels[0]), tio.Resample((1, 1.1, 1.25)), tio.RandomFlip(axes=flip_axes, flip_probability=1), tio.RandomMotion(), tio.RandomGhosting(axes=(0, 1, 2)), tio.RandomSpike(), tio.RandomNoise(), tio.RandomBlur(), tio.RandomSwap(patch_size=swap_patch, num_iterations=5), tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY), tio.RandomBiasField(), tio.RescaleIntensity(out_min_max=(0, 1)), tio.ZNormalization(), tio.HistogramStandardization(landmarks_dict), elastic, tio.RandomAffine(), tio.OneOf({ tio.RandomAffine(): 3, elastic: 1, }), tio.RemapLabels(remapping=remapping, masking_method='Left'), tio.RemoveLabels([1, 3]), tio.SequentialLabels(), tio.Pad(pad_args, padding_mode=3), tio.Crop(crop_args), ] if labels: transforms.append(tio.RandomLabelsToImage(label_key='label')) return tio.Compose(transforms)
def apply_transform(self, subject): current_spacing = subject.get_first_image().spacing if isinstance(self.target_spacing, str): target_spacing = self.spacing_modes[self.target_spacing]( current_spacing) target_spacing = (target_spacing, target_spacing, target_spacing) else: target_spacing = self.target_spacing # No operation if all current spacings are within tolerance of the target spacing if all( abs(cur - tar) < tol for cur, tar, tol in zip( current_spacing, target_spacing, self.tolerance)): return subject # Iteratively scale new_spacing = [] for cur, tar, tol in zip(current_spacing, target_spacing, self.tolerance): step = 1 spacing = cur while abs(spacing - tar) > tol: if cur < tar: scale = tar / cur scale = round(scale * step) / step else: scale = cur / tar scale = 1 / (round(scale * step) / step) spacing = cur * scale step += 1 new_spacing.append(spacing) new_spacing = tuple(new_spacing) resample = tio.Resample(target=new_spacing, image_interpolation=self.image_interpolation, pre_affine_name=self.pre_affine_name, scalars_only=self.scalars_only, **self.kwargs) subject = resample(subject) return subject
def get_transform(self, channels, is_3d=True, labels=True): landmarks_dict = { channel: np.linspace(0, 100, 13) for channel in channels } disp = 1 if is_3d else (1, 1, 0.01) elastic = torchio.RandomElasticDeformation(max_displacement=disp) cp_args = (9, 21, 30) if is_3d else (21, 30, 1) flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1) swap_patch = (2, 3, 4) if is_3d else (3, 4, 1) pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6) crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4) transforms = [ torchio.CropOrPad(cp_args), torchio.ToCanonical(), torchio.RandomDownsample(downsampling=(1.75, 2), axes=axes_downsample), torchio.Resample((1, 1.1, 1.25)), torchio.RandomFlip(axes=flip_axes, flip_probability=1), torchio.RandomMotion(), torchio.RandomGhosting(axes=(0, 1, 2)), torchio.RandomSpike(), torchio.RandomNoise(), torchio.RandomBlur(), torchio.RandomSwap(patch_size=swap_patch, num_iterations=5), torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY), torchio.RandomBiasField(), torchio.RescaleIntensity((0, 1)), torchio.ZNormalization(), torchio.HistogramStandardization(landmarks_dict), elastic, torchio.RandomAffine(), torchio.OneOf({ torchio.RandomAffine(): 3, elastic: 1, }), torchio.Pad(pad_args, padding_mode=3), torchio.Crop(crop_args), ] if labels: transforms.append(torchio.RandomLabelsToImage(label_key='label')) return torchio.Compose(transforms)
def preprocess(self, raw): resample = tio.Resample(raw.shape[0] / self.dim[0], image_interpolation='bspline') resampled = resample(np.expand_dims(raw, 0)) # #fix if image is bigger than fixed third dimension # if raw.shape[2] > self.upper: # # crop the image along the third dimension # dh = int(raw.shape[2] - self.upper / 2) # raw = raw[:, :, dh:-dh] # # # resizing # if self.dim[:2] != raw.shape[:-1]: # output = np.zeros((self.dim[0], self.dim[1], self.dim[2])) # for i in range(raw.shape[-1]): # try: # output[:, :, i] = cv2.resize(raw[:, :, i].astype('float32'), (self.dim[0], self.dim[1])) # except Exception as e: # pass # #print(e) # raw = output # # # check the third dimension # if raw.shape[2] < self.upper: # # ##pad the image with zeros # # if (self.upper - raw.shape[2]) % 2 == 0: # w = int((self.upper - raw.shape[-1]) / 2) # u = np.zeros((self.dim[0], self.dim[1], w)) # raw = np.concatenate((u, raw, u), axis=-1) # else: # w = int((self.upper - raw.shape[-1] - 1) / 2) # u = np.zeros((self.dim[0], self.dim[1], w)) # d = np.zeros((self.dim[0], self.dim[1], w + 1)) # raw = np.concatenate((u, raw, d), axis=-1) # print(f"[DEBUG2] {raw.shape}") return resampled
def test_dataset(self, modalities): """ Testing the Dataset class Note that test is not for """ output_path_mod = pathlib.Path(self.output_path, str("temp_folder_" + ("_").join(modalities.split(",")))).as_posix() comp_path = pathlib.Path(output_path_mod).resolve().joinpath('dataset.csv').as_posix() comp_table = pd.read_csv(comp_path, index_col=0) print(comp_path, comp_table) #Loading from nrrd files subjects_nrrd = Dataset.load_image(output_path_mod, ignore_multi=True) #Loading files directly # subjects_direct = Dataset.load_directly(self.input_path,modalities=modalities,ignore_multi=True) #The number of subjects is equal to the number of components which is 2 for this dataset # assert len(subjects_nrrd) == len(subjects_direct) == 2, "There was some error in generation of subject object" # assert subjects_nrrd[0].keys() == subjects_direct[0].keys() # del subjects_direct # To check if all metadata items present in the keys # temp_nrrd = subjects_nrrd[0] # columns_shdbe_present = set([col if col.split("_")[0]=="metadata" else "mod_"+("_").join(col.split("_")[1:]) for col in list(comp_table.columns) if col.split("_")[0] in ["folder","metadata"]]) # print(columns_shdbe_present) # assert set(temp_nrrd.keys()).issubset(columns_shdbe_present), "Not all items present in dictionary, some fault in going through the different columns in a single component" transforms = tio.Compose([tio.Resample(4), tio.CropOrPad((96,96,40)), select_roi_names(["larynx"]), tio.OneHot()]) #Forming dataset and dataloader test_set = tio.SubjectsDataset(subjects_nrrd, transform=transforms) test_loader = torch.utils.data.DataLoader(test_set,batch_size=2,shuffle=True,collate_fn = collate_fn) #Check test_set is correct assert len(test_set)==2 #Get items from test loader #If this function fails , there is some error in formation of test data = next(iter(test_loader)) A = [item[1].shape == (2,1,96,96,40) if not "RTSTRUCT" in item[0] else item[1].shape == (2,2,96,96,40) for item in data.items()] assert all(A), "There is some problem in the transformation/the formation of subject object"
def main(input_path, checkpoint_path, output_dir, landmarks_path, batch_size, num_workers, resample): import torch from tqdm import tqdm import torchio as tio import models import datasets fps = get_paths(input_path) subjects = [tio.Subject(image=tio.ScalarImage(fp)) for fp in fps] # key must be 'image' as in get_test_transform transform = tio.Compose(( tio.ToCanonical(), datasets.get_test_transform(landmarks_path), )) if resample: transform = tio.Compose(( tio.Resample(), transform, # tio.CropOrPad((264, 268, 144)), # ################################# for BITE? )) dataset = tio.SubjectsDataset(subjects, transform) checkpoint = torch.load(checkpoint_path) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = models.get_unet().to(device) model.load_state_dict(checkpoint['model']) output_dir = Path(output_dir) model.eval() torch.set_grad_enabled(False) loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) output_dir.mkdir(exist_ok=True, parents=True) for batch in tqdm(loader): inputs = batch['image'][tio.DATA].float().to(device) seg = model(inputs).softmax(dim=1)[:, 1:].cpu() > 0.5 for tensor, affine, path in zip(seg, batch['image'][tio.AFFINE], batch['image'][tio.PATH]): image = tio.LabelMap(tensor=tensor, affine=affine.numpy()) path = Path(path) out_path = output_dir / path.name.replace('.nii', '_seg_cnn.nii') image.save(out_path) return 0
def save_feature_maps(input_path, output_dir): torch.set_grad_enabled(False) output_dir = Path(output_dir).expanduser().absolute() output_dir.mkdir(exist_ok=True, parents=True) device = get_device() repo = 'fepegar/resseg' model_name = 'ressegnet' model = torch.hub.load(repo, model_name) model.to(device) model.eval() hooks = [] for args, module in get_all_modules(model): hook = module.register_forward_hook(get_activation(args)) preprocessed_subject = get_dataset(input_path)[0] image = preprocessed_subject[IMAGE_NAME] inputs = image.data.unsqueeze(0).float().to(device) with torch.cuda.amp.autocast(): model(inputs) downsampled = [image] for _ in range(2): target = torch.Tensor(downsampled[-1].spacing) * 2 target = tuple(target.tolist()) transform = tio.Resample(target, image_interpolation='nearest') downsampled_image = transform(downsampled[-1]) downsampled.append(downsampled_image) for args, features in tqdm(activation.items()): part, level, conv_layer = args affine = downsampled[level].affine for i, feature_map in enumerate(tqdm(features[0], leave=False)): features_image = tio.ScalarImage( tensor=feature_map.unsqueeze(0).cpu().float(), affine=affine, ) name = f'{part}_level_{level}_layer_{conv_layer}_feature_{i}.nii.gz' path = output_dir / name features_image.save(path)
subject = tio.Subject( hr=tio.ScalarImage(image_file), lr=tio.ScalarImage(image_file), ) #normalization = tio.ZNormalization(masking_method='label')#masking_method=tio.ZNormalization.mean) normalization = tio.ZNormalization() onehot = tio.OneHot() spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0) bias = tio.RandomBiasField(coefficients=0.5, p=0) flip = tio.RandomFlip(axes=('LR', ), p=1) noise = tio.RandomNoise(std=0.1, p=0) sampling_1mm = tio.Resample(1) sampling_05mm = tio.Resample(0.5) blur = tio.RandomBlur(0.5) sampling_jess = tio.Resample((0.8, 0.8, 2), exclude='hr') blur_jess = tio.Blur(std=(0.001, 0.001, 1), exclude='hr') downsampling_jess = tio.Resample((0.8, 0.8, 2), exclude='hr') upsampling_jess = tio.Resample(target='hr', exclude='hr') tocanonical = tio.ToCanonical() crop1 = tio.CropOrPad((290, 290, 200)) crop2 = tio.CropOrPad((290, 290, 200), include='lr') #transforms = tio.Compose([spatial, bias, flip, normalization, noise]) #transforms = tio.Compose([normalization, sampling_1mm, noise, blur, sampling_05mm]) #transforms = tio.Compose([blur_jess,sampling_jess])
subjects = [] for (image_path, label_path) in zip(df['imagepath'], df['maskpath']): subject = tio.Subject(image=tio.ScalarImage(image_path), mask=tio.LabelMap(label_path)) subjects.append(subject) # собираем особый датасет torchio с пациентами dataset = tio.SubjectsDataset(subjects) # приводим маску к 1 классу if config.to_one_class: for subject in dataset.dry_iter(): subject['mask'] = one(subject['mask']) training_transform = tio.Compose([ tio.Resample(4), tio.ZNormalization( masking_method=tio.ZNormalization.mean ), # вот эту штуку все рекомендовали на форумах torchio. tio.RandomFlip(p=0.25), tio.RandomNoise(p=0.25), # !!! Приходится насильно переводить тензоры в float tio.Lambda(to_float) ]) validation_transform = tio.Compose([ tio.Resample(4), tio.ZNormalization(masking_method=tio.ZNormalization.mean), tio.RandomNoise(p=0.25), tio.Lambda(to_float) ])
def __getitem__(self, index): img_path = os.path.join(self.root_dir, self.df.iloc[index, 0]) # 1 image = nib.load(img_path) img_shape = image.shape target_img_shape = (1, 256, 256, 166) target_img_affine = "[[3.22155614e-08 2.46488298e-04 - 1.20344027e+00 9.32926025e+01], \ [-2.46255622e-04 - 9.42077751e-01 - 3.14872030e-04 1.56945999e+02], \ [-9.41189972e-01 2.46487912e-04 4.11920069e-08 1.12556000e+02], \ [0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00]]" resampled_img_data = image.get_fdata() resampled_data_arr = np.asarray(resampled_img_data) #print("New method") #print(resampled_data_arr.shape) req = np.expand_dims(resampled_data_arr, axis=0) #print(req.shape) req_shape = req.shape resampled_data_arr = np.reshape(resampled_data_arr, req_shape) #print("Niiiiiiiiiiiiiiiiiiiiiiiiiiii") #print(resampled_data_arr.shape) # if img_shape != target_img_shape: if resampled_data_arr != target_img_shape: print("inside") """ resampled_nii = resample_img(image, target_affine=np.eye(4) * 2, target_shape=target_img_shape, interpolation='nearest') resampled_img_data = resampled_nii.get_fdata() """ print(resampled_img_data.shape) transform = tio.Resample(4) resampled_img_data = transform(resampled_data_arr) # images in fpg are now in MNI space #resampled_img_data = resampled_img_data.get_fdata() print("shape") print(resampled_img_data.shape) print("going out") else: print("img_path") print(img_path) resampled_img_data = image.get_fdata() resampled_data_arr = np.asarray(resampled_img_data) # min_max_normalization resampled_data_arr -= np.min(resampled_data_arr) resampled_data_arr /= np.max(resampled_data_arr) #if self.transform: # resampled_data_arr = self.transform(resampled_data_arr) if self.transform: #flip = tio.RandomFlip(axes=('P',), flip_probability=1) #Please use one of: L, R, P, A, I, S, T, B #resampled_data_arr = flip(resampled_data_arr) resampled_data_arr = transform_flip(resampled_data_arr) y_label = 0.0 if self.df.iloc[index, 1] == 'AD' else 1.0 # bz using cross entropy #1 # y_label = [1.0, 0.0] if (self.df.iloc[index, 1] == 'AD') else [0.0, 1.0] # for other cross entropy #2 y_label = torch.tensor(y_label, dtype=torch.float) return resampled_data_arr, y_label
lr_3=tio.ScalarImage(t2_file), ) subjects.append(subject) print('DHCP Dataset size:', len(subjects), 'subjects') # DATA AUGMENTATION normalization = tio.ZNormalization() spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0.75) flip = tio.RandomFlip(axes=('LR', ), flip_probability=0.5) tocanonical = tio.ToCanonical() b1 = tio.Blur(std=(0.001, 0.001, 1), include='lr_1') #blur d1 = tio.Resample((0.8, 0.8, 2), include='lr_1') #downsampling u1 = tio.Resample(target='hr', include='lr_1') #upsampling if in_channels == 3: b2 = tio.Blur(std=(0.001, 1, 0.001), include='lr_2') #blur d2 = tio.Resample((0.8, 2, 0.8), include='lr_2') #downsampling u2 = tio.Resample(target='hr', include='lr_2') #upsampling b3 = tio.Blur(std=(1, 0.001, 0.001), include='lr_3') #blur d3 = tio.Resample((2, 0.8, 0.8), include='lr_3') #downsampling u3 = tio.Resample(target='hr', include='lr_3') #upsampling if in_channels == 1: transforms = [tocanonical, flip, spatial, normalization, b1, d1, u1] training_transform = tio.Compose(transforms) validation_transform = tio.Compose(
def test_image_target(self): tio.Resample(self.sample_subject.t1)(self.sample_subject)
def segment_resection( input_path, model, output_path=None, tta_iterations=0, interpolation='bspline', num_workers=0, show_progress=True, binarize=True, postprocess=True, mni_transform_path=None, ): dataset = get_dataset( input_path, tta_iterations, interpolation, mni_transform_path=mni_transform_path, ) device = get_device() model.to(device) model.eval() torch.set_grad_enabled(False) loader = torch.utils.data.DataLoader( dataset, num_workers=num_workers, collate_fn=lambda x: x, ) all_results = [] for subjects_list_batch in tqdm(loader, disable=not show_progress): tensors = [ subject[IMAGE_NAME][tio.DATA] for subject in subjects_list_batch ] inputs = torch.stack(tensors).float().to(device) with torch.cuda.amp.autocast(): try: probs = model(inputs).softmax( dim=1)[:, 1:].cpu() # discard background except Exception as e: print(e) raise iterable = list(zip(subjects_list_batch, probs)) for subject, prob in tqdm(iterable, leave=False, unit='subject'): subject.image.set_data(prob) subject_back = subject.apply_inverse_transform( warn=False, image_interpolation='linear') all_results.append(subject_back.image.data) result = torch.stack(all_results) mean_prob = result.mean(dim=0) if binarize: mean_prob = (mean_prob >= 0.5).byte() class_ = tio.LabelMap else: class_ = tio.ScalarImage image_kwargs = { 'tensor': mean_prob, 'affine': subject_back[IMAGE_NAME].affine, } resample_kwargs = { 'target': input_path, 'image_interpolation': interpolation, } if mni_transform_path is not None: to_mni = tio.io.read_matrix(mni_transform_path) from_mni = np.linalg.inv(to_mni) image_kwargs[FROM_MNI] = from_mni resample_kwargs['pre_affine_name'] = FROM_MNI image = class_(**image_kwargs) if postprocess: image = tio.KeepLargestComponent()(image) resample = tio.Resample(**resample_kwargs) image_native = resample(image) if output_path is None: input_path = Path(input_path) split = input_path.name.split('.') stem = split[0] exts_string = '.'.join(split[1:]) output_path = input_path.parent / f'{stem}_seg.{exts_string}' image_native.save(output_path)
pbar.write( f"\tFilled {hole_voxels_removed} voxels from detected holes." ) out = torch.from_numpy(out).unsqueeze(0) out = out.int() image = tio.LabelMap(tensor=out) else: raise NotImplementedError image = tio.ScalarImage(tensor=probs) # inverse_transforms = subject.get_composed_history().inverse(warn=False) # image = inverse_transforms(image) orig_subject = dataset.subjects_map[ subject["name"]] # access subject without applying transformations # compare shapes ignoring the channel dimension if image.shape[1:] != orig_subject.shape[1:]: resample_transform = tio.Resample(orig_subject.get_images()[0]) image = resample_transform(image) assert orig_subject.shape[1:] == image.shape[ 1:], "Segmentation shape and original image shape do not match" pbar.write("\tSaving image...") image.save(out_folder / args.output_filename) pbar.write("\tFinished subject") pbar.update(1)
def __getitem__(self, index): file_npy = self.df.iloc[index][0] assert os.path.exists(file_npy), f'npy file {file_npy} does not exists' array_npy = np.load(file_npy) # shape (D,H,W) if array_npy.ndim > 3: array_npy = np.squeeze(array_npy) array_npy = np.expand_dims(array_npy, axis=0) #(C,D,H,W) #if depth_interval==2 (128,128,128)->(64,128,128) depth_start_random = random.randint(0, 20) % self.depth_interval array_npy = array_npy[:, depth_start_random::self.depth_interval, :, :] subject1 = tio.Subject(oct=tio.ScalarImage(tensor=array_npy), ) subjects_list = [subject1] crop_h = random.randint(0, self.random_crop_h) # pad_h_a, pad_h_b = math.floor(crop_h / 2), math.ceil(crop_h / 2) pad_h_a = random.randint(0, crop_h) pad_h_b = crop_h - pad_h_a transform_1 = tio.Compose([ # tio.OneOf({ # tio.RandomAffine(): 0.8, # tio.RandomElasticDeformation(): 0.2, # }, p=0.75,), # tio.RandomGamma(log_gamma=(-0.3, 0.3)), tio.RandomFlip(axes=2, flip_probability=0.5), # tio.RandomAffine( # scales=(0, 0, 0.9, 1.1, 0, 0), degrees=(0, 0, -5, 5, 0, 0), # image_interpolation='nearest'), tio.Crop(cropping=(0, 0, crop_h, 0, 0, 0)), # (d,h,w) crop height tio.Pad(padding=(0, 0, pad_h_a, pad_h_b, 0, 0)), tio.RandomNoise(std=(0, self.random_noise)), tio.Resample(self.resample_ratio), # tio.RescaleIntensity((0, 255)) ]) if random.randint(1, 20) == 5: transform = tio.Compose([tio.Resample(self.resample_ratio)]) else: transform = transform_1 subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform) inputs = subjects_dataset[0]['oct'][tio.DATA] array_3d = np.squeeze(inputs.cpu().numpy()) #shape: (D,H,W) array_3d = array_3d.astype(np.uint8) if self.imgaug_iaa is not None: self.imgaug_iaa.deterministic = True else: if (self.image_shape is None) or\ (array_3d.shape[1:3]) == (self.image_shape[0:2]): # (H,W) array_4d = np.expand_dims(array_3d, axis=-1) #(D,H,W,C) if 'array_4d' not in locals().keys(): list_images = [] for i in range(array_3d.shape[0]): img = array_3d[i, :, :] #(H,W) if (img.shape[0:2]) != (self.image_shape[0:2]): # (H,W) img = cv2.resize( img, (self.image_shape[1], self.image_shape[0])) # resize(width,height) # cvtColor do not support float64 img = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_GRAY2BGR) # other wise , MultiplyBrightness error img = img.astype(np.uint8) if self.imgaug_iaa is not None: img = self.imgaug_iaa(image=img) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) list_images.append(img) array_4d = np.array(list_images) # (D,H,W) array_4d = np.expand_dims(array_4d, axis=-1) #(D,H,W,C) if self.imgaug_iaa is not None: self.imgaug_iaa.deterministic = False if self.channel_first: array_4d = np.transpose(array_4d, (3, 0, 1, 2)) #(D,H,W,C)->(C,D,H,W) array_4d = array_4d.astype(np.float32) array_4d = array_4d / 255. # if array_4d.shape != (1, 64, 64, 64): # print(file_npy) # https://pytorch.org/docs/stable/data.html # It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing). # tensor_x = torch.from_numpy(array_4d) label = int(self.df.iloc[index][1]) return array_4d, label
def __init__(self, **kwargs): super().__init__() self.save_hyperparameters() device = torch.device( "cuda:0" if torch.cuda.is_available() and self.hparams.cuda else "cpu") if self.hparams.modelID == 0: self.net = ResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks, starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks, is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm, res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV, upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D) # TODO think of 2D # self.net = nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1) elif self.hparams.modelID == 2: self.net = DualSpaceResNet(in_channels=self.hparams.in_channels, out_channels=self.hparams.out_channels, res_blocks=self.hparams.model_res_blocks, starting_nfeatures=self.hparams.model_starting_nfeatures, updown_blocks=self.hparams.model_updown_blocks, is_relu_leaky=self.hparams.model_relu_leaky, do_batchnorm=self.hparams.model_do_batchnorm, res_drop_prob=self.hparams.model_drop_prob, is_replicatepad=self.hparams.model_is_replicatepad, out_act=self.hparams.model_out_act, forwardV=self.hparams.model_forwardV, upinterp_algo=self.hparams.model_upinterp_algo, post_interp_convtrans=self.hparams.model_post_interp_convtrans, is3D=self.hparams.is3D, connect_mode=self.hparams.model_dspace_connect_mode, inner_norm_ksp=self.hparams.model_inner_norm_ksp) elif self.hparams.modelID == 3: #Primal-Dual Network, complex Primal self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10, use_original_block = True, use_original_init = True, use_complex_primal = True, g_normtype = "magmax", transform = "Fourier", return_abs = True) elif self.hparams.modelID == 4: #Primal-Dual Network, absolute Primal self.net = PrimalDualNetwork(n_primary=5, n_dual=5, n_iterations=10, use_original_block = True, use_original_init = True, use_complex_primal = False, g_normtype = "magmax", transform = "Fourier") elif self.hparams.modelID == 5: #Primal-Dual UNet Network, absolute Primal self.net = PrimalDualNetwork(n_primary=4, n_dual=5, n_iterations=2, use_original_block = False, use_original_init = False, use_complex_primal = False, g_normtype = "magmax", transform = "Fourier") elif self.hparams.modelID == 6: #Primal-Dual Network v2 (no residual), complex Primal self.net = PrimalDualNetworkNoResidue(n_primary=5, n_dual=5, n_iterations=10, use_original_block = True, use_original_init = True, use_complex_primal = True, residuals=False, g_normtype = "magmax", transform = "Fourier", return_abs = True) else: # TODO: other models sys.exit("Only ReconResNet and DualSpaceResNet have been implemented so far in ReconEngine") if bool(self.hparams.preweights_path): print("Pre-weights found, loding...") chk = torch.load(self.hparams.preweights_path, map_location='cpu') self.net.load_state_dict(chk['state_dict']) if self.hparams.lossID == 0: if self.hparams.in_channels != 1 or self.hparams.out_channels != 1: sys.exit( "Perceptual Loss used here only works for 1 channel input and output") self.loss = PerceptualLoss(device=device, loss_model="unet3Dds", resize=None, loss_type=self.hparams.ploss_type, n_level=self.hparams.ploss_level) # TODO thinkof 2D elif self.hparams.lossID == 1: self.loss = nn.L1Loss(reduction='mean') elif self.hparams.lossID == 2: self.loss = MS_SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device) elif self.hparams.lossID == 3: self.loss = SSIM(channel=self.hparams.out_channels, data_range=1, spatial_dims=3 if self.hparams.is3D else 2, nonnegative_ssim=False).to(device) else: sys.exit("Invalid Loss ID") self.dataspace = DataSpaceHandler(**self.hparams) if self.hparams.ds_mode == 0: trans = tioTransforms augs = tioAugmentations elif self.hparams.ds_mode == 1: trans = pytTransforms augs = pytAugmentations # TODO parameterised everything self.init_transforms = [] self.aug_transforms = [] self.transforms = [] if self.hparams.ds_mode == 0 and self.hparams.cannonicalResample: # Only applicable for TorchIO self.init_transforms += [tio.ToCanonical(), tio.Resample('gt')] if self.hparams.ds_mode == 0 and self.hparams.forceNormAffine: # Only applicable for TorchIO self.init_transforms += [trans.ForceAffine()] if self.hparams.croppad and self.hparams.ds_mode == 1: self.init_transforms += [ trans.CropOrPad(size=self.hparams.input_shape)] self.init_transforms += [trans.IntensityNorm(type=self.hparams.norm_type, return_meta=self.hparams.motion_return_meta)] # dataspace_transforms = self.dataspace.getTransforms() #TODO: dataspace transforms are not in use # self.init_transforms += dataspace_transforms if bool(self.hparams.random_crop) and self.hparams.ds_mode == 1: self.aug_transforms += [augs.RandomCrop( size=self.hparams.random_crop, p=self.hparams.p_random_crop)] if self.hparams.p_contrast_augment > 0: self.aug_transforms += [augs.getContrastAugs( p=self.hparams.p_contrast_augment)] # if the task if MoCo and pre-corrupted vols are not supplied if self.hparams.taskID == 1 and not bool(self.hparams.train_path_inp): if self.hparams.motion_mode == 0 and self.hparams.ds_mode == 0: motion_params = {k.split('motionmg_')[ 1]: v for k, v in self.hparams.items() if k.startswith('motionmg')} self.transforms += [tioMotion.RandomMotionGhostingFast( **motion_params), trans.IntensityNorm()] elif self.hparams.motion_mode == 1 and self.hparams.ds_mode == 1 and not self.hparams.is3D: self.transforms += [pytMotion.Motion2Dv0( sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)] elif self.hparams.motion_mode == 2 and self.hparams.ds_mode == 1 and not self.hparams.is3D: self.transforms += [pytMotion.Motion2Dv1(sigma_range=self.hparams.motion_sigma_range, n_threads=self.hparams.motion_n_threads, restore_original=self.hparams.motion_restore_original, p=self.hparams.motion_p, return_meta=self.hparams.motion_return_meta)] else: sys.exit( "Error: invalid motion_mode, ds_mode, is3D combo. Please double check!") self.static_metamat = sio.loadmat(self.hparams.static_metamat_file) if bool( self.hparams.static_metamat_file) else None if self.hparams.taskID == 0 and self.hparams.use_datacon: self.datacon = DataConsistency( isRadial=self.hparams.is_radial, metadict=self.static_metamat) else: self.datacon = None input_shape = self.hparams.input_shape if self.hparams.is3D else self.hparams.input_shape[ :-1] self.example_input_array = torch.empty( self.hparams.batch_size, self.hparams.in_channels, *input_shape).float() self.saver = ResSaver( self.hparams.res_path, save_inp=self.hparams.save_inp, do_norm=self.hparams.do_savenorm)
fpg.mri.affine #.numpy() # To maybe CropOrPad fpg_ras.mri.spatial_shape target_shape = 96, 96, 48 crop_pad = tio.CropOrPad(target_shape) show_fpg(crop_pad(fpg_ras)) #Random affine :To simulate different positions and size of the patient within the scanner, we can use a RandomAffine transform. #To improve visualization, we will use a 2D image and add a grid to it. image = tio.ScalarImage('slice_7t.png') spacing = image.spacing[0] image = tio.Resample(spacing)(image) print('Downloaded slice:', image) to_pil(image) tio.ScalarImage('slice_7t.png').spacing #grid over the image slice_grid = copy.deepcopy(image) data = slice_grid.data white = data.max() N = 16 data[..., ::N, :, :] = white data[..., :, ::N, :] = white to_pil(slice_grid)
def get_frame(image, i): return image.data[..., i].permute(1, 2, 0).byte() plt.rcParams['animation.embed_limit'] = 25 fig, ax = plt.subplots() im = ax.imshow(get_frame(image, 0)) return animation.FuncAnimation( fig, _update_frame, repeat_delay=image['delay'], frames=image.shape[-1], ) # Source: https://thehigherlearning.wordpress.com/2014/06/25/watching-a-cell-divide-under-an-electron-microscope-is-mesmerizing-gif/ # noqa: E501 array, delay = read_clip('nBTu3oi.gif') plt.imshow(array[..., 0].transpose(1, 2, 0)) plt.plot() image = tio.ScalarImage(tensor=array, delay=delay) original_animation = plot_gif(image) transform = tio.Compose(( tio.Resample((2, 2, 1)), tio.RandomAffine(degrees=(0, 0, 20)), )) torch.manual_seed(0) transformed = transform(image) transformed_animation = plot_gif(transformed)
def patch_sampler(img_filenames, labelmap_filenames, patch_size, sampler_type, out_dir, max_patches=None, voxel_spacing=(), patch_overlap=(0, 0, 0), min_labeled_voxels=1.0, label_prob=0.8, save_patches=False, batch_size=None, prepare_batches=False, inference=False): """Reshape a 3D volumes into a collection of 2D patches The resulting patches are allocated in a dedicated array. Parameters ---------- img_filenames : list of strings Paths to images to extract patches from patch_size : tuple of ints (patch_x, patch_y, patch_z) The dimensions of one patch patch_overlap : tuple of ints (0, patch_x, patch_y) The maximum patch overlap between the patches min_labeled_voxels is not None: : float between 0 and 1 The minimum percentage of labeled pixels for a patch. If set to None patches are extracted based on center_voxel. labelmap_filenames : list of strings Paths to labelmap Returns ------- img_patches, label_patches : array, shape = (n_patches, patch_x, patch_y, patch_z, 1) The collection of patches extracted from the volumes, where `n_patches` is the total number of patches extracted. """ if max_patches is not None: max_patches = int(max_patches / len(img_filenames)) img_patches = [] label_patches = [] patch_counter = 0 save_counter = 0 img_ids = [] label_ids = [] save_size = 1 if prepare_batches: save_size = batch_size print(f'\nExtracting patches from: {img_filenames}\n') for i in tqdm(range(len(img_filenames)), leave=False): if voxel_spacing: util.update_affine(img_filenames[i], labelmap_filenames[i]) if labelmap_filenames: subject = tio.Subject(img=tio.Image(img_filenames[i], type=tio.INTENSITY), labelmap=tio.LabelMap(labelmap_filenames[i])) # Apply transformations #transform = tio.ZNormalization() #transformed = transform(subject) transform = tio.RescaleIntensity((0, 1)) transformed = transform(subject) if voxel_spacing: transform = tio.Resample(voxel_spacing) transformed = transform(transformed) num_img_patches = 0 if sampler_type == 'grid': sampler = tio.data.GridSampler(transformed, patch_size, patch_overlap) for patch in sampler: img_patch = np.array(patch.img.data) label_patch = np.array(patch.labelmap.data) labeled_voxels = torch.count_nonzero( patch.labelmap.data) >= patch_size[0] * patch_size[ 1] * patch_size[2] * min_labeled_voxels center = label_patch[0, int(patch_size[0] / 2), int(patch_size[1] / 2), int(patch_size[2] / 2)] != 0 if labeled_voxels or center: img_patches.append(img_patch) label_patches.append(label_patch) patch_counter += 1 num_img_patches += 1 if save_patches: img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save( img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter, save_size, patch_size, inference, out_dir) # Check if max_patches for img if max_patches is not None: if num_img_patches > max_patches: break else: # Define sampler one_label = 1.0 - label_prob label_probabilities = {0: one_label, 1: label_prob} sampler = tio.data.LabelSampler( patch_size, label_probabilities=label_probabilities) if max_patches is None: generator = sampler(transformed) else: generator = sampler(transformed, max_patches) for patch in generator: img_patches.append(np.array(patch.img.data)) label_patches.append(np.array(patch.labelmap.data)) patch_counter += 1 if save_patches: img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save( img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter, save_size, patch_size, inference, out_dir) print(f'Finished extracting patches.') if save_patches: return img_ids, label_ids else: if patch_size[0] == 1: return np.array(img_patches).reshape( len(img_patches), patch_size[1], patch_size[2], 1), np.array(label_patches).reshape(len(label_patches), patch_size[1], patch_size[2], 1) else: return np.array(img_patches).reshape( len(img_patches), patch_size[0], patch_size[1], patch_size[2], 1), np.array(label_patches).reshape(len(label_patches), patch_size[1], patch_size[2], 1)