def test_inverse(self): one_hot = tio.OneHot() subject_one_hot = one_hot(self.sample_subject) subject_back = subject_one_hot.apply_inverse_transform() self.assertTensorEqual( self.sample_subject.label.data, subject_back.label.data, )
def __init__(self, mask, ct_path=None, ds_cts=None, structures=None): if ct_path is None and ds_cts is None: raise ValueError('At least ct_path should be provided') self.masks = mask if not isinstance(mask, str) else tio.LabelMap(mask) self.structures = structures self.masks_itk = self.masks.as_sitk() self.transform = tio.OneHot() self.one_hot_masks = self.transform(self.masks) self.n_masks = self.one_hot_masks.shape[0] - 1 self.ct_files = natsorted([ os.path.join(ct_path, ct) for ct in os.listdir(ct_path) if ct.endswith("dcm") ]) self.ds_ct = ds_cts or [ dcmread(ct_file, force=True) for ct_file in self.ct_files ] self.ds_ct.reverse()
def test_batch_history(self): # https://github.com/fepegar/torchio/discussions/743 subject = self.sample_subject transform = tio.Compose([ tio.RandomAffine(), tio.CropOrPad(5), tio.OneHot(), ]) dataset = tio.SubjectsDataset([subject], transform=transform) loader = torch.utils.data.DataLoader( dataset, collate_fn=tio.utils.history_collate ) batch = tio.utils.get_first_item(loader) transformed: tio.Subject = tio.utils.get_subjects_from_batch(batch)[0] inverse = transformed.apply_inverse_transform() images1 = subject.get_images(intensity_only=False) images2 = inverse.get_images(intensity_only=False) for image1, image2 in zip(images1, images2): assert image1.shape == image2.shape
image_file = home + '/Sync-Exp/Data/DHCP/sub-CC00060XX03_ses-12501_desc-restore_T2w.nii.gz' label_file = home + '/Sync-Exp/Data/DHCP/sub-CC00060XX03_ses-12501_desc-fusion_space-T2w_dseg.nii.gz' subject = tio.Subject( image=tio.ScalarImage(image_file), label=tio.LabelMap(label_file), ) subject = tio.Subject( hr=tio.ScalarImage(image_file), lr=tio.ScalarImage(image_file), ) #normalization = tio.ZNormalization(masking_method='label')#masking_method=tio.ZNormalization.mean) normalization = tio.ZNormalization() onehot = tio.OneHot() spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0) bias = tio.RandomBiasField(coefficients=0.5, p=0) flip = tio.RandomFlip(axes=('LR', ), p=1) noise = tio.RandomNoise(std=0.1, p=0) sampling_1mm = tio.Resample(1) sampling_05mm = tio.Resample(0.5) blur = tio.RandomBlur(0.5) sampling_jess = tio.Resample((0.8, 0.8, 2), exclude='hr') blur_jess = tio.Blur(std=(0.001, 0.001, 1), exclude='hr') downsampling_jess = tio.Resample((0.8, 0.8, 2), exclude='hr') upsampling_jess = tio.Resample(target='hr', exclude='hr')
def test_one_hot(self): image = self.sample_subject.label one_hot = tio.OneHot(num_classes=3)(image) assert one_hot.num_channels == 3
def test_multichannel(self): label_map = tio.LabelMap(tensor=torch.rand(2, 3, 3, 3) > 1) with self.assertRaises(RuntimeError): tio.OneHot()(label_map)
channels[1], channels[2], mri_chan=subject['data'][0][1] > 0, angle_num=gif_angle_rotation, angle_view=gif_view_angle, fig_size=fig_size_gif) make_gif(true_gif_output_path, os.path.join(true_gif_output_path, 'true.gif'), angle_num=gif_angle_rotation) if __name__ == "__main__": validation_transform = tio.Compose([ tio.ZNormalization(masking_method=tio.ZNormalization.mean), tio.CropOrPad((240, 240, 160)), tio.OneHot(num_classes=5) ]) gen_visuals( image_path= "../brats_new/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_010", transforms=validation_transform, model_path="./Models/test_train_many_1e-3.pt", gen_pred=True, gen_true=True, input_channels_list=['flair', 't1', 't2', 't1ce'], seg_channels=[1, 2, 4], gen_gif=False, true_gif_output_path="../output/true", pred_gif_output_path="../output/pred", seg_channels_to_display_gif=[1, 2, 4], gif_view_angle=30,
def test_dataset(self, modalities): """ Testing the Dataset class Note that test is not for """ output_path_mod = pathlib.Path(self.output_path, str("temp_folder_" + ("_").join(modalities.split(",")))).as_posix() comp_path = pathlib.Path(output_path_mod).resolve().joinpath('dataset.csv').as_posix() comp_table = pd.read_csv(comp_path, index_col=0) print(comp_path, comp_table) #Loading from nrrd files subjects_nrrd = Dataset.load_image(output_path_mod, ignore_multi=True) #Loading files directly # subjects_direct = Dataset.load_directly(self.input_path,modalities=modalities,ignore_multi=True) #The number of subjects is equal to the number of components which is 2 for this dataset # assert len(subjects_nrrd) == len(subjects_direct) == 2, "There was some error in generation of subject object" # assert subjects_nrrd[0].keys() == subjects_direct[0].keys() # del subjects_direct # To check if all metadata items present in the keys # temp_nrrd = subjects_nrrd[0] # columns_shdbe_present = set([col if col.split("_")[0]=="metadata" else "mod_"+("_").join(col.split("_")[1:]) for col in list(comp_table.columns) if col.split("_")[0] in ["folder","metadata"]]) # print(columns_shdbe_present) # assert set(temp_nrrd.keys()).issubset(columns_shdbe_present), "Not all items present in dictionary, some fault in going through the different columns in a single component" transforms = tio.Compose([tio.Resample(4), tio.CropOrPad((96,96,40)), select_roi_names(["larynx"]), tio.OneHot()]) #Forming dataset and dataloader test_set = tio.SubjectsDataset(subjects_nrrd, transform=transforms) test_loader = torch.utils.data.DataLoader(test_set,batch_size=2,shuffle=True,collate_fn = collate_fn) #Check test_set is correct assert len(test_set)==2 #Get items from test loader #If this function fails , there is some error in formation of test data = next(iter(test_loader)) A = [item[1].shape == (2,1,96,96,40) if not "RTSTRUCT" in item[0] else item[1].shape == (2,2,96,96,40) for item in data.items()] assert all(A), "There is some problem in the transformation/the formation of subject object"