示例#1
0
 def test_axis_name_2d(self):
     path = self.get_image_path('im2d', shape=(5, 6))
     image = ScalarImage(path)
     height_idx = image.axis_name_to_index('h')
     width_idx = image.axis_name_to_index('w')
     self.assertEqual(image.height, image.shape[height_idx])
     self.assertEqual(image.width, image.shape[width_idx])
示例#2
0
    def _generate_cache(self):
        os.makedirs(self.cfg["cache_path"])

        data_map = self._create_data_map()

        for path, label_paths in tqdm(
            data_map.items(),
            total=len(data_map),
            position=0,
            leave=False,
            desc="Caching pooled data",
        ):
            subject_id = self._get_subject_id(path)

            image = ScalarImage(path, check_nans=True)
            label_map = LabelMap(label_paths, check_nans=True)
            if not isinstance(label_map, list):
                one_hot = OneHot(self.cfg["num_classes"] + 1)
                label_map = one_hot(label_map)
                label_map = LabelMap(tensor=label_map.tensor[1:])  # type: ignore

            image.load()
            label_map.load()

            image, label_map = self._resize(image, label_map)

            torch.save(
                {"subject_id": subject_id, "image": image.data, "seg": label_map.data},
                os.path.join(self.cfg["cache_path"], f"{subject_id}.pt"),
            )
示例#3
0
 def test_inconsistent_spatial_shape(self):
     subject = Subject(
         a=ScalarImage(tensor=torch.rand(1, 3, 3, 4)),
         b=ScalarImage(tensor=torch.rand(2, 2, 3, 4)),
     )
     with self.assertRaises(RuntimeError):
         subject.spatial_shape
示例#4
0
 def test_get_center(self):
     tensor = torch.rand(1, 3, 3, 3)
     image = ScalarImage(tensor=tensor)
     ras = image.get_center()
     lps = image.get_center(lps=True)
     self.assertEqual(ras, (1, 1, 1))
     self.assertEqual(lps, (-1, -1, 1))
示例#5
0
 def get_image_path(
         self,
         stem,
         binary=False,
         shape=(10, 20, 30),
         spacing=(1, 1, 1),
         components=1,
         add_nans=False,
         suffix=None,
         ):
     shape = (*shape, 1) if len(shape) == 2 else shape
     data = np.random.rand(components, *shape)
     if binary:
         data = (data > 0.5).astype(np.uint8)
     if add_nans:
         data[:] = np.nan
     affine = np.diag((*spacing, 1))
     if suffix is None:
         suffix = random.choice(('.nii.gz', '.nii', '.nrrd', '.img'))
     path = self.dir / f'{stem}{suffix}'
     if np.random.rand() > 0.5:
         path = str(path)
     image = ScalarImage(tensor=data, affine=affine, check_nans=not add_nans)
     image.save(path)
     return path
示例#6
0
def main():
    # Define training and patches sampling parameters
    num_epochs = 20
    patch_size = 128
    queue_length = 100
    patches_per_volume = 5
    batch_size = 2

    # Populate a list with images
    one_subject = Subject(
        T1=ScalarImage('../BRATS2018_crop_renamed/LGG75_T1.nii.gz'),
        T2=ScalarImage('../BRATS2018_crop_renamed/LGG75_T2.nii.gz'),
        label=LabelMap('../BRATS2018_crop_renamed/LGG75_Label.nii.gz'),
    )

    # This subject doesn't have a T2 MRI!
    another_subject = Subject(
        T1=ScalarImage('../BRATS2018_crop_renamed/LGG74_T1.nii.gz'),
        label=LabelMap('../BRATS2018_crop_renamed/LGG74_Label.nii.gz'),
    )

    subjects = [
        one_subject,
        another_subject,
    ]

    subjects_dataset = SubjectsDataset(subjects)
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        patches_per_volume,
        UniformSampler(patch_size),
    )

    # This collate_fn is needed in the case of missing modalities
    # In this case, the batch will be composed by a *list* of samples instead
    # of the typical Python dictionary that is collated by default in Pytorch
    batch_loader = DataLoader(
        queue_dataset,
        batch_size=batch_size,
        collate_fn=lambda x: x,
    )

    # Mock PyTorch model
    model = nn.Identity()

    for epoch_index in range(num_epochs):
        logging.info(f'Epoch {epoch_index}')
        for batch in batch_loader:  # batch is a *list* here, not a dictionary
            logits = model(batch)
            logging.info([batch[idx].keys() for idx in range(batch_size)])
            logging.info(logits.shape)
    logging.info('')
示例#7
0
    def get_volume_torchio(self, idx, return_orig=False):
        subject_row = self.get_row(idx)
        dict_suj = dict()
        if not pd.isna(subject_row["image_filename"]):
            path_imgs = self.read_path(subject_row["image_filename"])
            if path_imgs:
                if isinstance(path_imgs, list):
                    imgs = ScalarImage(tensor=np.asarray(
                        [nb.load(p).get_fdata() for p in path_imgs]))
                else:
                    imgs = ScalarImage(path_imgs)
                dict_suj["t1"] = imgs

        if "label_filename" in subject_row.keys() and not pd.isna(
                subject_row["label_filename"]):
            path_imgs = self.read_path(subject_row["label_filename"])
            if isinstance(path_imgs, list):
                imgs = LabelMap(tensor=np.asarray(
                    [nb.load(p).get_fdata() for p in path_imgs]))
            else:
                imgs = LabelMap(path_imgs)
            dict_suj["label"] = imgs
        sub = Subject(dict_suj)
        if "history" not in self.df_data.columns:
            return sub
        else:
            trsfms = self.get_transformations(idx)
            res = sub
            for tr in trsfms.transforms:  #.transforms:
                print(tr.name)
                if isinstance(tr, torchio.transforms.LabelsToImage):
                    tr.label_key = "label"
                if isinstance(tr, torchio.transforms.MotionFromTimeCourse):
                    output_path = opj(self.out_tmp, "{}.png".format(idx))
                    fitpars = tr.fitpars["t1"]
                    plt.figure()
                    plt.plot(fitpars.T)
                    plt.legend([
                        "trans_x", "trans_y", "trans_z", "rot_x", "rot_y",
                        "rot_z"
                    ])
                    plt.xlabel("Timesteps")
                    plt.ylabel("Magnitude")
                    plt.title("Motion parameters")
                    plt.savefig(output_path)
                    plt.close()
                    self.written_files.append(output_path)
                res = tr(res)
            res = trsfms(sub)
            return res
示例#8
0
    def get_volume_torchio_without_motion(self, idx, return_orig=False):
        subject_row = self.get_row(idx)
        dict_suj = dict()
        if not pd.isna(subject_row["image_filename"]):
            path_imgs = self.read_path(subject_row["image_filename"])
            if path_imgs:
                imgs = ScalarImage(path_imgs)
                dict_suj["t1"] = imgs

        if "label_filename" in subject_row.keys() and not pd.isna(
                subject_row["label_filename"]):
            path_imgs = self.read_path(subject_row["label_filename"])
            imgs = LabelMap(path_imgs)
            dict_suj["label"] = imgs
        sub = Subject(dict_suj)
        if "history" not in self.df_data.columns:
            return sub
        else:
            trsfms = self.get_transformations(idx)
            trsfms_short = []
            for tr in trsfms.transforms:  #.transforms:
                print(tr.name)
                if isinstance(tr, torchio.transforms.LabelsToImage):
                    tr.label_key = "label"
                if isinstance(tr, torchio.transforms.MotionFromTimeCourse):
                    tmot = tr
                    break
                trsfms_short.append(tr)
            trsfms_short = torchio.Compose(trsfms_short)
            res = trsfms_short(sub)
            return res, tmot
示例#9
0
 def get_reference_image_and_path(self):
     """Return a reference image and its path"""
     path = self.get_image_path('ref',
                                shape=(10, 20, 31),
                                spacing=(1, 1, 2))
     image = ScalarImage(path)
     return image, path
示例#10
0
 def test_with_a_list_of_paths(self):
     shape = (5, 5, 5)
     path1 = self.get_image_path('path1', shape=shape)
     path2 = self.get_image_path('path2', shape=shape)
     image = ScalarImage(path=[path1, path2])
     self.assertEqual(image.shape, (2, 5, 5, 5))
     self.assertEqual(image[STEM], ['path1', 'path2'])
示例#11
0
 def test_with_a_list_of_2d_paths(self):
     shape = (5, 6)
     path1 = self.get_image_path('path1', shape=shape, suffix='.nii')
     path2 = self.get_image_path('path2', shape=shape, suffix='.img')
     path3 = self.get_image_path('path3', shape=shape, suffix='.hdr')
     image = ScalarImage(path=[path1, path2, path3])
     self.assertEqual(image.shape, (3, 5, 6, 1))
     self.assertEqual(image[STEM], ['path1', 'path2', 'path3'])
示例#12
0
 def _resize(self, image: ScalarImage, label_map: LabelMap):
     res_subject = Resize(self.cfg["desired_size"]).apply_transform(
         Subject(
             image=ScalarImage(tensor=image.data),
             seg=LabelMap(tensor=label_map.data),
         )
     )
     return res_subject["image"], res_subject["seg"]
示例#13
0
 def get_subject_with_partial_volume_label_map(self, components=1):
     """Return a subject with a partial-volume label map."""
     return Subject(
         t1=ScalarImage(self.get_image_path('t1_d'), ),
         label=LabelMap(
             self.get_image_path('label_d2',
                                 binary=False,
                                 components=components)),
     )
示例#14
0
    def setUp(self):
        """Set up test fixtures, if any."""
        self.dir = Path(tempfile.gettempdir()) / '.torchio_tests'
        self.dir.mkdir(exist_ok=True)
        random.seed(42)
        np.random.seed(42)

        registration_matrix = np.array([
            [1, 0, 0, 10],
            [0, 1, 0, 0],
            [0, 0, 1.2, 0],
            [0, 0, 0, 1]
        ])

        subject_a = Subject(
            t1=ScalarImage(self.get_image_path('t1_a')),
        )
        subject_b = Subject(
            t1=ScalarImage(self.get_image_path('t1_b')),
            label=LabelMap(self.get_image_path('label_b', binary=True)),
        )
        subject_c = Subject(
            label=LabelMap(self.get_image_path('label_c', binary=True)),
        )
        subject_d = Subject(
            t1=ScalarImage(
                self.get_image_path('t1_d'),
                pre_affine=registration_matrix,
            ),
            t2=ScalarImage(self.get_image_path('t2_d')),
            label=LabelMap(self.get_image_path('label_d', binary=True)),
        )
        subject_a4 = Subject(
            t1=ScalarImage(self.get_image_path('t1_a'), components=2),
        )
        self.subjects_list = [
            subject_a,
            subject_a4,
            subject_b,
            subject_c,
            subject_d,
        ]
        self.dataset = SubjectsDataset(self.subjects_list)
        self.sample = self.dataset[-1]  # subject_d
示例#15
0
 def setUp(self):
     super().setUp()
     self.subjects = [
         Subject(
             image=ScalarImage(self.get_image_path(f'hs_image_{i}')),
             label=LabelMap(self.get_image_path(f'hs_label_{i}')),
         )
         for i in range(5)
     ]
     self.dataset = SubjectsDataset(self.subjects)
示例#16
0
 def get_inconsistent_shape_subject(self):
     """Return a subject containing images of different shape."""
     subject = Subject(
         t1=ScalarImage(self.get_image_path('t1_inc')),
         t2=ScalarImage(self.get_image_path('t2_inc', shape=(10, 20, 31))),
         label=LabelMap(
             self.get_image_path(
                 'label_inc',
                 shape=(8, 17, 25),
                 binary=True,
             ), ),
         label2=LabelMap(
             self.get_image_path(
                 'label2_inc',
                 shape=(18, 17, 25),
                 binary=True,
             ), ),
     )
     return subject
示例#17
0
    def test_all_random_transforms(self):
        sample = Subject(t1=ScalarImage(tensor=torch.rand(1, 20, 20, 20)),
                         seg=LabelMap(tensor=torch.rand(1, 20, 20, 20) > 1))

        transforms_names = [
            name for name in dir(torchio) if name.startswith('Random')
        ]

        # Downsample at the end so that image shape is not modified
        transforms_names.remove('RandomDownsample')
        transforms_names.append('RandomDownsample')

        transforms = []
        for transform_name in transforms_names:
            # Only transform needing an argument for __init__
            if transform_name == 'RandomLabelsToImage':
                transform = getattr(torchio, transform_name)(label_key='seg')
            else:
                transform = getattr(torchio, transform_name)()
            transforms.append(transform)
        composed_transform = torchio.Compose(transforms)
        with warnings.catch_warnings():  # ignore elastic deformation warning
            warnings.simplefilter('ignore', RuntimeWarning)
            transformed = composed_transform(sample)

        new_transforms, seeds = compose_from_history(transformed.history)
        new_transformed = self.apply_transforms(subject=sample,
                                                trsfm_list=new_transforms,
                                                seeds_list=seeds)
        """
        new_transforms = []
        seeds = []

        for transform_name, params_dict in transformed.history:
            # The Resample transform in the history comes from the DownSampling
            if transform_name in ['Resample', 'Compose']:
                continue
            transform_class = getattr(torchio, transform_name)

            if transform_name == 'RandomLabelsToImage':
                transform = transform_class(label_key='seg')
            else:
                transform = transform_class()
            new_transforms.append(transform)
            seeds.append(params_dict['seed'])

        composed_transform = torchio.Compose(new_transforms)
        with warnings.catch_warnings():  # ignore elastic deformation warning
            warnings.simplefilter('ignore', RuntimeWarning)
            new_transformed = composed_transform(sample, seeds=seeds)
        """

        self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data)
        self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data)
示例#18
0
    def get_volume_torchio(self, idx, return_orig=False):
        subject_row = self.get_row(idx)
        dict_suj = dict()
        if not pd.isna(subject_row["image_filename"]):
            path_imgs = self.read_path(subject_row["image_filename"])
            if path_imgs:
                imgs = ScalarImage(path_imgs)
                dict_suj["t1"] = imgs

        if "label_filename" in subject_row.keys() and not pd.isna(
                subject_row["label_filename"]):
            path_imgs = self.read_path(subject_row["label_filename"])
            imgs = LabelMap(path_imgs)
            dict_suj["label"] = imgs
        sub = Subject(dict_suj)
        if "history" not in self.df_data.columns:
            return sub
        else:
            trsfms = self.get_transformations(idx)
            res = sub
            for tr in trsfms.transforms:  #.transforms:
                print(tr.name)
                if isinstance(tr, torchio.transforms.LabelsToImage):
                    tr.label_key = "label"
                if isinstance(tr, torchio.transforms.MotionFromTimeCourse):
                    output_path = opj(self.out_tmp, "{}.png".format(idx))
                    fitpars = tr.fitpars["t1"]
                    plt.figure()
                    plt.plot(fitpars.T)
                    plt.legend([
                        "trans_x", "trans_y", "trans_z", "rot_x", "rot_y",
                        "rot_z"
                    ])
                    plt.xlabel("Timesteps")
                    plt.ylabel("Magnitude")
                    plt.title("Motion parameters")
                    plt.savefig(output_path)
                    plt.close()
                    self.written_files.append(output_path)
                #Bad bug fix, du to frequency_encogin_dim save without a dict ...
                if isinstance(
                        tr, torchio.transforms.augmentation.intensity.
                        random_motion_from_time_course.MotionFromTimeCourse):
                    if isinstance(tr.tr, dict):
                        if not isinstance(tr.frequency_encoding_dim, dict):
                            value = tr.frequency_encoding_dim
                            aaa = dict()
                            for k in tr.tr.keys():
                                aaa[k] = value
                            tr.frequency_encoding_dim = aaa
            res = trsfms(sub)
            return res
示例#19
0
 def setUp(self):
     super().setUp()
     subjects = []
     for i in range(5):
         image = ScalarImage(self.get_image_path(f'hs_image_{i}'))
         label_path = self.get_image_path(f'hs_label_{i}',
                                          binary=True,
                                          force_binary_foreground=True)
         label = LabelMap(label_path)
         subject = Subject(image=image, label=label)
         subjects.append(subject)
     self.subjects = subjects
     self.dataset = SubjectsDataset(self.subjects)
 def setUp(self):
     super().setUp()
     self.subjects = [
         Subject(
             image=ScalarImage(self.get_image_path(f'hs_image_{i}')),
             label=LabelMap(
                 self.get_image_path(
                     f'hs_label_{i}',
                     binary=True,
                     force_binary_foreground=True,
                 ), ),
         ) for i in range(5)
     ]
     self.dataset = SubjectsDataset(self.subjects)
示例#21
0
    def _get_subjects_list(self) -> List[Subject]:
        if not os.path.exists(self.cfg["cache_path"]):
            self._generate_cache()
            multitasking.wait_for_tasks()

        subjects = []
        for pt_name in os.listdir(self.cfg["cache_path"]):
            pt_data = torch.load(os.path.join(self.cfg["cache_path"], pt_name))

            subjects.append(
                Subject(
                    subject_id=pt_data["subject_id"],
                    image=ScalarImage(tensor=pt_data["image"]),
                    seg=LabelMap(tensor=pt_data["seg"]),
                )
            )
        return subjects
示例#22
0
def plot_subject(subject: Subject, save_plot_path: str):
    if save_plot_path:
        os.makedirs(save_plot_path, exist_ok=True)

    data_dict = {}
    sx, sy, sz = subject.spatial_shape
    sx, sy, sz = min(sx, sy, sz) / sx, min(sx, sy, sz) / sy, min(sx, sy,
                                                                 sz) / sz
    for name, image in subject.get_images_dict(intensity_only=False).items():
        if isinstance(image, LabelMap):
            data_dict[name] = LabelMap(
                tensor=squeeze_segmentation(image),
                affine=np.eye(4) * np.array([sx, sy, sz, 1]),
            )
        else:
            data_dict[name] = ScalarImage(tensor=image.data,
                                          affine=np.eye(4) *
                                          np.array([sx, sy, sz, 1]))

    out_subject = Subject(data_dict)
    out_subject.plot(reorient=False, show=True, figsize=(10, 10))

    mpl, plt = import_mpl_plt()
    backend_ = mpl.get_backend()

    plt.ioff()
    mpl.use("agg")
    for x in range(max(out_subject.spatial_shape)):
        out_subject.plot(
            reorient=False,
            indices=(
                min(x, out_subject.spatial_shape[0] - 1),
                min(x, out_subject.spatial_shape[1] - 1),
                min(x, out_subject.spatial_shape[2] - 1),
            ),
            output_path=f"{save_plot_path}/{x:03d}.png",
            show=False,
            figsize=(10, 10),
        )
        plt.close("all")
    plt.ion()
    mpl.use(backend_)

    create_gifs(save_plot_path,
                f"{save_plot_path}/{os.path.basename(save_plot_path)}.gif")
示例#23
0
def plot_aggregated_image(
    writer: SummaryWriter,
    epoch: int,
    model: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,  # type: ignore
    device: torch.device,
    save_path: str,
):
    log = logging.getLogger(__name__)

    sampler, subject_id = random_subject_from_loader(data_loader)
    aggregator_x = GridAggregator(sampler)
    aggregator_y = GridAggregator(sampler)
    aggregator_y_pred = GridAggregator(sampler)
    for batch, locations in batches_from_sampler(sampler,
                                                 data_loader.batch_size):
        x: torch.Tensor = batch["image"]["data"]
        aggregator_x.add_batch(x, locations)
        y: torch.Tensor = batch["seg"]["data"]
        aggregator_y.add_batch(y, locations)

        logits = model(x.to(device))
        y_pred = (torch.sigmoid(logits) > 0.5).float()
        aggregator_y_pred.add_batch(y_pred, locations)

    whole_x = aggregator_x.get_output_tensor()
    whole_y = aggregator_y.get_output_tensor()
    whole_y_pred = aggregator_y_pred.get_output_tensor()

    plot_subject(
        Subject(
            image=ScalarImage(tensor=whole_x),
            true_seg=LabelMap(tensor=whole_y),
            pred_seg=LabelMap(tensor=whole_y_pred),
        ),
        f"{save_path}/{epoch}-{subject_id}",
    )
示例#24
0
 def test_bad_affine(self):
     with self.assertRaises(ValueError):
         ScalarImage(tensor=torch.rand(1, 2, 3, 4), affine=np.eye(3))
示例#25
0
 def test_repr(self):
     sample = Subject(t1=ScalarImage(self.get_image_path('repr_test')))
     assert 'shape' not in repr(sample['t1'])
     sample.load()
     assert 'shape' in repr(sample['t1'])
示例#26
0
 def test_bad_key(self):
     with self.assertRaises(ValueError):
         ScalarImage(path='', data=5)
示例#27
0
 def test_no_input(self):
     with self.assertRaises(ValueError):
         ScalarImage()
示例#28
0
 def test_crop_scalar_image_type(self):
     data = torch.ones((10, 10, 10))
     image = ScalarImage(tensor=data)
     cropped = image.crop((1, 1, 1), (5, 5, 5))
     self.assertIs(cropped.type, INTENSITY)
示例#29
0
 def test_wrong_scalar_image_type(self):
     data = torch.ones((10, 10, 10))
     with self.assertRaises(ValueError):
         ScalarImage(tensor=data, type=LABEL)
示例#30
0
 def test_scalar_image_type(self):
     data = torch.ones((10, 10, 10))
     image = ScalarImage(tensor=data)
     self.assertIs(image.type, INTENSITY)