def get_local_segmentation(self, roi: daisy.Roi, threshold: float): # open fragments fragments = daisy.open_ds(self.fragments_file, self.fragments_dataset) # open RAG DB rag_provider = MongoDbRagProvider( self.fragments_db, host=self.fragments_host, mode="r", edges_collection=self.edges_collection, ) segmentation = fragments[roi] segmentation.materialize() ids = [int(id) for id in list(np.unique(segmentation.data))] rag = rag_provider.read_rag(ids) if len(rag.nodes()) == 0: raise Exception('RAG is empty') components = rag.get_connected_components(threshold) values_map = np.array( [[fragment, i] for i in range(1, len(components) + 1) for fragment in components[i - 1]], dtype=np.uint64, ) old_values = values_map[:, 0] new_values = values_map[:, 1] replace_values(segmentation.data, old_values, new_values, inplace=True) return segmentation
def get_segmentation( fragments, fragments_file, lut_fragment_segment, edges_collection, threshold): logging.info( "Loading fragment - segment lookup table for threshold %s..." % threshold) fragment_segment_lut_dir = os.path.join( fragments_file, lut_fragment_segment) fragment_segment_lut_file = os.path.join( fragment_segment_lut_dir, 'seg_%s_%d.npz' % (edges_collection, int(threshold * 100))) fragment_segment_lut = np.load( fragment_segment_lut_file)['fragment_segment_lut'] assert fragment_segment_lut.dtype == np.uint64 # fragments = fragments.to_ndarray(block.write_roi) logging.info("Relabeling fragment ids with segment ids...") segment_ids = replace_values( fragments, fragment_segment_lut[0], fragment_segment_lut[1]) return segment_ids
def get_segmentation(fragments, fragments_file, edges_collection, threshold, run_type): logging.info(f"Loading fragment - segment lookup table for threshold \ {threshold}...") fragment_segment_lut_dir = os.path.join(fragments_file, 'luts', 'fragment_segment') if run_type: logging.info(f"Run type set, evaluating on {run_type} dataset") fragment_segment_lut_dir = os.path.join(fragment_segment_lut_dir, run_type) fragment_segment_lut_file = os.path.join( fragment_segment_lut_dir, f'seg_{edges_collection}_{int(threshold*100)}.npz') fragment_segment_lut = np.load( fragment_segment_lut_file)['fragment_segment_lut'] assert fragment_segment_lut.dtype == np.uint64 logging.info("Relabeling fragment ids with segment ids...") segment_ids = replace_values(fragments, fragment_segment_lut[0], fragment_segment_lut[1]) return segment_ids
def get_site_segment_ids(self, threshold): # get fragment-segment LUT logging.info("Reading fragment-segment LUT...") start = time.time() fragment_segment_lut_dir = os.path.join(self.fragments_file, 'luts/fragment_segment') if self.run_type: logging.info(f"Using lookup tables for {self.run_type} data") fragment_segment_lut_dir = os.path.join(fragment_segment_lut_dir, self.run_type) logging.info("Reading fragment segment luts from: " f"{fragment_segment_lut_dir}") fragment_segment_lut_file = os.path.join( fragment_segment_lut_dir, 'seg_%s_%d.npz' % (self.edges_collection, int(threshold * 100))) fragment_segment_lut = np.load( fragment_segment_lut_file)['fragment_segment_lut'] assert fragment_segment_lut.dtype == np.uint64 # get the segment ID for each site logging.info("Mapping sites to segments...") site_mask = np.isin(fragment_segment_lut[0], self.site_fragment_ids) site_segment_ids = replace_values(self.site_fragment_ids, fragment_segment_lut[0][site_mask], fragment_segment_lut[1][site_mask]) return site_segment_ids, fragment_segment_lut
def __relabel(self, array, components, component_labels): old_values = [] new_values = [] for component, label in zip(components, component_labels): for c in component: old_values.append(c) new_values.append(label) array[:] = replace_values(array, old_values, new_values)
def segment_in_block(block, fragments_file, segmentation, fragments, lut): logging.info("Copying fragments to memory...") # load fragments fragments = fragments.to_ndarray(block.write_roi) # replace values, write to empty array relabelled = np.zeros_like(fragments) relabelled = replace_values(fragments, lut[0], lut[1], out_array=relabelled) segmentation[block.write_roi] = relabelled
def segment_in_block(block, fragments_file, segmentation, fragments, lut): logging.info("Copying fragments to memory...") start = time.time() fragments = fragments.to_ndarray(block.write_roi) logging.info("%.3fs" % (time.time() - start)) # get segments num_segments = len(np.unique(lut[1])) logging.info("Relabelling fragments to %d segments", num_segments) start = time.time() relabelled = replace_values(fragments, lut[0], lut[1]) logging.info("%.3fs" % (time.time() - start)) segmentation[block.write_roi] = relabelled
def watershed_in_block( affs, block, context, rag_provider, fragments_out, num_voxels_in_block, mask=None, fragments_in_xy=False, epsilon_agglomerate=0.0, filter_fragments=0.0, min_seed_distance=10, replace_sections=None): ''' Args: filter_fragments (float): Filter fragments that have an average affinity lower than this value. min_seed_distance (int): Controls distance between seeds in the initial watershed. Reducing this value improves downsampled segmentation. ''' total_roi = affs.roi logger.debug("reading affs from %s", block.read_roi) affs = affs.intersect(block.read_roi) affs.materialize() if affs.dtype == np.uint8: logger.info("Assuming affinities are in [0,255]") max_affinity_value = 255.0 affs.data = affs.data.astype(np.float32) else: max_affinity_value = 1.0 if mask is not None: logger.debug("reading mask from %s", block.read_roi) mask_data = get_mask_data_in_roi(mask, affs.roi, affs.voxel_size) logger.debug("masking affinities") affs.data *= mask_data # extract fragments fragments_data, _ = watershed_from_affinities( affs.data, max_affinity_value, fragments_in_xy=fragments_in_xy, min_seed_distance=min_seed_distance) if mask is not None: fragments_data *= mask_data.astype(np.uint64) if filter_fragments > 0: if fragments_in_xy: average_affs = np.mean(affs.data[0:2]/max_affinity_value, axis=0) else: average_affs = np.mean(affs.data/max_affinity_value, axis=0) filtered_fragments = [] fragment_ids = np.unique(fragments_data) for fragment, mean in zip( fragment_ids, measurements.mean( average_affs, fragments_data, fragment_ids)): if mean < filter_fragments: filtered_fragments.append(fragment) filtered_fragments = np.array( filtered_fragments, dtype=fragments_data.dtype) replace = np.zeros_like(filtered_fragments) replace_values(fragments_data, filtered_fragments, replace, inplace=True) if epsilon_agglomerate > 0: logger.info( "Performing initial fragment agglomeration until %f", epsilon_agglomerate) generator = waterz.agglomerate( affs=affs.data/max_affinity_value, thresholds=[epsilon_agglomerate], fragments=fragments_data, scoring_function='OneMinus<HistogramQuantileAffinity<RegionGraphType, 25, ScoreValue, 256, false>>', discretize_queue=256, return_merge_history=False, return_region_graph=False) fragments_data[:] = next(generator) # cleanup generator for _ in generator: pass if replace_sections: logger.info("Replacing sections...") block_begin = block.write_roi.get_begin() shape = block.write_roi.get_shape() z_context = context[0]/affs.voxel_size[0] logger.info("Z context: %i",z_context) mapping = {} voxel_offset = block_begin[0]/affs.voxel_size[0] for i,j in zip( range(fragments_data.shape[0]), range(shape[0])): mapping[i] = i mapping[j] = int(voxel_offset + i) \ if block_begin[0] == total_roi.get_begin()[0] \ else int(voxel_offset + (i - z_context)) logging.info('Mapping: %s', mapping) replace = [k for k,v in mapping.items() if v in replace_sections] for r in replace: logger.info("Replacing mapped section %i with zero", r) fragments_data[r] = 0 #todo add key value replacement option fragments = daisy.Array(fragments_data, affs.roi, affs.voxel_size) # crop fragments to write_roi fragments = fragments[block.write_roi] fragments.materialize() max_id = fragments.data.max() # ensure we don't have IDs larger than the number of voxels (that would # break uniqueness of IDs below) if max_id > num_voxels_in_block: logger.warning( "fragments in %s have max ID %d, relabelling...", block.write_roi, max_id) fragments.data, max_id = relabel(fragments.data) assert max_id < num_voxels_in_block # ensure unique IDs id_bump = block.block_id[1]*num_voxels_in_block logger.debug("bumping fragment IDs by %i", id_bump) fragments.data[fragments.data>0] += id_bump fragment_ids = range(id_bump + 1, id_bump + 1 + int(max_id)) # store fragments logger.debug("writing fragments to %s", block.write_roi) fragments_out[block.write_roi] = fragments # following only makes a difference if fragments were found if max_id == 0: return # get fragment centers fragment_centers = { fragment: block.write_roi.get_offset() + affs.voxel_size*daisy.Coordinate(center) for fragment, center in zip( fragment_ids, measurements.center_of_mass(fragments.data, fragments.data, fragment_ids)) if not np.isnan(center[0]) } # store nodes rag = rag_provider[block.write_roi] rag.add_nodes_from([ (node, { 'center_z': c[0], 'center_y': c[1], 'center_x': c[2] } ) for node, c in fragment_centers.items() ]) rag.write_nodes(block.write_roi)
def parse_rag_excerpt(self, nodes_list, edges_list): # TODO parametrize the used names id_field = 'id' node1_field = 'u' node2_field = 'v' merge_score_field = 'merge_score' gt_merge_score_field = 'gt_merge_score' merge_labeled_field = 'merge_labeled' # TODO remove duplicate code, this is also used in hemibrain_graph def to_np_arrays(inp): d = {} for i in inp: for k, v in i.items(): d.setdefault(k, []).append(v) for k, v in d.items(): d[k] = np.array(v) return d node_attrs = to_np_arrays(nodes_list) # TODO maybe port to numpy, but generally fast # Drop edges for which one of the incident nodes is not in the # extracted node set start = time.time() for e in reversed(edges_list): if e[node1_field] not in node_attrs[id_field] or e[ node2_field] not in node_attrs[id_field]: edges_list.remove(e) logger.debug(f'drop edges at the border in {time.time() - start}s') # If all edges were removed in the step above, raise a ValueError # that is caught later on if len(edges_list) == 0: raise ValueError( f'Removed all edges in ROI, as one node is outside of ROI') edges_attrs = to_np_arrays(edges_list) node_ids_np = node_attrs[id_field].astype(np.int64) node_ids = torch.tensor(node_ids_np, dtype=torch.long) # By not operating inplace and providing out_array, we always use # the C++ implementation of replace_values logger.debug( f'before: interval {node_ids_np.max() - node_ids_np.min()}, min id {node_ids_np.min()}, max id {node_ids_np.max()}, shape {node_ids_np.shape}' ) start = time.time() edges_node1 = np.zeros_like(edges_attrs[node1_field], dtype=np.int64) edges_node1 = replace_values( in_array=edges_attrs[node1_field].astype(np.int64), old_values=node_ids_np, new_values=np.arange(len(node_attrs[id_field]), dtype=np.int64), inplace=False, out_array=edges_node1) edges_attrs[node1_field] = edges_node1 logger.debug( f'remapping {len(edges_attrs[node1_field])} edges (u) in {time.time() - start} s' ) logger.debug( f'edges after: min id {edges_attrs[node1_field].min()}, max id {edges_attrs[node1_field].max()}' ) start = time.time() edges_node2 = np.zeros_like(edges_attrs[node2_field], dtype=np.int64) edges_node2 = replace_values( in_array=edges_attrs[node2_field].astype(np.int64), old_values=node_ids_np, new_values=np.arange(len(node_attrs[id_field]), dtype=np.int64), inplace=False, out_array=edges_node2) edges_attrs[node2_field] = edges_node2 logger.debug( f'remapping {len(edges_attrs[node2_field])} edges (v) in {time.time() - start} s' ) logger.debug( f'edges after: min id {edges_attrs[node2_field].min()}, max id {edges_attrs[node2_field].max()}' ) # TODO I could potentially avoid transposing twice # edge index requires dimensionality of (2,e) # pyg works with directed edges, duplicate each edge here edge_index_undir = np.array( [edges_attrs[node1_field], edges_attrs[node2_field]]).transpose() edge_index_dir = np.repeat(edge_index_undir, 2, axis=0) edge_index_dir[1::2, :] = np.flip(edge_index_dir[1::2, :], axis=1) edge_index = torch.tensor(edge_index_dir.astype(np.int64).transpose(), dtype=torch.long) edge_attr_undir = np.expand_dims(edges_attrs[merge_score_field], axis=1) edge_attr_dir = np.repeat(edge_attr_undir, 2, axis=0) edge_attr = torch.tensor(edge_attr_dir, dtype=torch.float) pos = torch.transpose(input=torch.tensor([ node_attrs['center_z'], node_attrs['center_y'], node_attrs['center_x'] ], dtype=torch.float), dim0=0, dim1=1) # TODO node features go here x = torch.ones(len(node_attrs[id_field]), 1, dtype=torch.float) # Targets operate on undirected edges, therefore no duplicate necessary mask = torch.tensor(edges_attrs[merge_labeled_field], dtype=torch.float) y = torch.tensor(edges_attrs[gt_merge_score_field], dtype=torch.long) return edge_index, edge_attr, x, pos, node_ids, mask, y
def simulate_random_cages(volume, segmentation, cages, min_density, max_density, fm_intensity, point_spread_function, return_cage_map=False, return_density_map=False, no_cage_probability=0.0): '''Randomly render cages with a range of densities for each segment into a volume. Args: volume (Volume): The volume to render to. The volume is expected to be real valued with values between 0 and 1. segmentation (Volume): A segmentation of the volume. The segmentation is expected to be int valued with values between 1 and n. 0 will be treated as background. cages (list of Cages): A list of cages to randomly select from. min_density, max_density (float): The minimum and maximum density to uniformly choose from. fm_intensity (float): Render intensity for element 100 (Fermium), to be used as reference point for cubic intensity transfer function. point_spread_function (PointSpreadFunction): The PSF to use to render points. return_cage_map (bool): Return a map of which segment contains which type of cage (as an integer). return_density_map (bool): Return a map of the cage densities per segment. no_cage_probability (float): The probability of expressing no cage, per segment. ''' assert (volume.data.min() >= 0 and volume.data.max() <= 1) id_list = np.unique(segmentation.data) id_list = id_list[np.nonzero(id_list)] random_cages = {} random_densities = {} for id_element in id_list: test = random.random() if test > no_cage_probability: random_cages[id_element] = random.choice(cages) random_densities[id_element] = random.uniform( min_density, max_density) else: random_cages[id_element] = None random_densities[id_element] = 0 simulate_cages(volume, segmentation, random_cages, random_densities, fm_intensity, point_spread_function) ret = () if return_cage_map: # replace segmentation IDs with cage IDs cage_map = replace_values(segmentation.data, id_list, [ random_cages[i].cage_id if random_cages[i] else 0 for i in id_list ]) ret = ret + (cage_map, ) if return_density_map: densities = np.array([random_densities[i] for i in id_list], dtype=np.float64) # (almost) the same for the density map: density_map = replace_values(segmentation.data.astype(np.uint64), id_list.astype(np.uint64), densities.view(np.uint64)).view( np.float64) density_map = density_map.astype(np.float32) ret = ret + (density_map, ) if len(ret) > 0: return ret
def atexit_tasks(model): # ----------------------------------------------- # ---------------- EVALUATION ROUTINE ----------- # ----------------------------------------------- _log.info('saving tensorboardx summary files ...') # save the tensorboardx summary files summary_dir_exit = os.path.join(config.run_abs_path, config.summary_dir) summary_compressed = summary_dir_exit + '.tar.gz' # remove old tar file if os.path.isfile(summary_compressed): os.remove(summary_compressed) with tarfile.open(summary_compressed, mode='w:gz') as archive: archive.add(summary_dir_exit, arcname='summary', recursive=True) _run.add_artifact(filename=summary_compressed, name='summary.tar.gz') model.eval() model.current_writer = None # final print routine train_dataset.print_summary() _log.info(f'Total number of parameters: {total_params}') if config.final_training_pass: # train loss final_loss_train = 0.0 final_metric_train = 0.0 final_nr_nodes_train = 0 _log.info('final training pass ...') start = time.time() for data_ft in data_loader_train: data_ft = data_ft.to(device) out_ft = model(data_ft) final_loss_train += model.loss( out_ft, data_ft.y, data_ft.mask).item() * data_ft.num_nodes final_metric_train += model.out_to_metric( out_ft, data_ft.y) * data_ft.num_nodes final_nr_nodes_train += data_ft.num_nodes final_loss_train /= final_nr_nodes_train final_metric_train /= final_nr_nodes_train _run.log_scalar('loss_train_final', final_loss_train, config.training_epochs) _run.log_scalar('accuracy_train_final', final_metric_train, config.training_epochs) _log.info(f'final training pass in {time.time() - start:.3f}s') else: # report training loss of last epoch final_loss_train = epoch_loss final_metric_train = epoch_metric_train _log.info( f'Mean train loss ({train_dataset.__len__()} samples): {final_loss_train:.3f}' ) _log.info(f'Mean accuracy on train set: {final_metric_train:.3f}') if config.final_test_pass: # test loss data_loader_test = DataLoader( test_dataset, batch_size=config.batch_size_eval, shuffle=False, num_workers=config.num_workers, worker_init_fn=lambda idx: np.random.seed()) test_loss = 0.0 test_metric = 0.0 nr_nodes_test = 0 test_predictions = [] test_targets = [] test_1d_outputs = dict() _log.info('test pass ...') start_test_pass = time.time() for data_fe in data_loader_test: data_fe = data_fe.to(device) out_fe = model(data_fe) if config.write_to_db: start = time.time() out_1d = model.out_to_one_dim(out_fe).cpu() # TODO this assumes again that every pairs of directed edges are next to each other edges = torch.transpose(data_fe.edge_index, 0, 1)[0::2] edges = edges[data_fe.roi_mask].cpu().numpy().astype( np.int64) edges_orig_labels = np.zeros_like(edges, dtype=np.int64) edges_orig_labels = replace_values( in_array=edges, out_array=edges_orig_labels, old_values=np.arange(data_fe.num_nodes, dtype=np.int64), new_values=data_fe.node_ids.cpu().numpy().astype( np.int64), inplace=False) # TODO min max might be unnecessary here # convert to tuples, make sure that directedness is not a problem edges_list = [ tuple([np.min(i), np.max(i)]) for i in edges_orig_labels ] for k, v in zip(edges_list, out_1d): if k not in test_1d_outputs: test_1d_outputs[k] = v else: # TODO adapt strategy here if desired test_1d_outputs[k] = max(test_1d_outputs[k], v) _log.info( f'writing outputs to dict in {time.time() - start}s') test_loss += model.loss( out_fe, data_fe.y, data_fe.mask).item() * data_fe.num_nodes test_metric += model.out_to_metric( out_fe, data_fe.y) * data_fe.num_nodes nr_nodes_test += data_fe.num_nodes pred = model.out_to_predictions(out_fe) test_predictions.extend(model.predictions_to_list(pred)) test_targets.extend(data_fe.y.tolist()) test_loss /= nr_nodes_test test_metric /= nr_nodes_test _run.log_scalar('loss_test', test_loss, config.training_epochs) _run.log_scalar('accuracy_test', test_metric, config.training_epochs) _log.info(f'test pass in {time.time() - start_test_pass:.3f}s\n') _log.info( f'Mean test loss ({test_dataset.__len__()} samples): {test_loss:.3f}' ) _log.info(f'Mean accuracy on test set: {test_metric:.3f}\n') if config.write_to_db: # timestamp = datetime.datetime.now( # pytz.timezone('US/Eastern')).strftime('%Y%m%dT%H%M%S.%f%z') comment = _run.meta_info['options']['--comment'] test_dataset.write_outputs_to_db( outputs_dict=test_1d_outputs, collection_name=f'{_run.start_time}_{comment}') if config.plot_targets_vs_predictions: # TODO fix to run on cluster # plot targets vs predictions. default is a confusion matrix model.plot_targets_vs_predictions(targets=test_targets, predictions=test_predictions) _run.add_artifact(filename=os.path.join( config.run_abs_path, config.confusion_matrix_path), name=config.confusion_matrix_path) # if Regression, plot targets vs. continuous outputs # if isinstance(model.model_type, RegressionProblem): # test_outputs = [] # for data in data_loader_test: # data = data.to(device) # out = torch.squeeze(model(data)).tolist() # test_outputs.extend(out) # model.model_type.plot_targets_vs_outputs( # targets=test_targets, outputs=test_outputs) # plot the graphs in the test dataset for visual inspection if config.plot_graphs_testset: if config.plot_graphs_testset < 0 or config.plot_graphs_testset > test_dataset.__len__( ): plot_limit = test_dataset.__len__() else: plot_limit = config.plot_graphs_testset for i in range(plot_limit): g = test_dataset[i] g.to(device) out_p = model(g) g.plot_predictions(config=config, pred=model.predictions_to_list( model.out_to_predictions(out_p)), graph_nr=i, run=_run, acc=model.out_to_metric(out_p, g.y), logger=_log) else: # report validation loss of last epoch test_loss = validation_loss test_metric = epoch_metric_val _log.info( f'Mean validation loss ({test_dataset.__len__()} samples): {test_loss:.3f}' ) _log.info(f'Mean accuracy on validation set: {test_metric:.3f}\n') return '\n{0}\ntrain acc: {1:.3f}\ntest acc: {2:.3f}'.format( _run.meta_info['options']['--comment'], final_metric_train, test_metric)