Esempio n. 1
0
    def _load_filenames(self):
        """Load preprocessed pair data (input and gt) in handler."""
        for subject_id, input_filename, gt_filename, roi_filename, metadata in self.filename_pairs:
            # Creating/ getting the subject group
            if str(subject_id) in self.hdf5_file.keys():
                grp = self.hdf5_file[str(subject_id)]
            else:
                grp = self.hdf5_file.create_group(str(subject_id))

            roi_pair = imed_loader.SegmentationPair(input_filename,
                                                    roi_filename,
                                                    metadata=metadata,
                                                    slice_axis=self.slice_axis,
                                                    cache=False,
                                                    soft_gt=self.soft_gt)

            seg_pair = imed_loader.SegmentationPair(input_filename,
                                                    gt_filename,
                                                    metadata=metadata,
                                                    slice_axis=self.slice_axis,
                                                    cache=False,
                                                    soft_gt=self.soft_gt)
            print("gt filename", gt_filename)
            input_data_shape, _ = seg_pair.get_pair_shapes()

            useful_slices = []
            input_volumes = []
            gt_volume = []
            roi_volume = []

            for idx_pair_slice in range(input_data_shape[-1]):

                slice_seg_pair = seg_pair.get_pair_slice(idx_pair_slice)

                self.has_bounding_box = imed_obj_detect.verify_metadata(
                    slice_seg_pair, self.has_bounding_box)
                if self.has_bounding_box:
                    imed_obj_detect.adjust_transforms(self.prepro_transforms,
                                                      slice_seg_pair)

                # keeping idx of slices with gt
                if self.slice_filter_fn:
                    filter_fn_ret_seg = self.slice_filter_fn(slice_seg_pair)
                if self.slice_filter_fn and filter_fn_ret_seg:
                    useful_slices.append(idx_pair_slice)

                roi_pair_slice = roi_pair.get_pair_slice(idx_pair_slice)
                slice_seg_pair, roi_pair_slice = imed_transforms.apply_preprocessing_transforms(
                    self.prepro_transforms, slice_seg_pair, roi_pair_slice)

                input_volumes.append(slice_seg_pair["input"][0])

                # Handle unlabeled data
                if not len(slice_seg_pair["gt"]):
                    gt_volume = []
                else:
                    gt_volume.append(
                        (slice_seg_pair["gt"][0] * 255).astype(np.uint8) /
                        255.)

                # Handle data with no ROI provided
                if not len(roi_pair_slice["gt"]):
                    roi_volume = []
                else:
                    roi_volume.append(
                        (roi_pair_slice["gt"][0] * 255).astype(np.uint8) /
                        255.)

            # Getting metadata using the one from the last slice
            input_metadata = slice_seg_pair['input_metadata'][0]
            gt_metadata = slice_seg_pair['gt_metadata'][0]
            roi_metadata = roi_pair_slice['input_metadata'][0]

            if grp.attrs.__contains__('slices'):
                grp.attrs['slices'] = list(
                    set(np.concatenate((grp.attrs['slices'], useful_slices))))
            else:
                grp.attrs['slices'] = useful_slices

            # Creating datasets and metadata
            contrast = input_metadata['contrast']
            # Inputs
            print(len(input_volumes))
            print("grp= ", str(subject_id))
            key = "inputs/{}".format(contrast)
            print("key = ", key)
            if len(input_volumes) < 1:
                print("list empty")
                continue
            grp.create_dataset(key, data=input_volumes)
            # Sub-group metadata
            if grp['inputs'].attrs.__contains__('contrast'):
                attr = grp['inputs'].attrs['contrast']
                new_attr = [c for c in attr]
                new_attr.append(contrast)
                grp['inputs'].attrs.create('contrast', new_attr, dtype=self.dt)

            else:
                grp['inputs'].attrs.create('contrast', [contrast],
                                           dtype=self.dt)

            # dataset metadata
            grp[key].attrs['input_filenames'] = input_metadata[
                'input_filenames']
            grp[key].attrs['data_type'] = input_metadata['data_type']

            if "zooms" in input_metadata.keys():
                grp[key].attrs["zooms"] = input_metadata['zooms']
            if "data_shape" in input_metadata.keys():
                grp[key].attrs["data_shape"] = input_metadata['data_shape']
            if "bounding_box" in input_metadata.keys():
                grp[key].attrs["bounding_box"] = input_metadata['bounding_box']

            # GT
            key = "gt/{}".format(contrast)
            grp.create_dataset(key, data=gt_volume)
            # Sub-group metadata
            if grp['gt'].attrs.__contains__('contrast'):
                attr = grp['gt'].attrs['contrast']
                new_attr = [c for c in attr]
                new_attr.append(contrast)
                grp['gt'].attrs.create('contrast', new_attr, dtype=self.dt)

            else:
                grp['gt'].attrs.create('contrast', [contrast], dtype=self.dt)

            # dataset metadata
            grp[key].attrs['gt_filenames'] = input_metadata['gt_filenames']
            grp[key].attrs['data_type'] = gt_metadata['data_type']

            if "zooms" in gt_metadata.keys():
                grp[key].attrs["zooms"] = gt_metadata['zooms']
            if "data_shape" in gt_metadata.keys():
                grp[key].attrs["data_shape"] = gt_metadata['data_shape']
            if gt_metadata['bounding_box'] is not None:
                grp[key].attrs["bounding_box"] = gt_metadata['bounding_box']

            # ROI
            key = "roi/{}".format(contrast)
            grp.create_dataset(key, data=roi_volume)
            # Sub-group metadata
            if grp['roi'].attrs.__contains__('contrast'):
                attr = grp['roi'].attrs['contrast']
                new_attr = [c for c in attr]
                new_attr.append(contrast)
                grp['roi'].attrs.create('contrast', new_attr, dtype=self.dt)

            else:
                grp['roi'].attrs.create('contrast', [contrast], dtype=self.dt)

            # dataset metadata
            grp[key].attrs['roi_filename'] = roi_metadata['gt_filenames']
            grp[key].attrs['data_type'] = roi_metadata['data_type']

            if "zooms" in roi_metadata.keys():
                grp[key].attrs["zooms"] = roi_metadata['zooms']
            if "data_shape" in roi_metadata.keys():
                grp[key].attrs["data_shape"] = roi_metadata['data_shape']

            # Adding contrast to group metadata
            if grp.attrs.__contains__('contrast'):
                attr = grp.attrs['contrast']
                new_attr = [c for c in attr]
                new_attr.append(contrast)
                grp.attrs.create('contrast', new_attr, dtype=self.dt)

            else:
                grp.attrs.create('contrast', [contrast], dtype=self.dt)
Esempio n. 2
0
    def _load_filenames(self):
        """Load preprocessed pair data (input and gt) in handler."""
        with h5py.File(self.path_hdf5, "a") as hdf5_file:
            for subject_id, input_filename, gt_filename, roi_filename, metadata in self.filename_pairs:
                # Creating/ getting the subject group
                if str(subject_id) in hdf5_file.keys():
                    grp = hdf5_file[str(subject_id)]
                else:
                    grp = hdf5_file.create_group(str(subject_id))

                roi_pair = imed_loader.SegmentationPair(
                    input_filename,
                    roi_filename,
                    metadata=metadata,
                    slice_axis=self.slice_axis,
                    cache=False,
                    soft_gt=self.soft_gt)

                seg_pair = imed_loader.SegmentationPair(
                    input_filename,
                    gt_filename,
                    metadata=metadata,
                    slice_axis=self.slice_axis,
                    cache=False,
                    soft_gt=self.soft_gt)
                print("gt filename", gt_filename)
                input_data_shape, _ = seg_pair.get_pair_shapes()

                useful_slices = []
                input_volumes = []
                gt_volume = []
                roi_volume = []

                for idx_pair_slice in range(input_data_shape[-1]):
                    slice_seg_pair, roi_pair_slice = self._slice_seg_pair(
                        idx_pair_slice, seg_pair, roi_pair, useful_slices,
                        input_volumes, gt_volume, roi_volume)

                # Getting metadata using the one from the last slice
                input_metadata = slice_seg_pair['input_metadata'][0]
                gt_metadata = slice_seg_pair['gt_metadata'][0]
                roi_metadata = roi_pair_slice['input_metadata'][0]

                if grp.attrs.__contains__('slices'):
                    grp.attrs['slices'] = list(
                        set(
                            np.concatenate(
                                (grp.attrs['slices'], useful_slices))))
                else:
                    grp.attrs['slices'] = useful_slices

                # Creating datasets and metadata
                contrast = input_metadata['contrast']

                # Inputs
                print(len(input_volumes))
                print("grp= ", str(subject_id))
                key = "inputs/{}".format(contrast)
                print("key = ", key)
                if len(input_volumes) < 1:
                    print("list empty")
                    continue
                grp.create_dataset(key, data=input_volumes)

                # Sub-group metadata
                self.create_subgrp_metadata('inputs', grp, contrast)

                # dataset metadata
                grp[key].attrs['input_filenames'] = input_metadata[
                    'input_filenames']
                self.create_metadata(grp, key, input_metadata)

                # GT
                key = "gt/{}".format(contrast)
                grp.create_dataset(key, data=gt_volume)
                # Sub-group metadata
                self.create_subgrp_metadata('gt', grp, contrast)

                # dataset metadata
                grp[key].attrs['gt_filenames'] = input_metadata['gt_filenames']
                self.create_metadata(grp, key, gt_metadata)

                # ROI
                key = "roi/{}".format(contrast)
                grp.create_dataset(key, data=roi_volume)
                # Sub-group metadata
                self.create_subgrp_metadata('roi', grp, contrast)

                # dataset metadata
                grp[key].attrs['roi_filename'] = roi_metadata['gt_filenames']
                self.create_metadata(grp, key, roi_metadata)

                # Adding contrast to group metadata
                self.add_grp_contrast(grp, contrast)
Esempio n. 3
0
def test_image_orientation():
    device = torch.device("cuda:" + str(GPU_NUMBER) if torch.cuda.is_available() else "cpu")
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.set_device(device)
        print("Using GPU number {}".format(device))

    train_lst = ['sub-unf01']

    training_transform_dict = {
        "Resample":
            {
                "wspace": 1.5,
                "hspace": 1,
                "dspace": 3
            },
        "CenterCrop":
            {
                "size": [176, 128, 160]
            },
        "NumpyToTensor": {},
        "NormalizeInstance": {"applied_to": ['im']}
    }

    tranform_lst, training_undo_transform = imed_transforms.prepare_transforms(training_transform_dict)

    model_params = {
            "name": "Modified3DUNet",
            "dropout_rate": 0.3,
            "bn_momentum": 0.9,
            "depth": 2,
            "in_channel": 1,
            "out_channel": 1,
            "length_3D": [176, 128, 160],
            "stride_3D": [176, 128, 160],
            "attention": False,
            "n_filters": 8
        }
    contrast_params = {
        "contrast_lst": ['T1w'],
        "balance": {}
    }

    for dim in ['2d', '3d']:
        for slice_axis in [0, 1, 2]:
            if dim == '2d':
                ds = imed_loader.BidsDataset(PATH_BIDS,
                                             subject_lst=train_lst,
                                             target_suffix=["_seg-manual"],
                                             contrast_params=contrast_params,
                                             metadata_choice=False,
                                             slice_axis=slice_axis,
                                             transform=tranform_lst,
                                             multichannel=False)
                ds.load_filenames()
            else:
                ds = imed_loader.Bids3DDataset(PATH_BIDS,
                                               subject_lst=train_lst,
                                               target_suffix=["_seg-manual"],
                                               model_params=model_params,
                                               contrast_params=contrast_params,
                                               metadata_choice=False,
                                               slice_axis=slice_axis,
                                               transform=tranform_lst,
                                               multichannel=False)

            loader = DataLoader(ds, batch_size=1,
                                shuffle=True, pin_memory=True,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=1)

            input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[0]
            segpair = imed_loader.SegmentationPair(input_filename, gt_filename, metadata=metadata,
                                                   slice_axis=slice_axis)
            nib_original = nib.load(gt_filename[0])
            # Get image with original, ras and hwd orientations
            input_init = nib_original.get_fdata()
            input_ras = nib.as_closest_canonical(nib_original).get_fdata()
            img, gt = segpair.get_pair_data()
            input_hwd = gt[0]

            pred_tmp_lst, z_tmp_lst = [], []
            for i, batch in enumerate(loader):
                # batch["input_metadata"] = batch["input_metadata"][0]  # Take only metadata from one input
                # batch["gt_metadata"] = batch["gt_metadata"][0]  # Take only metadata from one label

                for smp_idx in range(len(batch['gt'])):
                    # undo transformations
                    if dim == '2d':
                        preds_idx_undo, metadata_idx = training_undo_transform(batch["gt"][smp_idx],
                                                                               batch["gt_metadata"][smp_idx],
                                                                               data_type='gt')

                        # add new sample to pred_tmp_lst
                        pred_tmp_lst.append(preds_idx_undo[0])
                        z_tmp_lst.append(int(batch['input_metadata'][smp_idx][0]['slice_index']))

                    else:
                        preds_idx_undo, metadata_idx = training_undo_transform(batch["gt"][smp_idx],
                                                                               batch["gt_metadata"][smp_idx],
                                                                               data_type='gt')

                    fname_ref = metadata_idx[0]['gt_filenames'][0]

                    if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d':
                        # save the completely processed file as a nii
                        nib_ref = nib.load(fname_ref)
                        nib_ref_can = nib.as_closest_canonical(nib_ref)

                        if dim == '2d':
                            tmp_lst = []
                            for z in range(nib_ref_can.header.get_data_shape()[slice_axis]):
                                tmp_lst.append(pred_tmp_lst[z_tmp_lst.index(z)])
                            arr = np.stack(tmp_lst, axis=-1)
                        else:
                            arr = np.array(preds_idx_undo[0])

                        # verify image after transform, undo transform and 3D reconstruction
                        input_hwd_2 = imed_postpro.threshold_predictions(arr)
                        # Some difference are generated due to transform and undo transform
                        # (e.i. Resample interpolation)
                        assert imed_metrics.dice_score(input_hwd_2, input_hwd) >= 0.8
                        input_ras_2 = imed_loader_utils.orient_img_ras(input_hwd_2, slice_axis)
                        assert imed_metrics.dice_score(input_ras_2, input_ras) >= 0.8
                        input_init_2 = imed_loader_utils.reorient_image(input_hwd_2, slice_axis, nib_ref, nib_ref_can)
                        assert imed_metrics.dice_score(input_init_2, input_init) >= 0.8

                        # re-init pred_stack_lst
                        pred_tmp_lst, z_tmp_lst = [], []