def test_save_subject(self): dataset = SubjectsDataset(self.subjects_list, transform=lambda x: x) _ = len(dataset) # for coverage subject = dataset[0] output_path = self.dir / 'test.nii.gz' paths_dict = {'t1': output_path} with self.assertWarns(DeprecationWarning): dataset.save_sample(subject, paths_dict) nii = nib.load(str(output_path)) ndims_output = len(nii.shape) ndims_subject = len(subject['t1'].shape) assert ndims_subject == ndims_output + 1
def test_data_loader(self): from torch.utils.data import DataLoader subj_list = [torchio.datasets.Colin27()] dataset = SubjectsDataset(subj_list) loader = DataLoader(dataset, batch_size=1, shuffle=True) for batch in loader: batch['t1'][DATA] batch['brain'][DATA]
def test_data_loader(self): from torch.utils.data import DataLoader subj_list = [self.sample_subject] dataset = SubjectsDataset(subj_list) loader = DataLoader(dataset, batch_size=1, shuffle=True) for batch in loader: batch['t1'][DATA] batch['label'][DATA]
def setUp(self): super().setUp() self.subjects = [ Subject( image=ScalarImage(self.get_image_path(f'hs_image_{i}')), label=LabelMap(self.get_image_path(f'hs_label_{i}')), ) for i in range(5) ] self.dataset = SubjectsDataset(self.subjects)
def split_dataset(dataset: Dataset, lengths: List[int]) -> List[Dataset]: div_points = [0] + np.cumsum(lengths).tolist() datasets = [] for i in range(len(div_points) - 1): st = div_points[i] end = div_points[i + 1] datasets.append(SubjectsDataset(dataset._subjects[st:end])) return datasets
def main(): # Define training and patches sampling parameters num_epochs = 20 patch_size = 128 queue_length = 100 patches_per_volume = 5 batch_size = 2 # Populate a list with images one_subject = Subject( T1=ScalarImage('../BRATS2018_crop_renamed/LGG75_T1.nii.gz'), T2=ScalarImage('../BRATS2018_crop_renamed/LGG75_T2.nii.gz'), label=LabelMap('../BRATS2018_crop_renamed/LGG75_Label.nii.gz'), ) # This subject doesn't have a T2 MRI! another_subject = Subject( T1=ScalarImage('../BRATS2018_crop_renamed/LGG74_T1.nii.gz'), label=LabelMap('../BRATS2018_crop_renamed/LGG74_Label.nii.gz'), ) subjects = [ one_subject, another_subject, ] subjects_dataset = SubjectsDataset(subjects) queue_dataset = Queue( subjects_dataset, queue_length, patches_per_volume, UniformSampler(patch_size), ) # This collate_fn is needed in the case of missing modalities # In this case, the batch will be composed by a *list* of samples instead # of the typical Python dictionary that is collated by default in Pytorch batch_loader = DataLoader( queue_dataset, batch_size=batch_size, collate_fn=lambda x: x, ) # Mock PyTorch model model = nn.Identity() for epoch_index in range(num_epochs): logging.info(f'Epoch {epoch_index}') for batch in batch_loader: # batch is a *list* here, not a dictionary logits = model(batch) logging.info([batch[idx].keys() for idx in range(batch_size)]) logging.info(logits.shape) logging.info('')
def setUp(self): super().setUp() subjects = [] for i in range(5): image = ScalarImage(self.get_image_path(f'hs_image_{i}')) label_path = self.get_image_path(f'hs_label_{i}', binary=True, force_binary_foreground=True) label = LabelMap(label_path) subject = Subject(image=image, label=label) subjects.append(subject) self.subjects = subjects self.dataset = SubjectsDataset(self.subjects)
def setUp(self): super().setUp() self.subjects = [ Subject( image=ScalarImage(self.get_image_path(f'hs_image_{i}')), label=LabelMap( self.get_image_path( f'hs_label_{i}', binary=True, force_binary_foreground=True, ), ), ) for i in range(5) ] self.dataset = SubjectsDataset(self.subjects)
def run_queue(self, num_workers, **kwargs): subjects_dataset = SubjectsDataset(self.subjects_list) patch_size = 10 sampler = UniformSampler(patch_size) queue_dataset = Queue( subjects_dataset, max_length=6, samples_per_volume=2, sampler=sampler, **kwargs, ) _ = str(queue_dataset) batch_loader = DataLoader(queue_dataset, batch_size=4) for batch in batch_loader: _ = batch['one_modality'][DATA] _ = batch['segmentation'][DATA]
def setUp(self): """Set up test fixtures, if any.""" self.dir = Path(tempfile.gettempdir()) / '.torchio_tests' self.dir.mkdir(exist_ok=True) random.seed(42) np.random.seed(42) registration_matrix = np.array([ [1, 0, 0, 10], [0, 1, 0, 0], [0, 0, 1.2, 0], [0, 0, 0, 1] ]) subject_a = Subject( t1=ScalarImage(self.get_image_path('t1_a')), ) subject_b = Subject( t1=ScalarImage(self.get_image_path('t1_b')), label=LabelMap(self.get_image_path('label_b', binary=True)), ) subject_c = Subject( label=LabelMap(self.get_image_path('label_c', binary=True)), ) subject_d = Subject( t1=ScalarImage( self.get_image_path('t1_d'), pre_affine=registration_matrix, ), t2=ScalarImage(self.get_image_path('t2_d')), label=LabelMap(self.get_image_path('label_d', binary=True)), ) subject_a4 = Subject( t1=ScalarImage(self.get_image_path('t1_a'), components=2), ) self.subjects_list = [ subject_a, subject_a4, subject_b, subject_c, subject_d, ] self.dataset = SubjectsDataset(self.subjects_list) self.sample = self.dataset[-1] # subject_d
def iterate_dataset(subjects_list): dataset = SubjectsDataset(subjects_list) for _ in dataset: pass
def test_wrong_transform_init(self): with self.assertRaises(ValueError): SubjectsDataset( self.subjects_list, transform={}, )
# # print(1) training_transform = tio.Compose( [ tio.ToCanonical(), tio.RandomBlur(std=(0, 1), seed=seed, p=0.1), # blur 50% of times tio.RandomNoise(std=5, seed=1, p=0.5), # Gaussian noise 50% of times tio.OneOf( { # either tio.RandomAffine(scales=(0.95, 1.05), degrees=5, seed=seed): 0.75, # random affine tio.RandomElasticDeformation(max_displacement=(5, 5, 5), seed=seed): 0.25, # or random elastic deformation }, p=0.8), # applied to 80% of images ]) for one_subject in dataset: image0 = one_subject.mri plt.imshow(image0.data[0, int(image0.shape[1] / 2), :, :]) plt.show() break dataset_augmented = SubjectsDataset(subjects, transform=training_transform) for one_subject in dataset_augmented: image = one_subject.mri plt.imshow(image.data[0, int(image.shape[1] / 2), :, :]) plt.show() pass
#plt.ioff() data_ref, aff = read_image('/data/romain/data_exemple/suj_150423/mT1w_1mm.nii') res, res_fitpar, extra_info = pd.DataFrame(), pd.DataFrame(), dict() disp_str = disp_str_list[0]; s = 2; xx = 100 for disp_str in disp_str_list: for s in [2, 20]: #[1, 2, 3, 5, 7, 10, 12 , 15, 20 ] : # [2,4,6] : #[1, 3 , 5 , 8, 10 , 12, 15, 20 , 25 ]: for xx in x0: dico_params['displacement_shift_strategy'] = disp_str fp = corrupt_data(xx, sigma=s, method=mvt_type, amplitude=10, mvt_axes=mvt_axes) dico_params['fitpars'] = fp dico_params['nT'] = fp.shape[1] t = RandomMotionFromTimeCourse(**dico_params) if 'synth' in suj_type: dataset = SubjectsDataset(suj, transform= torchio.Compose([tlab, t ])) else: dataset = SubjectsDataset(suj, transform= t ) sample = dataset[0] fout = out_path + '/{}_{}_{}_s{}_freq{}_{}'.format(suj_type, mvt_axe_str, mvt_type, s, xx, disp_str) fit_pars = t.fitpars - np.tile(t.to_substract[..., np.newaxis],(1,200)) # fig = plt.figure();plt.plot(fit_pars.T);plt.savefig(fout+'.png');plt.close(fig) #sample['image'].save(fout+'.nii') extra_info['x0'], extra_info['mvt_type'], extra_info['mvt_axe']= xx, mvt_type, mvt_axe_str extra_info['shift_type'], extra_info['sigma'], extra_info['amp'] = disp_str, s, 10 extra_info['disp'] = np.sum(t.to_substract) dff = pd.DataFrame(fit_pars.T); dff.columns = ['x', 'trans_y', 'z', 'r1', 'r2', 'r3']; dff['nbt'] = range(0,200) for k,v in extra_info.items():
def test_indexing_nonint(self): dset = SubjectsDataset(self.subjects_list) dset[torch.tensor(0)]