예제 #1
0
    def get_pair_data(self):
        """Return the tuple (input, ground truth) with the data content in numpy array."""
        cache_mode = 'fill' if self.cache else 'unchanged'

        input_data = []
        for handle in self.input_handle:
            hwd_oriented = imed_loader_utils.orient_img_hwd(
                handle.get_fdata(cache_mode, dtype=np.float32),
                self.slice_axis)
            input_data.append(hwd_oriented)

        gt_data = []
        # Handle unlabeled data
        if self.gt_handle is None:
            gt_data = None
        for gt in self.gt_handle:
            if gt is not None:
                hwd_oriented = imed_loader_utils.orient_img_hwd(
                    gt.get_fdata(cache_mode, dtype=np.float32),
                    self.slice_axis)
                data_type = np.float32 if self.soft_gt else np.uint8
                gt_data.append(hwd_oriented.astype(data_type))
            else:
                gt_data.append(
                    np.zeros(imed_loader_utils.orient_shapes_hwd(
                        self.input_handle[0].shape, self.slice_axis),
                             dtype=np.float32).astype(np.uint8))

        return input_data, gt_data
예제 #2
0
    def get_pair_data(self):
        """Return the tuple (input, ground truth) with the data content in numpy array."""
        cache_mode = 'fill' if self.cache else 'unchanged'

        input_data = []
        for handle in self.input_handle:
            hwd_oriented = imed_loader_utils.orient_img_hwd(handle.get_fdata(cache_mode, dtype=np.float32),
                                                            self.slice_axis)
            input_data.append(hwd_oriented)

        gt_data = []
        # Handle unlabeled data
        if self.gt_handle is None:
            gt_data = None
        for gt in self.gt_handle:
            if gt is not None:
                if not isinstance(gt, list):  # this tissue has annotation from only one rater
                    hwd_oriented = imed_loader_utils.orient_img_hwd(gt.get_fdata(cache_mode, dtype=np.float32),
                                                                    self.slice_axis)
                    gt_data.append(hwd_oriented)
                else:  # this tissue has annotation from several raters
                    hwd_oriented_list = [
                        imed_loader_utils.orient_img_hwd(gt_rater.get_fdata(cache_mode, dtype=np.float32),
                                                         self.slice_axis) for gt_rater in gt]
                    gt_data.append([hwd_oriented.astype(data_type) for hwd_oriented in hwd_oriented_list])
            else:
                gt_data.append(
                    np.zeros(imed_loader_utils.orient_shapes_hwd(self.input_handle[0].shape, self.slice_axis),
                             dtype=np.float32).astype(np.uint8))

        return input_data, gt_data
예제 #3
0
    def get_pair_shapes(self):
        """Return the tuple (input, ground truth) representing both the input and ground truth shapes."""
        input_shape = []
        for handle in self.input_handle:
            shape = imed_loader_utils.orient_shapes_hwd(
                handle.header.get_data_shape(), self.slice_axis)
            input_shape.append(tuple(shape))

            if not len(set(input_shape)):
                raise RuntimeError('Inputs have different dimensions.')

        gt_shape = []

        for gt in self.gt_handle:
            if gt is not None:
                shape = imed_loader_utils.orient_shapes_hwd(
                    gt.header.get_data_shape(), self.slice_axis)
                gt_shape.append(tuple(shape))

                if not len(set(gt_shape)):
                    raise RuntimeError('Labels have different dimensions.')

        return input_shape[0], gt_shape[0] if len(gt_shape) else None
def run_visualization(input, config, number, output, roi):
    """Utility function to visualize Data Augmentation transformations.

    Data augmentation is a key part of the Deep Learning training scheme. This script aims at facilitating the
    fine-tuning of data augmentation parameters. To do so, this script provides a step-by-step visualization of the
    transformations that are applied on data.

    This function applies a series of transformations (defined in a configuration file
    ``-c``) to ``-n`` 2D slices randomly extracted from an input image (``-i``), and save as png the resulting sample
    after each transform.

    For example::

        ivadomed_visualize_transforms -i t2s.nii.gz -n 1 -c config.json -r t2s_seg.nii.gz

    Provides a visualization of a series of three transformation on a randomly selected slice:

    .. image:: https://raw.githubusercontent.com/ivadomed/doc-figures/main/scripts/transforms_im.png
        :width: 600px
        :align: center

    And on a binary mask::

        ivadomed_visualize_transforms -i t2s_gmseg.nii.gz -n 1 -c config.json -r t2s_seg.nii.gz

    Gives:

    .. image:: https://raw.githubusercontent.com/ivadomed/doc-figures/main/scripts/transforms_gt.png
        :width: 600px
        :align: center

    Args:
         input (string): Image filename. Flag: ``--input``, ``-i``
         config (string): Configuration file filename. Flag: ``--config``, ``-c``
         number (int): Number of slices randomly extracted. Flag: ``--number``, ``-n``
         output (string): Folder path where the results are saved. Flag: ``--ofolder``, ``-o``
         roi (string): Filename of the region of interest. Only needed if ROICrop is part of the transformations.
                       Flag: ``--roi``, ``-r``
    """
    # Load context
    context = imed_config_manager.ConfigurationManager(config).get_config()

    # Create output folder
    if not Path(output).is_dir():
        Path(output).mkdir(parents=True)

    # Slice extracted according to below axis
    axis = imed_utils.AXIS_DCT[context[ConfigKW.LOADER_PARAMETERS][LoaderParamsKW.SLICE_AXIS]]
    # Get data
    input_img, input_data = get_data(input, axis)
    # Image or Mask
    is_mask = np.array_equal(input_data, input_data.astype(bool))
    # Get zooms
    zooms = imed_loader_utils.orient_shapes_hwd(input_img.header.get_zooms(), slice_axis=axis)
    # Get indexes
    indexes = random.sample(range(0, input_data.shape[2]), number)

    # Get training transforms
    training_transforms, _, _ = imed_transforms.get_subdatasets_transforms(context[ConfigKW.TRANSFORMATION])

    if TransformationKW.ROICROP in training_transforms:
        if roi and Path(roi).is_file():
            roi_img, roi_data = get_data(roi, axis)
        else:
            raise ValueError("\nPlease provide ROI image (-r) in order to apply ROICrop transformation.")

    # Compose transforms
    dict_transforms = {}
    stg_transforms = ""
    for transform_name in training_transforms:
        # We skip NumpyToTensor transform since that s only a change of data type
        if transform_name == "NumpyToTensor":
            continue

        # Update stg_transforms
        stg_transforms += transform_name + "_"

        # Add new transform to Compose
        dict_transforms.update({transform_name: training_transforms[transform_name]})
        composed_transforms = imed_transforms.Compose(dict_transforms)

        # Loop across slices
        for i in indexes:
            data = [input_data[:, :, i]]
            # Init metadata
            metadata = SampleMetadata({MetadataKW.ZOOMS: zooms, MetadataKW.DATA_TYPE: "gt" if is_mask else "im"})

            # Apply transformations to ROI
            if TransformationKW.CENTERCROP in training_transforms or \
                    (TransformationKW.ROICROP in training_transforms and Path(roi).is_file()):
                metadata.__setitem__(MetadataKW.CROP_PARAMS, {})

            # Apply transformations to image
            stack_im, _ = composed_transforms(sample=data,
                                              metadata=[metadata for _ in range(number)],
                                              data_type="im")

            # Plot before / after transformation
            fname_out = str(Path(output, stg_transforms + "slice" + str(i) + ".png"))
            logger.debug(f"Fname out: {fname_out}.")
            logger.debug(f"\t{dict(metadata)}")
            # rescale intensities
            if len(stg_transforms[:-1].split("_")) == 1:
                before = np.rot90(imed_maths.rescale_values_array(data[0], 0.0, 1.0))
            else:
                before = after
            if isinstance(stack_im[0], torch.Tensor):
                after = np.rot90(imed_maths.rescale_values_array(stack_im[0].numpy(), 0.0, 1.0))
            else:
                after = np.rot90(imed_maths.rescale_values_array(stack_im[0], 0.0, 1.0))
            # Plot
            imed_utils.plot_transformed_sample(before,
                                               after,
                                               list_title=["\n".join(stg_transforms[:-1].split("_")[:-1]),
                                                           "\n".join(stg_transforms[:-1].split("_"))],
                                               fname_out=fname_out,
                                               cmap="jet" if is_mask else "gray")
예제 #5
0
    def get_pair_metadata(self, slice_index=0, coord=None):
        """Return dictionary containing input and gt metadata.

        Args:
            slice_index (int): Index of 2D slice if 2D model is used, else 0.
            coord (tuple or list): Coordinates of subvolume in volume if 3D model is used, else None.

        Returns:
            dict: Input and gt metadata.
        """
        gt_meta_dict = []
        for gt in self.gt_handle:
            if gt is not None:
                gt_meta_dict.append(
                    imed_loader_utils.SampleMetadata({
                        "zooms":
                        imed_loader_utils.orient_shapes_hwd(
                            gt.header.get_zooms(), self.slice_axis),
                        "data_shape":
                        imed_loader_utils.orient_shapes_hwd(
                            gt.header.get_data_shape(), self.slice_axis),
                        "gt_filenames":
                        self.metadata[0]["gt_filenames"],
                        "bounding_box":
                        self.metadata[0]["bounding_box"]
                        if 'bounding_box' in self.metadata[0] else None,
                        "data_type":
                        'gt',
                        "crop_params": {}
                    }))
            else:
                # Temporarily append null metadata to null gt
                gt_meta_dict.append(None)

        # Replace null metadata with metadata from other existing classes of the same subject
        for idx, gt_metadata in enumerate(gt_meta_dict):
            if gt_metadata is None:
                gt_meta_dict[idx] = list(filter(None, gt_meta_dict))[0]

        input_meta_dict = []
        for handle in self.input_handle:
            input_meta_dict.append(
                imed_loader_utils.SampleMetadata({
                    "zooms":
                    imed_loader_utils.orient_shapes_hwd(
                        handle.header.get_zooms(), self.slice_axis),
                    "data_shape":
                    imed_loader_utils.orient_shapes_hwd(
                        handle.header.get_data_shape(), self.slice_axis),
                    "data_type":
                    'im',
                    "crop_params": {}
                }))

        dreturn = {
            "input_metadata": input_meta_dict,
            "gt_metadata": gt_meta_dict,
        }

        for idx, metadata in enumerate(self.metadata):  # loop across channels
            metadata["slice_index"] = slice_index
            metadata["coord"] = coord
            self.metadata[idx] = metadata
            for metadata_key in metadata.keys():  # loop across input metadata
                dreturn["input_metadata"][idx][metadata_key] = metadata[
                    metadata_key]

        return dreturn
예제 #6
0
def run_visualization(input, config, number, output, roi):
    """Utility function to visualize Data Augmentation transformations.

    Data augmentation is a key part of the Deep Learning training scheme. This script aims at facilitating the
    fine-tuning of data augmentation parameters. To do so, this script provides a step-by-step visualization of the
    transformations that are applied on data.

    This function applies a series of transformations (defined in a configuration file
    ``-c``) to ``-n`` 2D slices randomly extracted from an input image (``-i``), and save as png the resulting sample
    after each transform.

    For example::

        ivadomed_visualize_transforms -i t2s.nii.gz -n 1 -c config.json -r t2s_seg.nii.gz

    Provides a visualization of a series of three transformation on a randomly selected slice:

    .. image:: ../../images/transforms_im.png
        :width: 600px
        :align: center

    And on a binary mask::

        ivadomed_visualize_transforms -i t2s_gmseg.nii.gz -n 1 -c config.json -r t2s_seg.nii.gz

    Gives:

    .. image:: ../../images/transforms_gt.png
        :width: 600px
        :align: center

    Args:
         input (string): Image filename. Flag: ``--input``, ``-i``
         config (string): Configuration file filename. Flag: ``--config``, ``-c``
         number (int): Number of slices randomly extracted. Flag: ``--number``, ``-n``
         output (string): Folder path where the results are saved. Flag: ``--ofolder``, ``-o``
         roi (string): Filename of the region of interest. Only needed if ROICrop is part of the transformations.
                       Flag: ``--roi``, ``-r``
    """
    # Load context
    with open(config, "r") as fhandle:
        context = json.load(fhandle)
    # Create output folder
    if not os.path.isdir(output):
        os.makedirs(output)

    # Slice extracted according to below axis
    axis = imed_utils.AXIS_DCT[context["loader_parameters"]["slice_axis"]]
    # Get data
    input_img, input_data = get_data(input, axis)
    # Image or Mask
    is_mask = np.array_equal(input_data, input_data.astype(bool))
    # Get zooms
    zooms = imed_loader_utils.orient_shapes_hwd(input_img.header.get_zooms(),
                                                slice_axis=axis)
    # Get indexes
    indexes = random.sample(range(0, input_data.shape[2]), number)

    # Get training transforms
    training_transforms, _, _ = imed_transforms.get_subdatasets_transforms(
        context["transformation"])

    if "ROICrop" in training_transforms:
        if roi and os.path.isfile(roi):
            roi_img, roi_data = get_data(roi, axis)
        else:
            print(
                "\nPlease provide ROI image (-r) in order to apply ROICrop transformation."
            )
            exit()

    # Compose transforms
    dict_transforms = {}
    stg_transforms = ""
    for transform_name in training_transforms:
        # We skip NumpyToTensor transform since that s only a change of data type
        if transform_name == "NumpyToTensor":
            continue

        # Update stg_transforms
        stg_transforms += transform_name + "_"

        # Add new transform to Compose
        dict_transforms.update(
            {transform_name: training_transforms[transform_name]})
        composed_transforms = imed_transforms.Compose(dict_transforms)

        # Loop across slices
        for i in indexes:
            data = [input_data[:, :, i]]
            # Init metadata
            metadata = imed_loader_utils.SampleMetadata({
                "zooms":
                zooms,
                "data_type":
                "gt" if is_mask else "im"
            })

            # Apply transformations to ROI
            if "CenterCrop" in training_transforms or (
                    "ROICrop" in training_transforms and os.path.isfile(roi)):
                metadata.__setitem__('crop_params', {})

            # Apply transformations to image
            stack_im, _ = composed_transforms(
                sample=data,
                metadata=[metadata for _ in range(number)],
                data_type="im")

            # Plot before / after transformation
            fname_out = os.path.join(
                output, stg_transforms + "slice" + str(i) + ".png")
            print("Fname out: {}.".format(fname_out))
            print("\t{}".format(dict(metadata)))
            # rescale intensities
            if len(stg_transforms[:-1].split("_")) == 1:
                before = np.rot90(
                    imed_maths.rescale_values_array(data[0], 0.0, 1.0))
            else:
                before = after
            after = np.rot90(
                imed_maths.rescale_values_array(stack_im[0], 0.0, 1.0))
            # Plot
            imed_utils.plot_transformed_sample(
                before,
                after,
                list_title=[
                    "\n".join(stg_transforms[:-1].split("_")[:-1]),
                    "\n".join(stg_transforms[:-1].split("_"))
                ],
                fname_out=fname_out,
                cmap="jet" if is_mask else "gray")