def main(args, max_workers=3): signal_paths = args.signal_planes_paths[args.signal_channel] background_paths = args.background_planes_path[0] signal_images = get_sorted_file_paths(signal_paths, file_extension="tif") background_images = get_sorted_file_paths(background_paths, file_extension="tif") # Too many workers doesn't increase speed, and uses huge amounts of RAM workers = get_num_processes(min_free_cpu_cores=args.n_free_cpus, n_max_processes=max_workers) logging.debug("Initialising cube generator") inference_generator = CubeGeneratorFromFile( args.paths.detected_points, signal_images, background_images, args.voxel_sizes, args.network_voxel_sizes, batch_size=args.batch_size, cube_width=args.cube_width, cube_height=args.cube_height, cube_depth=args.cube_depth, ) model = get_model( existing_model=args.trained_model, model_weights=args.model_weights, network_depth=models[args.network_depth], inference=True, ) logging.info("Running inference") predictions = model.predict( inference_generator, use_multiprocessing=True, workers=workers, verbose=True, ) predictions = predictions.round() predictions = predictions.astype("uint16") predictions = np.argmax(predictions, axis=1) cells_list = [] # only go through the "extractable" cells for idx, cell in enumerate(inference_generator.ordered_cells): cell.type = predictions[idx] + 1 cells_list.append(cell) logging.info("Saving classified cells") save_cells(cells_list, args.paths.classified_points, save_csv=args.save_csv) try: get_cells(args.paths.classified_points, cells_only=True) return True except MissingCellsError: return False
def get_results(self): logging.info("Splitting cell clusters and writing results") max_cell_volume = sphere_volume(self.soma_size_spread_factor * self.soma_diameter / 2) cells = [] for ( cell_id, cell_points, ) in self.cell_detector.get_coords_list().items(): cell_volume = len(cell_points) if cell_volume < max_cell_volume: cell_centre = get_structure_centre_wrapper(cell_points) cells.append( Cell( (cell_centre["x"], cell_centre["y"], cell_centre["z"]), Cell.UNKNOWN, )) else: if cell_volume < self.max_cluster_size: try: cell_centres = split_cells( cell_points, outlier_keep=self.outlier_keep) except (ValueError, AssertionError) as err: raise StructureSplitException( f"Cell {cell_id}, error; {err}") for cell_centre in cell_centres: cells.append( Cell( ( cell_centre["x"], cell_centre["y"], cell_centre["z"], ), Cell.UNKNOWN, )) else: cell_centre = get_structure_centre_wrapper(cell_points) cells.append( Cell( ( cell_centre["x"], cell_centre["y"], cell_centre["z"], ), Cell.ARTIFACT, )) xml_file_path = os.path.join(self.output_folder, self.output_file + ".xml") save_cells( cells, xml_file_path, save_csv=self.save_csv, artifact_keep=self.artifact_keep, )
def xml_scale( xml_file, x_scale=1, y_scale=1, z_scale=1, output_directory=None, integer=True, ): # TODO: add a csv option """ To rescale the cell positions within an XML file. For compatibility with other software, or if data has been scaled after cell detection. :param xml_file: Any cellfinder xml file :param x_scale: Rescaling factor in the first dimension :param y_scale: Rescaling factor in the second dimension :param z_scale: Rescaling factor in the third dimension :param output_directory: Directory to save the rescaled XML file. Defaults to the same directory as the input XML file :param integer: Force integer cell positions (default: True) :return: """ if x_scale == y_scale == z_scale == 1: raise CommandLineInputError( "All rescaling factors are 1, " "please check the input." ) else: input_file = Path(xml_file) start_time = datetime.now() cells = cell_io.get_cells(xml_file) for cell in cells: cell.transform( x_scale=x_scale, y_scale=y_scale, z_scale=z_scale, integer=integer, ) if output_directory: output_directory = Path(output_directory) else: output_directory = input_file.parent ensure_directory_exists(output_directory) output_filename = output_directory / (input_file.stem + "_rescaled") output_filename = output_filename.with_suffix(input_file.suffix) cell_io.save_cells(cells, output_filename) print( "Finished. Total time taken: {}".format( datetime.now() - start_time ) )
def write_points_to_xml(path, data, metadata): cells_to_save = [] if metadata["metadata"]["point_type"] == Cell.CELL: cells_to_save.extend(convert_layer_to_cells(data)) elif metadata["metadata"]["point_type"] == Cell.UNKNOWN: cells_to_save.extend(convert_layer_to_cells(data, cells=False)) if cells_to_save: save_cells(cells_to_save, path) return path
def write_multiple_points_to_xml( path: str, layer_data: List[Tuple[Any, Dict, str]]) -> str: cells_to_save = [] for layer in layer_data: data, state, type = layer if state["metadata"]["point_type"] == Cell.CELL: cells_to_save.extend(convert_layer_to_cells(data)) elif state["metadata"]["point_type"] == Cell.UNKNOWN: cells_to_save.extend(convert_layer_to_cells(data, cells=False)) if cells_to_save: save_cells(cells_to_save, path) return path
def save_cell_count(self): self.status_label.setText("Saving cells") print("Saving cells") self.get_output_directory() filename = self.output_directory / "cells.xml" cells_to_save = [] for idx, point in enumerate(self.cell_layer.data): cell = Cell([point[2], point[1], point[0]], Cell.CELL) cells_to_save.append(cell) save_cells(cells_to_save, str(filename)) self.status_label.setText("Ready") print("Done!")
def transform_cells_to_standard_space(args): if args.registration_config is None: args.registration_config = source_custom_config_cellfinder() reg_params = RegistrationParams( args.registration_config, affine_n_steps=args.affine_n_steps, affine_use_n_steps=args.affine_use_n_steps, freeform_n_steps=args.freeform_n_steps, freeform_use_n_steps=args.freeform_use_n_steps, bending_energy_weight=args.bending_energy_weight, grid_spacing=args.grid_spacing, smoothing_sigma_reference=args.smoothing_sigma_reference, smoothing_sigma_floating=args.smoothing_sigma_floating, histogram_n_bins_floating=args.histogram_n_bins_floating, histogram_n_bins_reference=args.histogram_n_bins_reference, ) generate_deformation_field(args, reg_params) cells_only = not args.transform_all cells = get_cells( args.paths.classification_out_file, cells_only=cells_only ) logging.info("Loading deformation field") deformation_field = load_any_image( args.paths.tmp__deformation_field, as_numpy=True ) scales = get_scales(args, reg_params) field_scales = get_deformation_field_scales(reg_params) logging.info("Transforming cell positions") transformed_cells = transform_cell_positions( cells, deformation_field, field_scales, scales ) logging.info("Saving transformed cell positions") save_cells( transformed_cells, args.paths.cells_in_standard_space, save_csv=args.save_csv, ) if not args.debug: logging.info("Removing standard space transformation temp files") delete_temp(args.paths.standard_space_output_folder, args.paths)
def save_curation(viewer): """Save file""" if not CURATED_POINTS: print("No cells have been confirmed or toggled, not saving") else: unique_cells = unique_elements_lists(CURATED_POINTS) points = viewer.layers[1].data[unique_cells] labels = viewer.layers[1].properties["cell"][unique_cells] labels = labels.astype("int") labels = labels + 1 cells_to_save = [] for idx, point in enumerate(points): cell = Cell([point[2], point[1], point[0]], labels[idx]) cells_to_save.append(cell) print(f"Saving results to: {output_filename}") save_cells(cells_to_save, output_filename)
def run_all(args, what_to_run, atlas): from cellfinder_core.detect import detect from cellfinder_core.classify import classify from cellfinder_core.tools import prep from cellfinder_core.tools.IO import read_with_dask from cellfinder.analyse import analyse from cellfinder.figures import figures from cellfinder.tools.prep import ( prep_candidate_detection, prep_channel_specific_general, ) points = None signal_array = None args, what_to_run = prep_channel_specific_general(args, what_to_run) if what_to_run.detect: logging.info("Detecting cell candidates") args = prep_candidate_detection(args) signal_array = read_with_dask( args.signal_planes_paths[args.signal_channel] ) points = detect.main( signal_array, args.start_plane, args.end_plane, args.voxel_sizes, args.soma_diameter, args.max_cluster_size, args.ball_xy_size, args.ball_z_size, args.ball_overlap_fraction, args.soma_spread_factor, args.n_free_cpus, args.log_sigma_size, args.n_sds_above_mean_thresh, ) ensure_directory_exists(args.paths.points_directory) save_cells( points, args.paths.detected_points, save_csv=args.save_csv, artifact_keep=args.artifact_keep, ) else: logging.info("Skipping cell detection") points = get_cells(args.paths.detected_points) if what_to_run.classify: model_weights = prep.prep_classification( args.trained_model, args.model_weights, args.install_path, args.model, args.n_free_cpus, ) if what_to_run.classify: if points is None: points = get_cells(args.paths.detected_points) if signal_array is None: signal_array = read_with_dask( args.signal_planes_paths[args.signal_channel] ) logging.info("Running cell classification") background_array = read_with_dask(args.background_planes_path[0]) points = classify.main( points, signal_array, background_array, args.n_free_cpus, args.voxel_sizes, args.network_voxel_sizes, args.batch_size, args.cube_height, args.cube_width, args.cube_depth, args.trained_model, model_weights, args.network_depth, ) save_cells( points, args.paths.classified_points, save_csv=args.save_csv, ) what_to_run.cells_exist = cells_exist(args.paths.classified_points) else: logging.info("No cells were detected, skipping classification.") else: logging.info("Skipping cell classification") what_to_run.update_if_cells_required() if what_to_run.analyse or what_to_run.figures: downsampled_space = get_downsampled_space( atlas, args.brainreg_paths.boundaries_file_path ) if what_to_run.analyse: points = get_cells(args.paths.classified_points, cells_only=True) if len(points) == 0: logging.info("No cells detected, skipping cell position analysis") else: logging.info("Analysing cell positions") analyse.run(args, points, atlas, downsampled_space) else: logging.info("Skipping cell position analysis") if what_to_run.figures: points = get_cells(args.paths.detected_points, cells_only=True) if len(points) == 0: logging.info("No cells detected, skipping") else: logging.info("Generating figures") figures.run(args, atlas, downsampled_space.shape) else: logging.info("Skipping figure generation")