def save_output(self, postprocessor=None): # Currently assumes Nifti output. TODO: Make automatically detect output or determine with a class variable. # Ideally, split also this out into a saving.py function in utils. # Case naming is a little wild here, TODO: make more simple. # Some bad practice with abspath here. TODO: abspath all files on input for input_data in self.return_objects: casename = self.data_collection.data_groups[self.inputs[0]].base_casename input_affine = self.data_collection.data_groups[self.inputs[0]].base_affine augmentation_string = self.data_collection.data_groups[self.inputs[0]].augmentation_strings[-1] if self.output_directory is None: output_directory = os.path.abspath(casename) else: output_directory = os.path.abspath(self.output_directory) if os.path.exists(casename) and not os.path.isdir(casename): output_filename = os.path.basename(nifti_splitext(os.path.abspath(casename))[0]) + os.path.abspath(self.output_filename) else: output_filename = os.path.abspath(self.output_filename) if postprocessor is None: output_filepath = os.path.join(output_directory, replace_suffix(output_filename, '', augmentation_string + self.postprocessor_string)) else: output_filepath = os.path.join(output_directory, replace_suffix(output_filename, '', augmentation_string + postprocessor.postprocessor_string)) # If prediction already exists, skip it. Useful if process is interrupted. if os.path.exists(output_filepath) and not self.replace_existing: return # Squeezing is a little cagey. Maybe explicitly remove batch dimension instead. output_shape = input_data.shape input_data = np.squeeze(input_data) return_filenames = [] # If there is only one channel, only save one file. if output_shape[-1] == 1 or self.stack_outputs: self.return_filenames += [save_data(input_data, output_filepath, reference_data=input_affine)] else: for channel in range(output_shape[-1]): return_filenames += [save_numpy_2_nifti(input_data[..., channel], output_filepath=replace_suffix(output_filepath, input_suffix='', output_suffix='_channel_' + str(channel)), reference_data=input_affine)] self.return_filenames += [return_filenames] return
def save_to_disk(self, input_data, output_filenames, raw_data): for batch_idx, batch in enumerate(input_data): # Squeeze data here covers for user errors, but could result in unintended outcomes. output_shape = batch.shape batch = np.squeeze(batch) input_affine = raw_data[self.lead_key + '_affine'][batch_idx] output_filename = output_filenames[batch_idx] return_filenames = [] # If there is only one channel, only save one file. Otherwise, attempt to stack outputs or save # separate files. if output_shape[-1] == 1 or (output_shape[-1] == 3 and batch.ndim == 3) or self.stack_outputs: self.return_filenames += [ save_data(batch, output_filename, reference_data=input_affine) ] else: for channel in range(output_shape[-1]): return_filenames += [ save_data(batch[..., channel], replace_suffix(output_filename, input_suffix='', output_suffix='_channel_' + str(channel)), reference_data=input_affine) ] self.return_filenames += [return_filenames] return
def generate_output_filename(self, filename, suffix=None): if suffix is None: suffix = self.preprocessor_string # A bit hacky, here. if self.name == 'Conversion' and (filename.endswith('.nii') or filename.endswith('.nii.gz')): output_filename = filename elif self.output_folder is None: if os.path.isdir(filename): output_filename = os.path.join(filename, os.path.basename(os.path.dirname(filename) + suffix + '.nii.gz')) else: output_filename = replace_suffix(filename, '', suffix) else: if os.path.isdir(filename): output_filename = os.path.join(self.output_folder, os.path.basename(os.path.dirname(filename) + suffix + '.nii.gz')) else: output_filename = os.path.join(self.output_folder, os.path.basename(replace_suffix(filename, '', suffix))) return output_filename
def generate_output_filename(self, filename, suffix=None, file_extension='.nii.gz'): if suffix is None: suffix = self.preprocessor_string filename = os.path.abspath(filename) # A bit hacky if self.name == 'Conversion' and (filename.endswith('.nii') or filename.endswith('.nii.gz')): output_filename = filename elif self.output_folder is None: if os.path.isdir(filename): output_filename = os.path.join( filename, os.path.basename( os.path.dirname(filename) + suffix + file_extension)) else: output_filename = replace_suffix(filename, '', suffix, file_extension=file_extension) else: if os.path.isdir(filename): output_filename = os.path.join( self.output_folder, os.path.basename(filename + suffix + file_extension)) else: output_filename = os.path.join( self.output_folder, os.path.basename( replace_suffix(filename, '', suffix, file_extension=file_extension))) return cli_sanitize(output_filename)
def create_output_filenames(self): if self.current_layer is None: suffix = '_prediction' else: suffix = '_' + str(self.current_layer) self.current_output_directory = os.path.join( self.output_directory, os.path.basename(self.data_collection.get_current_casename())) if not os.path.exists(self.current_output_directory): os.mkdir(self.current_output_directory) self.current_patch_filename = os.path.join( self.current_output_directory, replace_suffix(self.patch_filename, '', suffix)) self.current_aggregate_patch_filename = os.path.join( self.current_output_directory, replace_suffix(self.aggregate_patch_filename, '', suffix)) self.current_umap_feature_output = os.path.join( self.current_output_directory, replace_suffix(self.features_filename, '', suffix)) self.current_umap_cluster_output = os.path.join( self.current_output_directory, replace_suffix(self.clusters_filename, '', suffix)) self.current_umap_plot_filename = os.path.join( self.current_output_directory, replace_suffix(self.umap_plot_filename, '', suffix)) self.current_data_filename = os.path.join( self.current_output_directory, self.data_filename) self.current_label_filename = os.path.join( self.current_output_directory, replace_suffix(self.label_filename, '', suffix)) return
def check_data(output_data=None, data_collection=None, batch_size=4, merge_batch=True, show_output=True, output_filepath=None, viz_rows=6, viz_mode_3d='2d_center', color_range=None, output_groups=None, combine_outputs=False): if data_collection is not None: generator = data_collection.data_generator(perpetual=True, verbose=False, batch_size=batch_size) output_data = next(generator) if type(output_data) is not dict: output_data = {'output_data': output_data} if color_range is None: color_range = { label: [np.min(data), np.max(data)] for label, data in output_data.items() } if output_groups is not None: output_data = { label: data for label, data in output_data.items() if label in output_groups } output_images = {} viz_rows = min(viz_rows, batch_size) viz_columns = int(np.ceil(batch_size / float(viz_rows))) for label, data in output_data.items(): if data.ndim == 5: output_images = display_3d_data(data, viz_mode_3d, label, output_images, viz_rows, viz_columns) elif data.ndim == 4: if data.shape[-1] not in [1, 3]: for i in range(data.shape[-1]): output_images[label + '_' + str(i)] = merge_data( data[..., -1][..., np.newaxis], [viz_rows, viz_columns], 1) else: output_images[label] = merge_data(data, [viz_rows, viz_columns], data.shape[-1]) if show_output: fig, axarr = plt.subplots(len(output_images.keys())) for plot_idx, (label, data) in enumerate(output_images.items()): if data.shape[-1] == 3: # Weird matplotlib bug: if np.min(data) < 0: data = (data - np.min(data)) / (np.max(data) - np.min(data)) plt_image = axarr[plot_idx].imshow(np.squeeze(data), cmap=plt.get_cmap('hot'), vmin=color_range[label][0], vmax=color_range[label][1], interpolation='none') fig.colorbar(plt_image, ax=axarr[plot_idx]) elif data.shape[-1] == 1: plt_image = axarr[plot_idx].imshow(np.squeeze(data), cmap='gray', vmin=color_range[label][0], vmax=color_range[label][1], interpolation='none') fig.colorbar(plt_image, ax=axarr[plot_idx], cmap='gray') axarr[plot_idx].set_title(label) plt.show() output_filepaths = {} for label, data in output_images.items(): output_images[label] = image_preprocess(data) if output_filepath is not None: output_filepaths[label] = save_data( output_images[label], replace_suffix(output_filepath, '', '_' + label)) return output_filepaths, output_images
def cluster_patch_data(self, input_data): if self.open_hdf5_file is not None: self.open_hdf5_file.close() open_hdf5 = tables.open_file(self.current_patch_filename, "r") output_npy_file = replace_suffix(self.current_aggregate_patch_filename, '', '_' + self.aggregation_method) # Load and aggregate data for analysis. if not os.path.exists(output_npy_file) or self.overwrite_aggregate: patch_data = self.aggregate_patch_data(open_hdf5, output_npy_file) else: patch_data = np.load(output_npy_file) print(patch_data.shape) print(input_data.shape) # Calculate Features and Clusters if self.baseline_mean_intensity: umap_features = patch_data else: if not os.path.exists(self.current_umap_feature_output ) or self.overwrite_features: umap_features = umap.UMAP( n_neighbors=30, min_dist=0.0, verbose=True).fit_transform(patch_data) np.save(self.current_umap_feature_output, umap_features) else: umap_features = np.load(self.current_umap_feature_output) if not os.path.exists( self.current_umap_cluster_output) or self.overwrite_clusters: k_clusters = KMeans( n_clusters=self.cluster_num).fit_predict(umap_features) np.save(self.current_umap_cluster_output, k_clusters) else: k_clusters = np.load(self.current_umap_cluster_output) # Plot UMAP and Save Output if not self.baseline_mean_intensity: if not os.path.exists( self.current_umap_plot_filename) or self.overwrite_plot: self.plot_umap(umap_features, clusters=k_clusters, show_plot=self.show_umap_plot, output_filename=self.current_umap_plot_filename) # Map Back to Original Data corners = open_hdf5.root.corners input_data = input_data[0, ..., 0] output_array = np.zeros_like(input_data) print(output_array.shape) print(corners.shape) for idx, coordinate in enumerate(corners): # print(coordinate) output_array[int(coordinate[0]), int(coordinate[1]), int(coordinate[2])] = k_clusters[idx] + 1 padded_points = ndimage.maximum_filter(output_array, 3) padded_points[output_array != 0] = output_array[output_array != 0] save_data(padded_points, self.current_label_filename) return
def check_data(output_data=None, data_collection=None, batch_size=4, merge_batch=True, show_output=True, output_filepath=None, viz_rows=None, viz_mode_2d=None, viz_mode_3d='2d_center', color_range=None, output_groups=None, combine_outputs=False, rgb_output=True, colorbar=True, subplot_rows=None, title=None, subplot_titles=None, **kwargs): if data_collection is not None: if batch_size > data_collection.total_cases * data_collection.multiplier: batch_size = data_collection.total_cases * data_collection.multiplier generator = data_collection.data_generator(perpetual=True, verbose=False, batch_size=batch_size) output_data = next(generator) if type(output_data) is not dict: output_data = {'output_data': output_data} if color_range is None: color_range = { label: [np.min(data), np.max(data)] for label, data in list(output_data.items()) } if output_groups is not None: output_data = { label: data for label, data in list(output_data.items()) if label in output_groups } output_images = OrderedDict() if viz_rows is None: viz_rows = int(np.ceil(np.sqrt(batch_size))) viz_rows = min(viz_rows, batch_size) viz_columns = int(np.ceil(batch_size / float(viz_rows))) for label, data in list(output_data.items()): if data.ndim == 5: output_images, color_range = display_3d_data( data, color_range, viz_mode_3d, label, output_images, viz_rows, viz_columns, subplot_titles=subplot_titles, **kwargs) elif data.ndim == 4: if data.shape[-1] == 2: for i in range(data.shape[-1]): if subplot_titles is None: subplot_title = label + '_' + str(i) else: subplot_title = subplot_titles[label][i] output_images[subplot_title] = merge_data( data[..., i][..., np.newaxis], [viz_rows, viz_columns], 1) color_range[subplot_title] = color_range[label] if data.shape[-1] not in [1, 3]: output_images[label + '_RGB'] = merge_data( data[..., 0:3], [viz_rows, viz_columns], 3) color_range[label + '_RGB'] = color_range[label] for i in range(3, data.shape[-1]): output_images[label + '_' + str(i)] = merge_data( data[..., i][..., np.newaxis], [viz_rows, viz_columns], 1) color_range[label + '_' + str(i)] = color_range[label] else: output_images[label] = merge_data(data, [viz_rows, viz_columns], data.shape[-1]) if show_output: plots = len(list(output_images.keys())) if subplot_rows is None: subplot_rows = int(np.ceil(np.sqrt(plots))) plot_columns = int(np.ceil(plots / float(subplot_rows))) fig, axarr = plt.subplots(subplot_rows, plot_columns) # matplotlib is so annoying if subplot_rows == 1 and plot_columns == 1: axarr = np.array([axarr]).reshape(1, 1) elif subplot_rows == 1 or plot_columns == 1: axarr = axarr.reshape(subplot_rows, plot_columns) for plot_idx, (label, data) in enumerate(output_images.items()): image_column = plot_idx % plot_columns image_row = plot_idx // plot_columns if data.shape[-1] == 3: # Weird matplotlib bug/feature: if np.min(data) < 0: data = (data - np.min(data)) / (np.max(data) - np.min(data)) plt_image = axarr[image_row, image_column].imshow( np.squeeze(data), cmap=plt.get_cmap('hot'), vmin=color_range[label][0], vmax=color_range[label][1], interpolation='none') if colorbar: fig.colorbar(plt_image, ax=axarr[image_row, image_column]) elif data.shape[-1] == 1: plt_image = axarr[image_row, image_column].imshow( np.squeeze(data), cmap='gray', vmin=color_range[label][0], vmax=color_range[label][1], interpolation='none') if colorbar: fig.colorbar(plt_image, ax=axarr[image_row, image_column], cmap='gray') axarr[image_row, image_column].set_title(label) for plot_idx in range(len(output_images), subplot_rows * plot_columns): image_column = plot_idx % plot_columns image_row = plot_idx // plot_columns fig.delaxes(axarr[image_row, image_column]) if title is not None: fig.suptitle(title, fontsize=28) plt.show() output_filepaths = {} for label, data in list(output_images.items()): output_images[label] = image_preprocess(data) if output_filepath is not None: output_filepaths[label] = save_data( output_images[label], replace_suffix(output_filepath, '', '_' + label)) return output_filepaths, output_images