Esempio n. 1
0
 def test_save_subject(self):
     dataset = SubjectsDataset(self.subjects_list, transform=lambda x: x)
     _ = len(dataset)  # for coverage
     subject = dataset[0]
     output_path = self.dir / 'test.nii.gz'
     paths_dict = {'t1': output_path}
     with self.assertWarns(DeprecationWarning):
         dataset.save_sample(subject, paths_dict)
     nii = nib.load(str(output_path))
     ndims_output = len(nii.shape)
     ndims_subject = len(subject['t1'].shape)
     assert ndims_subject == ndims_output + 1
Esempio n. 2
0
 def test_data_loader(self):
     from torch.utils.data import DataLoader
     subj_list = [torchio.datasets.Colin27()]
     dataset = SubjectsDataset(subj_list)
     loader = DataLoader(dataset, batch_size=1, shuffle=True)
     for batch in loader:
         batch['t1'][DATA]
         batch['brain'][DATA]
Esempio n. 3
0
 def test_data_loader(self):
     from torch.utils.data import DataLoader
     subj_list = [self.sample_subject]
     dataset = SubjectsDataset(subj_list)
     loader = DataLoader(dataset, batch_size=1, shuffle=True)
     for batch in loader:
         batch['t1'][DATA]
         batch['label'][DATA]
Esempio n. 4
0
 def setUp(self):
     super().setUp()
     self.subjects = [
         Subject(
             image=ScalarImage(self.get_image_path(f'hs_image_{i}')),
             label=LabelMap(self.get_image_path(f'hs_label_{i}')),
         )
         for i in range(5)
     ]
     self.dataset = SubjectsDataset(self.subjects)
Esempio n. 5
0
def split_dataset(dataset: Dataset, lengths: List[int]) -> List[Dataset]:
    div_points = [0] + np.cumsum(lengths).tolist()

    datasets = []
    for i in range(len(div_points) - 1):
        st = div_points[i]
        end = div_points[i + 1]
        datasets.append(SubjectsDataset(dataset._subjects[st:end]))

    return datasets
Esempio n. 6
0
def main():
    # Define training and patches sampling parameters
    num_epochs = 20
    patch_size = 128
    queue_length = 100
    patches_per_volume = 5
    batch_size = 2

    # Populate a list with images
    one_subject = Subject(
        T1=ScalarImage('../BRATS2018_crop_renamed/LGG75_T1.nii.gz'),
        T2=ScalarImage('../BRATS2018_crop_renamed/LGG75_T2.nii.gz'),
        label=LabelMap('../BRATS2018_crop_renamed/LGG75_Label.nii.gz'),
    )

    # This subject doesn't have a T2 MRI!
    another_subject = Subject(
        T1=ScalarImage('../BRATS2018_crop_renamed/LGG74_T1.nii.gz'),
        label=LabelMap('../BRATS2018_crop_renamed/LGG74_Label.nii.gz'),
    )

    subjects = [
        one_subject,
        another_subject,
    ]

    subjects_dataset = SubjectsDataset(subjects)
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        patches_per_volume,
        UniformSampler(patch_size),
    )

    # This collate_fn is needed in the case of missing modalities
    # In this case, the batch will be composed by a *list* of samples instead
    # of the typical Python dictionary that is collated by default in Pytorch
    batch_loader = DataLoader(
        queue_dataset,
        batch_size=batch_size,
        collate_fn=lambda x: x,
    )

    # Mock PyTorch model
    model = nn.Identity()

    for epoch_index in range(num_epochs):
        logging.info(f'Epoch {epoch_index}')
        for batch in batch_loader:  # batch is a *list* here, not a dictionary
            logits = model(batch)
            logging.info([batch[idx].keys() for idx in range(batch_size)])
            logging.info(logits.shape)
    logging.info('')
Esempio n. 7
0
 def setUp(self):
     super().setUp()
     subjects = []
     for i in range(5):
         image = ScalarImage(self.get_image_path(f'hs_image_{i}'))
         label_path = self.get_image_path(f'hs_label_{i}',
                                          binary=True,
                                          force_binary_foreground=True)
         label = LabelMap(label_path)
         subject = Subject(image=image, label=label)
         subjects.append(subject)
     self.subjects = subjects
     self.dataset = SubjectsDataset(self.subjects)
 def setUp(self):
     super().setUp()
     self.subjects = [
         Subject(
             image=ScalarImage(self.get_image_path(f'hs_image_{i}')),
             label=LabelMap(
                 self.get_image_path(
                     f'hs_label_{i}',
                     binary=True,
                     force_binary_foreground=True,
                 ), ),
         ) for i in range(5)
     ]
     self.dataset = SubjectsDataset(self.subjects)
Esempio n. 9
0
 def run_queue(self, num_workers, **kwargs):
     subjects_dataset = SubjectsDataset(self.subjects_list)
     patch_size = 10
     sampler = UniformSampler(patch_size)
     queue_dataset = Queue(
         subjects_dataset,
         max_length=6,
         samples_per_volume=2,
         sampler=sampler,
         **kwargs,
     )
     _ = str(queue_dataset)
     batch_loader = DataLoader(queue_dataset, batch_size=4)
     for batch in batch_loader:
         _ = batch['one_modality'][DATA]
         _ = batch['segmentation'][DATA]
Esempio n. 10
0
    def setUp(self):
        """Set up test fixtures, if any."""
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
        self.dir.mkdir(exist_ok=True)
        random.seed(42)
        np.random.seed(42)

        registration_matrix = np.array([
            [1, 0, 0, 10],
            [0, 1, 0, 0],
            [0, 0, 1.2, 0],
            [0, 0, 0, 1]
        ])

        subject_a = Subject(
            t1=ScalarImage(self.get_image_path('t1_a')),
        )
        subject_b = Subject(
            t1=ScalarImage(self.get_image_path('t1_b')),
            label=LabelMap(self.get_image_path('label_b', binary=True)),
        )
        subject_c = Subject(
            label=LabelMap(self.get_image_path('label_c', binary=True)),
        )
        subject_d = Subject(
            t1=ScalarImage(
                self.get_image_path('t1_d'),
                pre_affine=registration_matrix,
            ),
            t2=ScalarImage(self.get_image_path('t2_d')),
            label=LabelMap(self.get_image_path('label_d', binary=True)),
        )
        subject_a4 = Subject(
            t1=ScalarImage(self.get_image_path('t1_a'), components=2),
        )
        self.subjects_list = [
            subject_a,
            subject_a4,
            subject_b,
            subject_c,
            subject_d,
        ]
        self.dataset = SubjectsDataset(self.subjects_list)
        self.sample = self.dataset[-1]  # subject_d
Esempio n. 11
0
 def iterate_dataset(subjects_list):
     dataset = SubjectsDataset(subjects_list)
     for _ in dataset:
         pass
Esempio n. 12
0
 def test_wrong_transform_init(self):
     with self.assertRaises(ValueError):
         SubjectsDataset(
             self.subjects_list,
             transform={},
         )
Esempio n. 13
0
#
# print(1)

training_transform = tio.Compose(
    [
        tio.ToCanonical(),
        tio.RandomBlur(std=(0, 1), seed=seed, p=0.1),  # blur 50% of times
        tio.RandomNoise(std=5, seed=1, p=0.5),  # Gaussian noise 50% of times
        tio.OneOf(
            {  # either
                tio.RandomAffine(scales=(0.95, 1.05), degrees=5, seed=seed):
                0.75,  # random affine
                tio.RandomElasticDeformation(max_displacement=(5, 5, 5),
                                             seed=seed):
                0.25,  # or random elastic deformation
            },
            p=0.8),  # applied to 80% of images
    ])

for one_subject in dataset:
    image0 = one_subject.mri
    plt.imshow(image0.data[0, int(image0.shape[1] / 2), :, :])
    plt.show()
    break
dataset_augmented = SubjectsDataset(subjects, transform=training_transform)
for one_subject in dataset_augmented:
    image = one_subject.mri
    plt.imshow(image.data[0, int(image.shape[1] / 2), :, :])
    plt.show()
    pass
#plt.ioff()
data_ref, aff =  read_image('/data/romain/data_exemple/suj_150423/mT1w_1mm.nii')
res, res_fitpar, extra_info = pd.DataFrame(), pd.DataFrame(), dict()

disp_str = disp_str_list[0]; s = 2; xx = 100
for disp_str in disp_str_list:
    for s in [2, 20]: #[1, 2, 3,  5, 7, 10, 12 , 15, 20 ] : # [2,4,6] : #[1, 3 , 5 , 8, 10 , 12, 15, 20 , 25 ]:
        for xx in x0:
            dico_params['displacement_shift_strategy'] = disp_str

            fp = corrupt_data(xx, sigma=s, method=mvt_type, amplitude=10, mvt_axes=mvt_axes)
            dico_params['fitpars'] = fp
            dico_params['nT'] = fp.shape[1]
            t =  RandomMotionFromTimeCourse(**dico_params)
            if 'synth' in suj_type:
                dataset = SubjectsDataset(suj, transform= torchio.Compose([tlab, t ]))
            else:
                dataset = SubjectsDataset(suj, transform= t )

            sample = dataset[0]
            fout = out_path + '/{}_{}_{}_s{}_freq{}_{}'.format(suj_type, mvt_axe_str, mvt_type, s, xx, disp_str)
            fit_pars = t.fitpars - np.tile(t.to_substract[..., np.newaxis],(1,200))
            # fig = plt.figure();plt.plot(fit_pars.T);plt.savefig(fout+'.png');plt.close(fig)
            #sample['image'].save(fout+'.nii')

            extra_info['x0'], extra_info['mvt_type'], extra_info['mvt_axe']= xx, mvt_type, mvt_axe_str
            extra_info['shift_type'], extra_info['sigma'], extra_info['amp'] = disp_str, s, 10
            extra_info['disp'] = np.sum(t.to_substract)

            dff = pd.DataFrame(fit_pars.T); dff.columns = ['x', 'trans_y', 'z', 'r1', 'r2', 'r3']; dff['nbt'] = range(0,200)
            for k,v in extra_info.items():
Esempio n. 15
0
 def test_indexing_nonint(self):
     dset = SubjectsDataset(self.subjects_list)
     dset[torch.tensor(0)]