示例#1
0
    def save_output(self, postprocessor=None):

        # Currently assumes Nifti output. TODO: Make automatically detect output or determine with a class variable.
        # Ideally, split also this out into a saving.py function in utils.
        # Case naming is a little wild here, TODO: make more simple.
        # Some bad practice with abspath here. TODO: abspath all files on input

        for input_data in self.return_objects:

            casename = self.data_collection.data_groups[self.inputs[0]].base_casename
            input_affine = self.data_collection.data_groups[self.inputs[0]].base_affine

            augmentation_string = self.data_collection.data_groups[self.inputs[0]].augmentation_strings[-1]

            if self.output_directory is None:
                output_directory = os.path.abspath(casename)
            else:
                output_directory = os.path.abspath(self.output_directory)

            if os.path.exists(casename) and not os.path.isdir(casename):
                output_filename = os.path.basename(nifti_splitext(os.path.abspath(casename))[0]) + os.path.abspath(self.output_filename)
            else:
                output_filename = os.path.abspath(self.output_filename)

            if postprocessor is None:
                output_filepath = os.path.join(output_directory, replace_suffix(output_filename, '', augmentation_string + self.postprocessor_string))
            else:
                output_filepath = os.path.join(output_directory, replace_suffix(output_filename, '', augmentation_string + postprocessor.postprocessor_string))

            # If prediction already exists, skip it. Useful if process is interrupted.
            if os.path.exists(output_filepath) and not self.replace_existing:
                return

            # Squeezing is a little cagey. Maybe explicitly remove batch dimension instead.
            output_shape = input_data.shape
            input_data = np.squeeze(input_data)

            return_filenames = []

            # If there is only one channel, only save one file.
            if output_shape[-1] == 1 or self.stack_outputs:
                self.return_filenames += [save_data(input_data, output_filepath, reference_data=input_affine)]

            else:
                for channel in range(output_shape[-1]):
                    return_filenames += [save_numpy_2_nifti(input_data[..., channel], output_filepath=replace_suffix(output_filepath, input_suffix='', output_suffix='_channel_' + str(channel)), reference_data=input_affine)]
                self.return_filenames += [return_filenames]

        return
示例#2
0
    def save_to_disk(self, input_data, output_filenames, raw_data):

        for batch_idx, batch in enumerate(input_data):

            # Squeeze data here covers for user errors, but could result in unintended outcomes.
            output_shape = batch.shape
            batch = np.squeeze(batch)
            input_affine = raw_data[self.lead_key + '_affine'][batch_idx]

            output_filename = output_filenames[batch_idx]
            return_filenames = []

            # If there is only one channel, only save one file. Otherwise, attempt to stack outputs or save
            # separate files.
            if output_shape[-1] == 1 or (output_shape[-1] == 3 and batch.ndim
                                         == 3) or self.stack_outputs:
                self.return_filenames += [
                    save_data(batch,
                              output_filename,
                              reference_data=input_affine)
                ]
            else:
                for channel in range(output_shape[-1]):
                    return_filenames += [
                        save_data(batch[..., channel],
                                  replace_suffix(output_filename,
                                                 input_suffix='',
                                                 output_suffix='_channel_' +
                                                 str(channel)),
                                  reference_data=input_affine)
                    ]
                self.return_filenames += [return_filenames]

        return
示例#3
0
    def generate_output_filename(self, filename, suffix=None):

        if suffix is None:
            suffix = self.preprocessor_string

        # A bit hacky, here.
        if self.name == 'Conversion' and (filename.endswith('.nii') or filename.endswith('.nii.gz')):
            output_filename = filename
        elif self.output_folder is None:
            if os.path.isdir(filename):
                output_filename = os.path.join(filename, os.path.basename(os.path.dirname(filename) + suffix + '.nii.gz'))
            else:
                output_filename = replace_suffix(filename, '', suffix)
        else:
            if os.path.isdir(filename):
                output_filename = os.path.join(self.output_folder, os.path.basename(os.path.dirname(filename) + suffix + '.nii.gz'))
            else:
                output_filename = os.path.join(self.output_folder, os.path.basename(replace_suffix(filename, '', suffix)))

        return output_filename
示例#4
0
    def generate_output_filename(self,
                                 filename,
                                 suffix=None,
                                 file_extension='.nii.gz'):

        if suffix is None:
            suffix = self.preprocessor_string

        filename = os.path.abspath(filename)

        # A bit hacky
        if self.name == 'Conversion' and (filename.endswith('.nii')
                                          or filename.endswith('.nii.gz')):
            output_filename = filename
        elif self.output_folder is None:
            if os.path.isdir(filename):
                output_filename = os.path.join(
                    filename,
                    os.path.basename(
                        os.path.dirname(filename) + suffix + file_extension))
            else:
                output_filename = replace_suffix(filename,
                                                 '',
                                                 suffix,
                                                 file_extension=file_extension)
        else:
            if os.path.isdir(filename):
                output_filename = os.path.join(
                    self.output_folder,
                    os.path.basename(filename + suffix + file_extension))
            else:
                output_filename = os.path.join(
                    self.output_folder,
                    os.path.basename(
                        replace_suffix(filename,
                                       '',
                                       suffix,
                                       file_extension=file_extension)))

        return cli_sanitize(output_filename)
示例#5
0
    def create_output_filenames(self):

        if self.current_layer is None:
            suffix = '_prediction'
        else:
            suffix = '_' + str(self.current_layer)

        self.current_output_directory = os.path.join(
            self.output_directory,
            os.path.basename(self.data_collection.get_current_casename()))
        if not os.path.exists(self.current_output_directory):
            os.mkdir(self.current_output_directory)

        self.current_patch_filename = os.path.join(
            self.current_output_directory,
            replace_suffix(self.patch_filename, '', suffix))
        self.current_aggregate_patch_filename = os.path.join(
            self.current_output_directory,
            replace_suffix(self.aggregate_patch_filename, '', suffix))
        self.current_umap_feature_output = os.path.join(
            self.current_output_directory,
            replace_suffix(self.features_filename, '', suffix))
        self.current_umap_cluster_output = os.path.join(
            self.current_output_directory,
            replace_suffix(self.clusters_filename, '', suffix))
        self.current_umap_plot_filename = os.path.join(
            self.current_output_directory,
            replace_suffix(self.umap_plot_filename, '', suffix))
        self.current_data_filename = os.path.join(
            self.current_output_directory, self.data_filename)
        self.current_label_filename = os.path.join(
            self.current_output_directory,
            replace_suffix(self.label_filename, '', suffix))

        return
示例#6
0
def check_data(output_data=None,
               data_collection=None,
               batch_size=4,
               merge_batch=True,
               show_output=True,
               output_filepath=None,
               viz_rows=6,
               viz_mode_3d='2d_center',
               color_range=None,
               output_groups=None,
               combine_outputs=False):

    if data_collection is not None:
        generator = data_collection.data_generator(perpetual=True,
                                                   verbose=False,
                                                   batch_size=batch_size)
        output_data = next(generator)

    if type(output_data) is not dict:
        output_data = {'output_data': output_data}

    if color_range is None:
        color_range = {
            label: [np.min(data), np.max(data)]
            for label, data in output_data.items()
        }

    if output_groups is not None:
        output_data = {
            label: data
            for label, data in output_data.items() if label in output_groups
        }

    output_images = {}
    viz_rows = min(viz_rows, batch_size)
    viz_columns = int(np.ceil(batch_size / float(viz_rows)))

    for label, data in output_data.items():

        if data.ndim == 5:
            output_images = display_3d_data(data, viz_mode_3d, label,
                                            output_images, viz_rows,
                                            viz_columns)

        elif data.ndim == 4:
            if data.shape[-1] not in [1, 3]:
                for i in range(data.shape[-1]):
                    output_images[label + '_' + str(i)] = merge_data(
                        data[..., -1][..., np.newaxis],
                        [viz_rows, viz_columns], 1)

            else:
                output_images[label] = merge_data(data,
                                                  [viz_rows, viz_columns],
                                                  data.shape[-1])

    if show_output:

        fig, axarr = plt.subplots(len(output_images.keys()))

        for plot_idx, (label, data) in enumerate(output_images.items()):

            if data.shape[-1] == 3:

                # Weird matplotlib bug:
                if np.min(data) < 0:
                    data = (data - np.min(data)) / (np.max(data) -
                                                    np.min(data))

                plt_image = axarr[plot_idx].imshow(np.squeeze(data),
                                                   cmap=plt.get_cmap('hot'),
                                                   vmin=color_range[label][0],
                                                   vmax=color_range[label][1],
                                                   interpolation='none')

                fig.colorbar(plt_image, ax=axarr[plot_idx])

            elif data.shape[-1] == 1:
                plt_image = axarr[plot_idx].imshow(np.squeeze(data),
                                                   cmap='gray',
                                                   vmin=color_range[label][0],
                                                   vmax=color_range[label][1],
                                                   interpolation='none')

                fig.colorbar(plt_image, ax=axarr[plot_idx], cmap='gray')

            axarr[plot_idx].set_title(label)

        plt.show()

    output_filepaths = {}
    for label, data in output_images.items():
        output_images[label] = image_preprocess(data)
        if output_filepath is not None:
            output_filepaths[label] = save_data(
                output_images[label],
                replace_suffix(output_filepath, '', '_' + label))

    return output_filepaths, output_images
示例#7
0
    def cluster_patch_data(self, input_data):

        if self.open_hdf5_file is not None:
            self.open_hdf5_file.close()

        open_hdf5 = tables.open_file(self.current_patch_filename, "r")
        output_npy_file = replace_suffix(self.current_aggregate_patch_filename,
                                         '', '_' + self.aggregation_method)

        # Load and aggregate data for analysis.
        if not os.path.exists(output_npy_file) or self.overwrite_aggregate:
            patch_data = self.aggregate_patch_data(open_hdf5, output_npy_file)
        else:
            patch_data = np.load(output_npy_file)

        print(patch_data.shape)
        print(input_data.shape)

        # Calculate Features and Clusters
        if self.baseline_mean_intensity:
            umap_features = patch_data
        else:
            if not os.path.exists(self.current_umap_feature_output
                                  ) or self.overwrite_features:
                umap_features = umap.UMAP(
                    n_neighbors=30, min_dist=0.0,
                    verbose=True).fit_transform(patch_data)
                np.save(self.current_umap_feature_output, umap_features)
            else:
                umap_features = np.load(self.current_umap_feature_output)

        if not os.path.exists(
                self.current_umap_cluster_output) or self.overwrite_clusters:
            k_clusters = KMeans(
                n_clusters=self.cluster_num).fit_predict(umap_features)
            np.save(self.current_umap_cluster_output, k_clusters)
        else:
            k_clusters = np.load(self.current_umap_cluster_output)

        # Plot UMAP and Save Output
        if not self.baseline_mean_intensity:
            if not os.path.exists(
                    self.current_umap_plot_filename) or self.overwrite_plot:
                self.plot_umap(umap_features,
                               clusters=k_clusters,
                               show_plot=self.show_umap_plot,
                               output_filename=self.current_umap_plot_filename)

        # Map Back to Original Data
        corners = open_hdf5.root.corners
        input_data = input_data[0, ..., 0]
        output_array = np.zeros_like(input_data)
        print(output_array.shape)
        print(corners.shape)
        for idx, coordinate in enumerate(corners):
            # print(coordinate)
            output_array[int(coordinate[0]),
                         int(coordinate[1]),
                         int(coordinate[2])] = k_clusters[idx] + 1
        padded_points = ndimage.maximum_filter(output_array, 3)
        padded_points[output_array != 0] = output_array[output_array != 0]

        save_data(padded_points, self.current_label_filename)

        return
示例#8
0
def check_data(output_data=None,
               data_collection=None,
               batch_size=4,
               merge_batch=True,
               show_output=True,
               output_filepath=None,
               viz_rows=None,
               viz_mode_2d=None,
               viz_mode_3d='2d_center',
               color_range=None,
               output_groups=None,
               combine_outputs=False,
               rgb_output=True,
               colorbar=True,
               subplot_rows=None,
               title=None,
               subplot_titles=None,
               **kwargs):

    if data_collection is not None:
        if batch_size > data_collection.total_cases * data_collection.multiplier:
            batch_size = data_collection.total_cases * data_collection.multiplier

        generator = data_collection.data_generator(perpetual=True,
                                                   verbose=False,
                                                   batch_size=batch_size)
        output_data = next(generator)

    if type(output_data) is not dict:
        output_data = {'output_data': output_data}

    if color_range is None:
        color_range = {
            label: [np.min(data), np.max(data)]
            for label, data in list(output_data.items())
        }

    if output_groups is not None:
        output_data = {
            label: data
            for label, data in list(output_data.items())
            if label in output_groups
        }

    output_images = OrderedDict()

    if viz_rows is None:
        viz_rows = int(np.ceil(np.sqrt(batch_size)))

    viz_rows = min(viz_rows, batch_size)
    viz_columns = int(np.ceil(batch_size / float(viz_rows)))

    for label, data in list(output_data.items()):

        if data.ndim == 5:
            output_images, color_range = display_3d_data(
                data,
                color_range,
                viz_mode_3d,
                label,
                output_images,
                viz_rows,
                viz_columns,
                subplot_titles=subplot_titles,
                **kwargs)

        elif data.ndim == 4:

            if data.shape[-1] == 2:
                for i in range(data.shape[-1]):

                    if subplot_titles is None:
                        subplot_title = label + '_' + str(i)
                    else:
                        subplot_title = subplot_titles[label][i]

                    output_images[subplot_title] = merge_data(
                        data[..., i][..., np.newaxis], [viz_rows, viz_columns],
                        1)
                    color_range[subplot_title] = color_range[label]

            if data.shape[-1] not in [1, 3]:

                output_images[label + '_RGB'] = merge_data(
                    data[..., 0:3], [viz_rows, viz_columns], 3)
                color_range[label + '_RGB'] = color_range[label]
                for i in range(3, data.shape[-1]):
                    output_images[label + '_' + str(i)] = merge_data(
                        data[..., i][..., np.newaxis], [viz_rows, viz_columns],
                        1)
                    color_range[label + '_' + str(i)] = color_range[label]
            else:
                output_images[label] = merge_data(data,
                                                  [viz_rows, viz_columns],
                                                  data.shape[-1])

    if show_output:

        plots = len(list(output_images.keys()))
        if subplot_rows is None:
            subplot_rows = int(np.ceil(np.sqrt(plots)))
        plot_columns = int(np.ceil(plots / float(subplot_rows)))
        fig, axarr = plt.subplots(subplot_rows, plot_columns)

        # matplotlib is so annoying
        if subplot_rows == 1 and plot_columns == 1:
            axarr = np.array([axarr]).reshape(1, 1)
        elif subplot_rows == 1 or plot_columns == 1:
            axarr = axarr.reshape(subplot_rows, plot_columns)

        for plot_idx, (label, data) in enumerate(output_images.items()):

            image_column = plot_idx % plot_columns
            image_row = plot_idx // plot_columns

            if data.shape[-1] == 3:

                # Weird matplotlib bug/feature:
                if np.min(data) < 0:
                    data = (data - np.min(data)) / (np.max(data) -
                                                    np.min(data))

                plt_image = axarr[image_row, image_column].imshow(
                    np.squeeze(data),
                    cmap=plt.get_cmap('hot'),
                    vmin=color_range[label][0],
                    vmax=color_range[label][1],
                    interpolation='none')

                if colorbar:
                    fig.colorbar(plt_image, ax=axarr[image_row, image_column])

            elif data.shape[-1] == 1:
                plt_image = axarr[image_row, image_column].imshow(
                    np.squeeze(data),
                    cmap='gray',
                    vmin=color_range[label][0],
                    vmax=color_range[label][1],
                    interpolation='none')

                if colorbar:
                    fig.colorbar(plt_image,
                                 ax=axarr[image_row, image_column],
                                 cmap='gray')

            axarr[image_row, image_column].set_title(label)

        for plot_idx in range(len(output_images), subplot_rows * plot_columns):
            image_column = plot_idx % plot_columns
            image_row = plot_idx // plot_columns
            fig.delaxes(axarr[image_row, image_column])

        if title is not None:
            fig.suptitle(title, fontsize=28)

        plt.show()

    output_filepaths = {}
    for label, data in list(output_images.items()):
        output_images[label] = image_preprocess(data)
        if output_filepath is not None:
            output_filepaths[label] = save_data(
                output_images[label],
                replace_suffix(output_filepath, '', '_' + label))

    return output_filepaths, output_images