Beispiel #1
0
    def save_to_disk(self, input_data, output_filenames, raw_data):

        for batch_idx, batch in enumerate(input_data):

            # Squeeze data here covers for user errors, but could result in unintended outcomes.
            output_shape = batch.shape
            batch = np.squeeze(batch)
            input_affine = raw_data[self.lead_key + '_affine'][batch_idx]

            output_filename = output_filenames[batch_idx]
            return_filenames = []

            # If there is only one channel, only save one file. Otherwise, attempt to stack outputs or save
            # separate files.
            if output_shape[-1] == 1 or (output_shape[-1] == 3 and batch.ndim
                                         == 3) or self.stack_outputs:
                self.return_filenames += [
                    save_data(batch,
                              output_filename,
                              reference_data=input_affine)
                ]
            else:
                for channel in range(output_shape[-1]):
                    return_filenames += [
                        save_data(batch[..., channel],
                                  replace_suffix(output_filename,
                                                 input_suffix='',
                                                 output_suffix='_channel_' +
                                                 str(channel)),
                                  reference_data=input_affine)
                    ]
                self.return_filenames += [return_filenames]

        return
Beispiel #2
0
def test_glioblastoma_module(testing_directory="/home/DeepNeuro/tmp", gpu_num='0'):

    import numpy as np
    import os
    from shutil import rmtree

    FLAIR, T1POST, T1PRE = np.random.normal(loc=1000, scale=200, size=(240, 240, 40)), \
                            np.random.normal(loc=1500, scale=200, size=(240, 240, 180)), \
                            np.random.normal(loc=1300, scale=200, size=(120, 120, 60))

    from deepneuro.utilities.conversion import save_data

    try:
        os.mkdir(testing_directory)
        FLAIR_file = save_data(FLAIR, os.path.join(testing_directory, 'FLAIR.nii.gz'))
        T1PRE_file = save_data(T1PRE, os.path.join(testing_directory, 'T1PRE.nii.gz'))
        T1POST_file = save_data(T1POST, os.path.join(testing_directory, 'T1POST.nii.gz'))

        from deepneuro.pipelines.Segment_GBM.predict import predict_GBM

        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_num)
        predict_GBM(testing_directory, 
                T1POST=T1POST_file, 
                FLAIR=FLAIR_file, 
                T1PRE=T1PRE_file)

        rmtree(testing_directory)

    except:
        rmtree(testing_directory)
        raise

    return
Beispiel #3
0
    def execute(self, data_collection, return_array=False):

        """ There is a lot of repeated code in the preprocessors. Think about preprocessor structures and work on this class.
        """

        if self.verbose:
            docker_print('Working on Preprocessor:', self.name)

        for label, data_group in list(self.data_groups.items()):

            self.generate_output_filenames(data_collection, data_group)

            if type(data_group.preprocessed_case) is not list:
                self.output_data = data_group.preprocessed_case            
            else:

                for file_idx, output_filename in enumerate(self.output_filenames):
                    if os.path.isdir(data_group.preprocessed_case[file_idx]):
                        if self.overwrite or not os.path.exists(output_filename):
                            array_data, affine = read_image_files(data_group.preprocessed_case[file_idx], return_affine=True)
                            # TO-DO: Check if subsetting language behaviour below has edge cases.
                            save_data(array_data[..., 0], output_filename, reference_data=affine)
                    else:
                        self.output_filenames[file_idx] = data_group.preprocessed_case[file_idx]

                data_group.preprocessed_case = self.output_filenames
                self.output_data = data_group.preprocessed_case

            if return_array:
                self.convert_to_array_data(data_group)
Beispiel #4
0
    def save_to_file(self, data_group):

        if type(self.output_data) is not list:
            for file_idx, output_filename in enumerate(self.output_filenames):
                if self.overwrite or not os.path.exists(output_filename):
                    save_data(np.squeeze(self.output_data[..., file_idx]), output_filename, reference_data=data_group.preprocessed_affine)

        return
Beispiel #5
0
    def save_to_file(self, data_group):
        """ No idea how this will work if the amount of output files is changed in a preprocessing step
            Also missing affines is a problem.
        """

        if type(self.output_data) is not list:
            for file_idx, output_filename in enumerate(self.output_filenames):
                if self.overwrite or not os.path.exists(output_filename):
                    save_data(np.squeeze(self.output_data[..., file_idx]),
                              output_filename,
                              reference_data=data_group.preprocessed_affine)

        return
Beispiel #6
0
    def save_output(self, postprocessor=None):

        # Currently assumes Nifti output. TODO: Make automatically detect output or determine with a class variable.
        # Ideally, split also this out into a saving.py function in utils.
        # Case naming is a little wild here, TODO: make more simple.
        # Some bad practice with abspath here. TODO: abspath all files on input

        for input_data in self.return_objects:

            casename = self.data_collection.data_groups[self.inputs[0]].base_casename
            input_affine = self.data_collection.data_groups[self.inputs[0]].base_affine

            augmentation_string = self.data_collection.data_groups[self.inputs[0]].augmentation_strings[-1]

            if self.output_directory is None:
                output_directory = os.path.abspath(casename)
            else:
                output_directory = os.path.abspath(self.output_directory)

            if os.path.exists(casename) and not os.path.isdir(casename):
                output_filename = os.path.basename(nifti_splitext(os.path.abspath(casename))[0]) + os.path.abspath(self.output_filename)
            else:
                output_filename = os.path.abspath(self.output_filename)

            if postprocessor is None:
                output_filepath = os.path.join(output_directory, replace_suffix(output_filename, '', augmentation_string + self.postprocessor_string))
            else:
                output_filepath = os.path.join(output_directory, replace_suffix(output_filename, '', augmentation_string + postprocessor.postprocessor_string))

            # If prediction already exists, skip it. Useful if process is interrupted.
            if os.path.exists(output_filepath) and not self.replace_existing:
                return

            # Squeezing is a little cagey. Maybe explicitly remove batch dimension instead.
            output_shape = input_data.shape
            input_data = np.squeeze(input_data)

            return_filenames = []

            # If there is only one channel, only save one file.
            if output_shape[-1] == 1 or self.stack_outputs:
                self.return_filenames += [save_data(input_data, output_filepath, reference_data=input_affine)]

            else:
                for channel in range(output_shape[-1]):
                    return_filenames += [save_numpy_2_nifti(input_data[..., channel], output_filepath=replace_suffix(output_filepath, input_suffix='', output_suffix='_channel_' + str(channel)), reference_data=input_affine)]
                self.return_filenames += [return_filenames]

        return
Beispiel #7
0
def check_data(output_data=None,
               data_collection=None,
               batch_size=4,
               merge_batch=True,
               show_output=True,
               output_filepath=None,
               viz_rows=6,
               viz_mode_3d='2d_center',
               color_range=None,
               output_groups=None,
               combine_outputs=False):

    if data_collection is not None:
        generator = data_collection.data_generator(perpetual=True,
                                                   verbose=False,
                                                   batch_size=batch_size)
        output_data = next(generator)

    if type(output_data) is not dict:
        output_data = {'output_data': output_data}

    if color_range is None:
        color_range = {
            label: [np.min(data), np.max(data)]
            for label, data in output_data.items()
        }

    if output_groups is not None:
        output_data = {
            label: data
            for label, data in output_data.items() if label in output_groups
        }

    output_images = {}
    viz_rows = min(viz_rows, batch_size)
    viz_columns = int(np.ceil(batch_size / float(viz_rows)))

    for label, data in output_data.items():

        if data.ndim == 5:
            output_images = display_3d_data(data, viz_mode_3d, label,
                                            output_images, viz_rows,
                                            viz_columns)

        elif data.ndim == 4:
            if data.shape[-1] not in [1, 3]:
                for i in range(data.shape[-1]):
                    output_images[label + '_' + str(i)] = merge_data(
                        data[..., -1][..., np.newaxis],
                        [viz_rows, viz_columns], 1)

            else:
                output_images[label] = merge_data(data,
                                                  [viz_rows, viz_columns],
                                                  data.shape[-1])

    if show_output:

        fig, axarr = plt.subplots(len(output_images.keys()))

        for plot_idx, (label, data) in enumerate(output_images.items()):

            if data.shape[-1] == 3:

                # Weird matplotlib bug:
                if np.min(data) < 0:
                    data = (data - np.min(data)) / (np.max(data) -
                                                    np.min(data))

                plt_image = axarr[plot_idx].imshow(np.squeeze(data),
                                                   cmap=plt.get_cmap('hot'),
                                                   vmin=color_range[label][0],
                                                   vmax=color_range[label][1],
                                                   interpolation='none')

                fig.colorbar(plt_image, ax=axarr[plot_idx])

            elif data.shape[-1] == 1:
                plt_image = axarr[plot_idx].imshow(np.squeeze(data),
                                                   cmap='gray',
                                                   vmin=color_range[label][0],
                                                   vmax=color_range[label][1],
                                                   interpolation='none')

                fig.colorbar(plt_image, ax=axarr[plot_idx], cmap='gray')

            axarr[plot_idx].set_title(label)

        plt.show()

    output_filepaths = {}
    for label, data in output_images.items():
        output_images[label] = image_preprocess(data)
        if output_filepath is not None:
            output_filepaths[label] = save_data(
                output_images[label],
                replace_suffix(output_filepath, '', '_' + label))

    return output_filepaths, output_images
Beispiel #8
0
    def process_case(self, input_data, model=None):

        # A little bit strange to access casename this way. Maybe make it an optional
        # return of the generator.

        # Note that input_modalities as the first input is hard-coded here. Very fragile.

        # If an image is being repatched, its output shape is not certain. We attempt to infer it from
        # the input data. This is wonky. Move this to PatchInference, maybe.

        if model is not None:
            self.model = model

        if self.channels_first:
            input_data = np.swapaxes(input_data, 1, -1)

        if self.input_channels is not None:
            input_data = np.take(input_data, self.input_channels,
                                 self.channels_dim)

        # Determine patch shape. Currently only extends to spatial patching.
        # This leading dims business has got to have a better solution..
        self.input_patch_shape = self.model.model_input_shape
        if self.output_patch_shape is None:
            self.output_patch_shape = self.model.model_output_shape

        self.input_dim = len(self.input_patch_shape) - 2

        if self.patch_dimensions is None:
            if self.channels_first:
                self.patch_dimensions = [
                    -1 * self.input_dim + x for x in range(self.input_dim)
                ]
            else:
                self.patch_dimensions = [
                    -1 * self.input_dim + x - 1 for x in range(self.input_dim)
                ]

            if self.output_patch_dimensions is None:
                self.output_patch_dimensions = self.patch_dimensions

        self.output_shape = [1] + list(
            self.model.model_output_shape)[1:]  # Weird
        for i in range(len(self.patch_dimensions)):
            self.output_shape[self.output_patch_dimensions[
                i]] = input_data.shape[self.patch_dimensions[i]]

        for layer in self.output_layers:

            self.current_layer = layer
            self.create_output_filenames()

            save_data(input_data[0, ..., 0], self.current_data_filename)

            if self.current_patch_filename is not None:
                if not os.path.exists(
                        self.current_patch_filename) or self.overwrite_patches:
                    self.create_patch_hdf5_file()
                    patch_data = self.generate_patch_data(input_data, model)
                    print(patch_data)

                if self.cluster_individual_case:
                    self.cluster_patch_data(input_data)

            if self.open_hdf5_file is not None:
                self.open_hdf5_file.close()

        return None
Beispiel #9
0
    def cluster_patch_data(self, input_data):

        if self.open_hdf5_file is not None:
            self.open_hdf5_file.close()

        open_hdf5 = tables.open_file(self.current_patch_filename, "r")
        output_npy_file = replace_suffix(self.current_aggregate_patch_filename,
                                         '', '_' + self.aggregation_method)

        # Load and aggregate data for analysis.
        if not os.path.exists(output_npy_file) or self.overwrite_aggregate:
            patch_data = self.aggregate_patch_data(open_hdf5, output_npy_file)
        else:
            patch_data = np.load(output_npy_file)

        print(patch_data.shape)
        print(input_data.shape)

        # Calculate Features and Clusters
        if self.baseline_mean_intensity:
            umap_features = patch_data
        else:
            if not os.path.exists(self.current_umap_feature_output
                                  ) or self.overwrite_features:
                umap_features = umap.UMAP(
                    n_neighbors=30, min_dist=0.0,
                    verbose=True).fit_transform(patch_data)
                np.save(self.current_umap_feature_output, umap_features)
            else:
                umap_features = np.load(self.current_umap_feature_output)

        if not os.path.exists(
                self.current_umap_cluster_output) or self.overwrite_clusters:
            k_clusters = KMeans(
                n_clusters=self.cluster_num).fit_predict(umap_features)
            np.save(self.current_umap_cluster_output, k_clusters)
        else:
            k_clusters = np.load(self.current_umap_cluster_output)

        # Plot UMAP and Save Output
        if not self.baseline_mean_intensity:
            if not os.path.exists(
                    self.current_umap_plot_filename) or self.overwrite_plot:
                self.plot_umap(umap_features,
                               clusters=k_clusters,
                               show_plot=self.show_umap_plot,
                               output_filename=self.current_umap_plot_filename)

        # Map Back to Original Data
        corners = open_hdf5.root.corners
        input_data = input_data[0, ..., 0]
        output_array = np.zeros_like(input_data)
        print(output_array.shape)
        print(corners.shape)
        for idx, coordinate in enumerate(corners):
            # print(coordinate)
            output_array[int(coordinate[0]),
                         int(coordinate[1]),
                         int(coordinate[2])] = k_clusters[idx] + 1
        padded_points = ndimage.maximum_filter(output_array, 3)
        padded_points[output_array != 0] = output_array[output_array != 0]

        save_data(padded_points, self.current_label_filename)

        return
Beispiel #10
0
def check_data(output_data=None,
               data_collection=None,
               batch_size=4,
               merge_batch=True,
               show_output=True,
               output_filepath=None,
               viz_rows=None,
               viz_mode_2d=None,
               viz_mode_3d='2d_center',
               color_range=None,
               output_groups=None,
               combine_outputs=False,
               rgb_output=True,
               colorbar=True,
               subplot_rows=None,
               title=None,
               subplot_titles=None,
               **kwargs):

    if data_collection is not None:
        if batch_size > data_collection.total_cases * data_collection.multiplier:
            batch_size = data_collection.total_cases * data_collection.multiplier

        generator = data_collection.data_generator(perpetual=True,
                                                   verbose=False,
                                                   batch_size=batch_size)
        output_data = next(generator)

    if type(output_data) is not dict:
        output_data = {'output_data': output_data}

    if color_range is None:
        color_range = {
            label: [np.min(data), np.max(data)]
            for label, data in list(output_data.items())
        }

    if output_groups is not None:
        output_data = {
            label: data
            for label, data in list(output_data.items())
            if label in output_groups
        }

    output_images = OrderedDict()

    if viz_rows is None:
        viz_rows = int(np.ceil(np.sqrt(batch_size)))

    viz_rows = min(viz_rows, batch_size)
    viz_columns = int(np.ceil(batch_size / float(viz_rows)))

    for label, data in list(output_data.items()):

        if data.ndim == 5:
            output_images, color_range = display_3d_data(
                data,
                color_range,
                viz_mode_3d,
                label,
                output_images,
                viz_rows,
                viz_columns,
                subplot_titles=subplot_titles,
                **kwargs)

        elif data.ndim == 4:

            if data.shape[-1] == 2:
                for i in range(data.shape[-1]):

                    if subplot_titles is None:
                        subplot_title = label + '_' + str(i)
                    else:
                        subplot_title = subplot_titles[label][i]

                    output_images[subplot_title] = merge_data(
                        data[..., i][..., np.newaxis], [viz_rows, viz_columns],
                        1)
                    color_range[subplot_title] = color_range[label]

            if data.shape[-1] not in [1, 3]:

                output_images[label + '_RGB'] = merge_data(
                    data[..., 0:3], [viz_rows, viz_columns], 3)
                color_range[label + '_RGB'] = color_range[label]
                for i in range(3, data.shape[-1]):
                    output_images[label + '_' + str(i)] = merge_data(
                        data[..., i][..., np.newaxis], [viz_rows, viz_columns],
                        1)
                    color_range[label + '_' + str(i)] = color_range[label]
            else:
                output_images[label] = merge_data(data,
                                                  [viz_rows, viz_columns],
                                                  data.shape[-1])

    if show_output:

        plots = len(list(output_images.keys()))
        if subplot_rows is None:
            subplot_rows = int(np.ceil(np.sqrt(plots)))
        plot_columns = int(np.ceil(plots / float(subplot_rows)))
        fig, axarr = plt.subplots(subplot_rows, plot_columns)

        # matplotlib is so annoying
        if subplot_rows == 1 and plot_columns == 1:
            axarr = np.array([axarr]).reshape(1, 1)
        elif subplot_rows == 1 or plot_columns == 1:
            axarr = axarr.reshape(subplot_rows, plot_columns)

        for plot_idx, (label, data) in enumerate(output_images.items()):

            image_column = plot_idx % plot_columns
            image_row = plot_idx // plot_columns

            if data.shape[-1] == 3:

                # Weird matplotlib bug/feature:
                if np.min(data) < 0:
                    data = (data - np.min(data)) / (np.max(data) -
                                                    np.min(data))

                plt_image = axarr[image_row, image_column].imshow(
                    np.squeeze(data),
                    cmap=plt.get_cmap('hot'),
                    vmin=color_range[label][0],
                    vmax=color_range[label][1],
                    interpolation='none')

                if colorbar:
                    fig.colorbar(plt_image, ax=axarr[image_row, image_column])

            elif data.shape[-1] == 1:
                plt_image = axarr[image_row, image_column].imshow(
                    np.squeeze(data),
                    cmap='gray',
                    vmin=color_range[label][0],
                    vmax=color_range[label][1],
                    interpolation='none')

                if colorbar:
                    fig.colorbar(plt_image,
                                 ax=axarr[image_row, image_column],
                                 cmap='gray')

            axarr[image_row, image_column].set_title(label)

        for plot_idx in range(len(output_images), subplot_rows * plot_columns):
            image_column = plot_idx % plot_columns
            image_row = plot_idx // plot_columns
            fig.delaxes(axarr[image_row, image_column])

        if title is not None:
            fig.suptitle(title, fontsize=28)

        plt.show()

    output_filepaths = {}
    for label, data in list(output_images.items()):
        output_images[label] = image_preprocess(data)
        if output_filepath is not None:
            output_filepaths[label] = save_data(
                output_images[label],
                replace_suffix(output_filepath, '', '_' + label))

    return output_filepaths, output_images