def main(): # Define training and patches sampling parameters num_epochs = 20 patch_size = 128 queue_length = 100 patches_per_volume = 5 batch_size = 2 # Populate a list with images one_subject = Subject( T1=ScalarImage('../BRATS2018_crop_renamed/LGG75_T1.nii.gz'), T2=ScalarImage('../BRATS2018_crop_renamed/LGG75_T2.nii.gz'), label=LabelMap('../BRATS2018_crop_renamed/LGG75_Label.nii.gz'), ) # This subject doesn't have a T2 MRI! another_subject = Subject( T1=ScalarImage('../BRATS2018_crop_renamed/LGG74_T1.nii.gz'), label=LabelMap('../BRATS2018_crop_renamed/LGG74_Label.nii.gz'), ) subjects = [ one_subject, another_subject, ] subjects_dataset = SubjectsDataset(subjects) queue_dataset = Queue( subjects_dataset, queue_length, patches_per_volume, UniformSampler(patch_size), ) # This collate_fn is needed in the case of missing modalities # In this case, the batch will be composed by a *list* of samples instead # of the typical Python dictionary that is collated by default in Pytorch batch_loader = DataLoader( queue_dataset, batch_size=batch_size, collate_fn=lambda x: x, ) # Mock PyTorch model model = nn.Identity() for epoch_index in range(num_epochs): logging.info(f'Epoch {epoch_index}') for batch in batch_loader: # batch is a *list* here, not a dictionary logits = model(batch) logging.info([batch[idx].keys() for idx in range(batch_size)]) logging.info(logits.shape) logging.info('')
def get_volume_torchio(self, idx, return_orig=False): subject_row = self.get_row(idx) dict_suj = dict() if not pd.isna(subject_row["image_filename"]): path_imgs = self.read_path(subject_row["image_filename"]) if path_imgs: if isinstance(path_imgs, list): imgs = ScalarImage(tensor=np.asarray( [nb.load(p).get_fdata() for p in path_imgs])) else: imgs = ScalarImage(path_imgs) dict_suj["t1"] = imgs if "label_filename" in subject_row.keys() and not pd.isna( subject_row["label_filename"]): path_imgs = self.read_path(subject_row["label_filename"]) if isinstance(path_imgs, list): imgs = LabelMap(tensor=np.asarray( [nb.load(p).get_fdata() for p in path_imgs])) else: imgs = LabelMap(path_imgs) dict_suj["label"] = imgs sub = Subject(dict_suj) if "history" not in self.df_data.columns: return sub else: trsfms = self.get_transformations(idx) res = sub for tr in trsfms.transforms: #.transforms: print(tr.name) if isinstance(tr, torchio.transforms.LabelsToImage): tr.label_key = "label" if isinstance(tr, torchio.transforms.MotionFromTimeCourse): output_path = opj(self.out_tmp, "{}.png".format(idx)) fitpars = tr.fitpars["t1"] plt.figure() plt.plot(fitpars.T) plt.legend([ "trans_x", "trans_y", "trans_z", "rot_x", "rot_y", "rot_z" ]) plt.xlabel("Timesteps") plt.ylabel("Magnitude") plt.title("Motion parameters") plt.savefig(output_path) plt.close() self.written_files.append(output_path) res = tr(res) res = trsfms(sub) return res
def get_volume_torchio_without_motion(self, idx, return_orig=False): subject_row = self.get_row(idx) dict_suj = dict() if not pd.isna(subject_row["image_filename"]): path_imgs = self.read_path(subject_row["image_filename"]) if path_imgs: imgs = ScalarImage(path_imgs) dict_suj["t1"] = imgs if "label_filename" in subject_row.keys() and not pd.isna( subject_row["label_filename"]): path_imgs = self.read_path(subject_row["label_filename"]) imgs = LabelMap(path_imgs) dict_suj["label"] = imgs sub = Subject(dict_suj) if "history" not in self.df_data.columns: return sub else: trsfms = self.get_transformations(idx) trsfms_short = [] for tr in trsfms.transforms: #.transforms: print(tr.name) if isinstance(tr, torchio.transforms.LabelsToImage): tr.label_key = "label" if isinstance(tr, torchio.transforms.MotionFromTimeCourse): tmot = tr break trsfms_short.append(tr) trsfms_short = torchio.Compose(trsfms_short) res = trsfms_short(sub) return res, tmot
def _resize(self, image: ScalarImage, label_map: LabelMap): res_subject = Resize(self.cfg["desired_size"]).apply_transform( Subject( image=ScalarImage(tensor=image.data), seg=LabelMap(tensor=label_map.data), ) ) return res_subject["image"], res_subject["seg"]
def get_subject_with_partial_volume_label_map(self, components=1): """Return a subject with a partial-volume label map.""" return Subject( t1=ScalarImage(self.get_image_path('t1_d'), ), label=LabelMap( self.get_image_path('label_d2', binary=False, components=components)), )
def setUp(self): """Set up test fixtures, if any.""" self.dir = Path(tempfile.gettempdir()) / '.torchio_tests' self.dir.mkdir(exist_ok=True) random.seed(42) np.random.seed(42) registration_matrix = np.array([ [1, 0, 0, 10], [0, 1, 0, 0], [0, 0, 1.2, 0], [0, 0, 0, 1] ]) subject_a = Subject( t1=ScalarImage(self.get_image_path('t1_a')), ) subject_b = Subject( t1=ScalarImage(self.get_image_path('t1_b')), label=LabelMap(self.get_image_path('label_b', binary=True)), ) subject_c = Subject( label=LabelMap(self.get_image_path('label_c', binary=True)), ) subject_d = Subject( t1=ScalarImage( self.get_image_path('t1_d'), pre_affine=registration_matrix, ), t2=ScalarImage(self.get_image_path('t2_d')), label=LabelMap(self.get_image_path('label_d', binary=True)), ) subject_a4 = Subject( t1=ScalarImage(self.get_image_path('t1_a'), components=2), ) self.subjects_list = [ subject_a, subject_a4, subject_b, subject_c, subject_d, ] self.dataset = SubjectsDataset(self.subjects_list) self.sample = self.dataset[-1] # subject_d
def get_inconsistent_shape_subject(self): """Return a subject containing images of different shape.""" subject = Subject( t1=ScalarImage(self.get_image_path('t1_inc')), t2=ScalarImage(self.get_image_path('t2_inc', shape=(10, 20, 31))), label=LabelMap( self.get_image_path( 'label_inc', shape=(8, 17, 25), binary=True, ), ), label2=LabelMap( self.get_image_path( 'label2_inc', shape=(18, 17, 25), binary=True, ), ), ) return subject
def setUp(self): super().setUp() self.subjects = [ Subject( image=ScalarImage(self.get_image_path(f'hs_image_{i}')), label=LabelMap(self.get_image_path(f'hs_label_{i}')), ) for i in range(5) ] self.dataset = SubjectsDataset(self.subjects)
def test_all_random_transforms(self): sample = Subject(t1=ScalarImage(tensor=torch.rand(1, 20, 20, 20)), seg=LabelMap(tensor=torch.rand(1, 20, 20, 20) > 1)) transforms_names = [ name for name in dir(torchio) if name.startswith('Random') ] # Downsample at the end so that image shape is not modified transforms_names.remove('RandomDownsample') transforms_names.append('RandomDownsample') transforms = [] for transform_name in transforms_names: # Only transform needing an argument for __init__ if transform_name == 'RandomLabelsToImage': transform = getattr(torchio, transform_name)(label_key='seg') else: transform = getattr(torchio, transform_name)() transforms.append(transform) composed_transform = torchio.Compose(transforms) with warnings.catch_warnings(): # ignore elastic deformation warning warnings.simplefilter('ignore', RuntimeWarning) transformed = composed_transform(sample) new_transforms, seeds = compose_from_history(transformed.history) new_transformed = self.apply_transforms(subject=sample, trsfm_list=new_transforms, seeds_list=seeds) """ new_transforms = [] seeds = [] for transform_name, params_dict in transformed.history: # The Resample transform in the history comes from the DownSampling if transform_name in ['Resample', 'Compose']: continue transform_class = getattr(torchio, transform_name) if transform_name == 'RandomLabelsToImage': transform = transform_class(label_key='seg') else: transform = transform_class() new_transforms.append(transform) seeds.append(params_dict['seed']) composed_transform = torchio.Compose(new_transforms) with warnings.catch_warnings(): # ignore elastic deformation warning warnings.simplefilter('ignore', RuntimeWarning) new_transformed = composed_transform(sample, seeds=seeds) """ self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data) self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data)
def _generate_cache(self): os.makedirs(self.cfg["cache_path"]) data_map = self._create_data_map() for path, label_paths in tqdm( data_map.items(), total=len(data_map), position=0, leave=False, desc="Caching pooled data", ): subject_id = self._get_subject_id(path) image = ScalarImage(path, check_nans=True) label_map = LabelMap(label_paths, check_nans=True) if not isinstance(label_map, list): one_hot = OneHot(self.cfg["num_classes"] + 1) label_map = one_hot(label_map) label_map = LabelMap(tensor=label_map.tensor[1:]) # type: ignore image.load() label_map.load() image, label_map = self._resize(image, label_map) torch.save( {"subject_id": subject_id, "image": image.data, "seg": label_map.data}, os.path.join(self.cfg["cache_path"], f"{subject_id}.pt"), )
def get_volume_torchio(self, idx, return_orig=False): subject_row = self.get_row(idx) dict_suj = dict() if not pd.isna(subject_row["image_filename"]): path_imgs = self.read_path(subject_row["image_filename"]) if path_imgs: imgs = ScalarImage(path_imgs) dict_suj["t1"] = imgs if "label_filename" in subject_row.keys() and not pd.isna( subject_row["label_filename"]): path_imgs = self.read_path(subject_row["label_filename"]) imgs = LabelMap(path_imgs) dict_suj["label"] = imgs sub = Subject(dict_suj) if "history" not in self.df_data.columns: return sub else: trsfms = self.get_transformations(idx) res = sub for tr in trsfms.transforms: #.transforms: print(tr.name) if isinstance(tr, torchio.transforms.LabelsToImage): tr.label_key = "label" if isinstance(tr, torchio.transforms.MotionFromTimeCourse): output_path = opj(self.out_tmp, "{}.png".format(idx)) fitpars = tr.fitpars["t1"] plt.figure() plt.plot(fitpars.T) plt.legend([ "trans_x", "trans_y", "trans_z", "rot_x", "rot_y", "rot_z" ]) plt.xlabel("Timesteps") plt.ylabel("Magnitude") plt.title("Motion parameters") plt.savefig(output_path) plt.close() self.written_files.append(output_path) #Bad bug fix, du to frequency_encogin_dim save without a dict ... if isinstance( tr, torchio.transforms.augmentation.intensity. random_motion_from_time_course.MotionFromTimeCourse): if isinstance(tr.tr, dict): if not isinstance(tr.frequency_encoding_dim, dict): value = tr.frequency_encoding_dim aaa = dict() for k in tr.tr.keys(): aaa[k] = value tr.frequency_encoding_dim = aaa res = trsfms(sub) return res
def setUp(self): super().setUp() subjects = [] for i in range(5): image = ScalarImage(self.get_image_path(f'hs_image_{i}')) label_path = self.get_image_path(f'hs_label_{i}', binary=True, force_binary_foreground=True) label = LabelMap(label_path) subject = Subject(image=image, label=label) subjects.append(subject) self.subjects = subjects self.dataset = SubjectsDataset(self.subjects)
def setUp(self): super().setUp() self.subjects = [ Subject( image=ScalarImage(self.get_image_path(f'hs_image_{i}')), label=LabelMap( self.get_image_path( f'hs_label_{i}', binary=True, force_binary_foreground=True, ), ), ) for i in range(5) ] self.dataset = SubjectsDataset(self.subjects)
def plot_aggregated_image( writer: SummaryWriter, epoch: int, model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, # type: ignore device: torch.device, save_path: str, ): log = logging.getLogger(__name__) sampler, subject_id = random_subject_from_loader(data_loader) aggregator_x = GridAggregator(sampler) aggregator_y = GridAggregator(sampler) aggregator_y_pred = GridAggregator(sampler) for batch, locations in batches_from_sampler(sampler, data_loader.batch_size): x: torch.Tensor = batch["image"]["data"] aggregator_x.add_batch(x, locations) y: torch.Tensor = batch["seg"]["data"] aggregator_y.add_batch(y, locations) logits = model(x.to(device)) y_pred = (torch.sigmoid(logits) > 0.5).float() aggregator_y_pred.add_batch(y_pred, locations) whole_x = aggregator_x.get_output_tensor() whole_y = aggregator_y.get_output_tensor() whole_y_pred = aggregator_y_pred.get_output_tensor() plot_subject( Subject( image=ScalarImage(tensor=whole_x), true_seg=LabelMap(tensor=whole_y), pred_seg=LabelMap(tensor=whole_y_pred), ), f"{save_path}/{epoch}-{subject_id}", )
def _get_subjects_list(self) -> List[Subject]: if not os.path.exists(self.cfg["cache_path"]): self._generate_cache() multitasking.wait_for_tasks() subjects = [] for pt_name in os.listdir(self.cfg["cache_path"]): pt_data = torch.load(os.path.join(self.cfg["cache_path"], pt_name)) subjects.append( Subject( subject_id=pt_data["subject_id"], image=ScalarImage(tensor=pt_data["image"]), seg=LabelMap(tensor=pt_data["seg"]), ) ) return subjects
def plot_subject(subject: Subject, save_plot_path: str): if save_plot_path: os.makedirs(save_plot_path, exist_ok=True) data_dict = {} sx, sy, sz = subject.spatial_shape sx, sy, sz = min(sx, sy, sz) / sx, min(sx, sy, sz) / sy, min(sx, sy, sz) / sz for name, image in subject.get_images_dict(intensity_only=False).items(): if isinstance(image, LabelMap): data_dict[name] = LabelMap( tensor=squeeze_segmentation(image), affine=np.eye(4) * np.array([sx, sy, sz, 1]), ) else: data_dict[name] = ScalarImage(tensor=image.data, affine=np.eye(4) * np.array([sx, sy, sz, 1])) out_subject = Subject(data_dict) out_subject.plot(reorient=False, show=True, figsize=(10, 10)) mpl, plt = import_mpl_plt() backend_ = mpl.get_backend() plt.ioff() mpl.use("agg") for x in range(max(out_subject.spatial_shape)): out_subject.plot( reorient=False, indices=( min(x, out_subject.spatial_shape[0] - 1), min(x, out_subject.spatial_shape[1] - 1), min(x, out_subject.spatial_shape[2] - 1), ), output_path=f"{save_plot_path}/{x:03d}.png", show=False, figsize=(10, 10), ) plt.close("all") plt.ion() mpl.use(backend_) create_gifs(save_plot_path, f"{save_plot_path}/{os.path.basename(save_plot_path)}.gif")
def test_crop_label_map_type(self): data = torch.ones((10, 10, 10)) label = LabelMap(tensor=data) cropped = label.crop((1, 1, 1), (5, 5, 5)) self.assertIs(cropped.type, LABEL)
def test_label_map_type(self): data = torch.ones((10, 10, 10)) label = LabelMap(tensor=data) self.assertIs(label.type, LABEL)
def test_wrong_label_map_type(self): data = torch.ones((10, 10, 10)) with self.assertRaises(ValueError): LabelMap(tensor=data, type=INTENSITY)
def get_volume_torchio(self, idx, return_orig=False): subject_row = self.get_row(idx) dict_suj = dict() if not pd.isna(subject_row["image_filename"]): path_imgs = self.read_path(subject_row["image_filename"]) if isinstance(path_imgs, list): imgs = ScalarImage(tensor=np.asarray( [nb.load(p).get_fdata() for p in path_imgs])) else: imgs = ScalarImage(path_imgs) dict_suj["volume"] = imgs if "label_filename" in subject_row.keys() and not pd.isna( subject_row["label_filename"]): path_imgs = self.read_path(subject_row["label_filename"]) if isinstance(path_imgs, list): imgs = LabelMap(tensor=np.asarray( [nb.load(p).get_fdata() for p in path_imgs])) else: imgs = LabelMap(path_imgs) dict_suj["label"] = imgs sub = Subject(dict_suj) if return_orig or "transfo_order" not in self.df_data.columns: return sub else: trsfms, seeds = self.get_transformations(idx) for tr in trsfms.transform.transforms: if isinstance(tr, torchio.transforms.RandomLabelsToImage): tr.label_key = "label" if isinstance(tr, torchio.transforms.RandomMotionFromTimeCourse): output_path = opj(self.out_tmp, "{}.png".format(idx)) if "fitpars" in self.df_data.columns: fitpars = np.loadtxt(self.df_data["fitpars"][idx]) tr.fitpars = fitpars tr.simulate_displacement = False else: res = sub for trsfm, seed in zip(trsfms.transform.transforms, seeds): if seed: res = trsfm(res, seed) else: res = trsfm(res) del res fitpars = tr.fitpars plt.figure() plt.plot(fitpars.T) plt.legend([ "trans_x", "trans_y", "trans_z", "rot_x", "rot_y", "rot_z" ]) plt.xlabel("Timesteps") plt.ylabel("Magnitude") plt.title("Motion parameters") plt.savefig(output_path) plt.close() self.written_files.append(output_path) res = sub for trsfm, seed in zip(trsfms.transform.transforms, seeds): if seed: res = trsfm(res, seed) else: res = trsfm(res) #res = trsfms(sub, seeds) return res