Esempio n. 1
0
 def test_inverse(self):
     one_hot = tio.OneHot()
     subject_one_hot = one_hot(self.sample_subject)
     subject_back = subject_one_hot.apply_inverse_transform()
     self.assertTensorEqual(
         self.sample_subject.label.data,
         subject_back.label.data,
     )
Esempio n. 2
0
    def __init__(self, mask, ct_path=None, ds_cts=None, structures=None):

        if ct_path is None and ds_cts is None:
            raise ValueError('At least ct_path should be provided')

        self.masks = mask if not isinstance(mask, str) else tio.LabelMap(mask)
        self.structures = structures
        self.masks_itk = self.masks.as_sitk()
        self.transform = tio.OneHot()
        self.one_hot_masks = self.transform(self.masks)
        self.n_masks = self.one_hot_masks.shape[0] - 1
        self.ct_files = natsorted([
            os.path.join(ct_path, ct) for ct in os.listdir(ct_path)
            if ct.endswith("dcm")
        ])
        self.ds_ct = ds_cts or [
            dcmread(ct_file, force=True) for ct_file in self.ct_files
        ]
        self.ds_ct.reverse()
Esempio n. 3
0
 def test_batch_history(self):
     # https://github.com/fepegar/torchio/discussions/743
     subject = self.sample_subject
     transform = tio.Compose([
         tio.RandomAffine(),
         tio.CropOrPad(5),
         tio.OneHot(),
     ])
     dataset = tio.SubjectsDataset([subject], transform=transform)
     loader = torch.utils.data.DataLoader(
         dataset,
         collate_fn=tio.utils.history_collate
     )
     batch = tio.utils.get_first_item(loader)
     transformed: tio.Subject = tio.utils.get_subjects_from_batch(batch)[0]
     inverse = transformed.apply_inverse_transform()
     images1 = subject.get_images(intensity_only=False)
     images2 = inverse.get_images(intensity_only=False)
     for image1, image2 in zip(images1, images2):
         assert image1.shape == image2.shape
Esempio n. 4
0
image_file = home + '/Sync-Exp/Data/DHCP/sub-CC00060XX03_ses-12501_desc-restore_T2w.nii.gz'
label_file = home + '/Sync-Exp/Data/DHCP/sub-CC00060XX03_ses-12501_desc-fusion_space-T2w_dseg.nii.gz'

subject = tio.Subject(
    image=tio.ScalarImage(image_file),
    label=tio.LabelMap(label_file),
)

subject = tio.Subject(
    hr=tio.ScalarImage(image_file),
    lr=tio.ScalarImage(image_file),
)

#normalization = tio.ZNormalization(masking_method='label')#masking_method=tio.ZNormalization.mean)
normalization = tio.ZNormalization()
onehot = tio.OneHot()

spatial = tio.RandomAffine(scales=0.1, degrees=10, translation=0, p=0)

bias = tio.RandomBiasField(coefficients=0.5, p=0)
flip = tio.RandomFlip(axes=('LR', ), p=1)
noise = tio.RandomNoise(std=0.1, p=0)

sampling_1mm = tio.Resample(1)
sampling_05mm = tio.Resample(0.5)
blur = tio.RandomBlur(0.5)

sampling_jess = tio.Resample((0.8, 0.8, 2), exclude='hr')
blur_jess = tio.Blur(std=(0.001, 0.001, 1), exclude='hr')
downsampling_jess = tio.Resample((0.8, 0.8, 2), exclude='hr')
upsampling_jess = tio.Resample(target='hr', exclude='hr')
Esempio n. 5
0
 def test_one_hot(self):
     image = self.sample_subject.label
     one_hot = tio.OneHot(num_classes=3)(image)
     assert one_hot.num_channels == 3
Esempio n. 6
0
 def test_multichannel(self):
     label_map = tio.LabelMap(tensor=torch.rand(2, 3, 3, 3) > 1)
     with self.assertRaises(RuntimeError):
         tio.OneHot()(label_map)
Esempio n. 7
0
                          channels[1],
                          channels[2],
                          mri_chan=subject['data'][0][1] > 0,
                          angle_num=gif_angle_rotation,
                          angle_view=gif_view_angle,
                          fig_size=fig_size_gif)
            make_gif(true_gif_output_path,
                     os.path.join(true_gif_output_path, 'true.gif'),
                     angle_num=gif_angle_rotation)


if __name__ == "__main__":
    validation_transform = tio.Compose([
        tio.ZNormalization(masking_method=tio.ZNormalization.mean),
        tio.CropOrPad((240, 240, 160)),
        tio.OneHot(num_classes=5)
    ])
    gen_visuals(
        image_path=
        "../brats_new/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_010",
        transforms=validation_transform,
        model_path="./Models/test_train_many_1e-3.pt",
        gen_pred=True,
        gen_true=True,
        input_channels_list=['flair', 't1', 't2', 't1ce'],
        seg_channels=[1, 2, 4],
        gen_gif=False,
        true_gif_output_path="../output/true",
        pred_gif_output_path="../output/pred",
        seg_channels_to_display_gif=[1, 2, 4],
        gif_view_angle=30,
Esempio n. 8
0
    def test_dataset(self, modalities):
        """
        Testing the Dataset class
        Note that test is not for 
        """
        output_path_mod = pathlib.Path(self.output_path, str("temp_folder_" + ("_").join(modalities.split(",")))).as_posix()
        comp_path = pathlib.Path(output_path_mod).resolve().joinpath('dataset.csv').as_posix()
        comp_table = pd.read_csv(comp_path, index_col=0)
        print(comp_path, comp_table)
        
        #Loading from nrrd files
        subjects_nrrd = Dataset.load_image(output_path_mod, ignore_multi=True)
        #Loading files directly
        # subjects_direct = Dataset.load_directly(self.input_path,modalities=modalities,ignore_multi=True)
        
        #The number of subjects is equal to the number of components which is 2 for this dataset
        # assert len(subjects_nrrd) == len(subjects_direct) == 2, "There was some error in generation of subject object"
        # assert subjects_nrrd[0].keys() == subjects_direct[0].keys()

        # del subjects_direct
        # To check if all metadata items present in the keys
        # temp_nrrd = subjects_nrrd[0]
        # columns_shdbe_present = set([col if col.split("_")[0]=="metadata" else "mod_"+("_").join(col.split("_")[1:]) for col in list(comp_table.columns) if col.split("_")[0] in ["folder","metadata"]])
        # print(columns_shdbe_present)
        # assert set(temp_nrrd.keys()).issubset(columns_shdbe_present), "Not all items present in dictionary, some fault in going through the different columns in a single component"

        transforms = tio.Compose([tio.Resample(4), tio.CropOrPad((96,96,40)), select_roi_names(["larynx"]), tio.OneHot()])

        #Forming dataset and dataloader
        test_set = tio.SubjectsDataset(subjects_nrrd, transform=transforms)
        test_loader = torch.utils.data.DataLoader(test_set,batch_size=2,shuffle=True,collate_fn = collate_fn)

        #Check test_set is correct
        assert len(test_set)==2

        #Get items from test loader
        #If this function fails , there is some error in formation of test
        data = next(iter(test_loader))
        A = [item[1].shape == (2,1,96,96,40) if not "RTSTRUCT" in item[0] else item[1].shape == (2,2,96,96,40) for item in data.items()]
        assert all(A), "There is some problem in the transformation/the formation of subject object"