def infer_hierarchy(neuron_df, connection_df, min_weight=10, init='groundtruth', verbose=True, special_debug=False): ## ## TODO: If filtering connections for min_weight drops some neurons entirely, they should be removed from neuron_df ## lsf_slots = os.environ.get('LSB_DJOB_NUMPROC', default=0) if lsf_slots: os.environ['OMP_NUM_THREADS'] = lsf_slots logger.info(f"Using {lsf_slots} CPUs for OpenMP") assert init in ('groundtruth', 'random') neuron_df = load_table(neuron_df) connection_df = load_table(connection_df) assert {*neuron_df.columns} >= {'bodyId', 'instance', 'type'} assert {*connection_df.columns} >= {'bodyId_pre', 'bodyId_post', 'weight'} if special_debug: # Choose a very small subset of the data neuron_df = neuron_df.iloc[::100] bodies = neuron_df['bodyId'] connection_df = connection_df.query('bodyId_pre in @bodies and bodyId_post in @bodies') if init == "groundtruth": with Timer("Computing initial hierarchy from groundtruth", logger): assign_morpho_indexes(neuron_df) num_morpho_groups = neuron_df.morpho_index.max()+1 init_bs = [neuron_df['morpho_index'].values, np.zeros(num_morpho_groups, dtype=int)] else: init_bs = None # If this is a per-ROI table, sum up the ROIs. if 'roi' in connection_df: connection_df = connection_df.groupby(['bodyId_pre', 'bodyId_post'], as_index=False)['weight'].sum() strong_connections_df = connection_df.query('weight >= @min_weight') strong_bodies = pd.unique(strong_connections_df[['bodyId_pre', 'bodyId_post']].values.reshape(-1)) weights = strong_connections_df.set_index(['bodyId_pre', 'bodyId_post'])['weight'] logger.info(f"Strong connectome (cutoff={min_weight}) has {len(strong_bodies)} bodies and {len(weights)} edges") vertexes = np.arange(len(strong_bodies), dtype=np.uint32) vertex_mapper = LabelMapper(strong_bodies.astype(np.uint64), vertexes) vertex_reverse_mapper = LabelMapper(vertexes, strong_bodies.astype(np.uint64)) g = construct_graph(weights, vertexes, vertex_mapper) with Timer("Running inference"): # Computes a NestedBlockState nbs = graph_tool.inference.minimize_nested_blockmodel_dl(g, bs=init_bs, mcmc_args=dict(parallel=True), # see graph-tool docs and mailing list for caveats deg_corr=True, verbose=verbose) partition_df = construct_partition_table(nbs, neuron_df, vertexes, vertex_reverse_mapper) return strong_connections_df, g, nbs, partition_df
def _find_disconnected_components(cleaned_edges, output_labels): """ Given a graph defined by cleaned_edges and a node labeling in output_labels, Check if any output labels are split among discontiguous groups, and return the set of output label IDs for such objects. """ # Figure out which edges were 'cut' (endpoints got different labels) # and which were preserved mapper = LabelMapper(np.arange(output_labels.shape[0], dtype=np.uint32), output_labels) labeled_edges = mapper.apply(cleaned_edges) preserved_edges = cleaned_edges[labeled_edges[:, 0] == labeled_edges[:, 1]] # Compute CC on the graph WITHOUT cut edges (keep only preserved edges) component_labels = connected_components(preserved_edges, len(output_labels)) assert len(component_labels) == len(output_labels) # Align node output labels to their connected component labels cc_df = pd.DataFrame({'label': output_labels, 'cc': component_labels}) # How many unique connected component labels are associated with each output label? cc_counts = cc_df.groupby('label').nunique()['cc'] # Any output labels that map to multiple CC labels are 'disconnected components' in the output. disconnected_cc_counts = cc_counts[cc_counts > 1] disconnected_components = set(disconnected_cc_counts.index) - set([0]) return disconnected_components
def connected_components_nonconsecutive(edges, node_ids): """ Run connected components on the graph encoded by 'edges' and node_ids. All nodes from edges must be present in node_ids. (Additional nodes are permitted, in which case each is assigned its own CC.) The node_ids do not need to be consecutive, and can have arbitrarily large values. For graphs with consecutively-valued nodes, connected_components() (below) will be faster because it avoids a relabeling step. Args: edges: ndarray, shape=(E,2), dtype np.uint32 or uint64 node_ids: ndarray, shape=(N,), dtype np.uint32 or uint64 Returns: ndarray, same shape as node_ids, labeled by component index from 0..C """ assert node_ids.ndim == 1 assert node_ids.dtype in (np.uint32, np.uint64) cons_node_ids = np.arange(len(node_ids), dtype=np.uint32) mapper = LabelMapper(node_ids, cons_node_ids) cons_edges = mapper.apply(edges) return connected_components(cons_edges, len(node_ids))
def _overwrite_body_id_column(block_sv_stats, segment_to_body_df=None): """ Given a stats array with 'columns' as defined in STATS_DTYPE, overwrite the body_id column according to the given agglomeration mapping DataFrame. If no mapping is given, simply copy the segment_id column into the body_id column. Args: block_sv_stats: numpy.ndarray, with dtype=STATS_DTYPE segment_to_body_df: pandas.DataFrame, with columns ['segment_id', 'body_id'] """ assert block_sv_stats.dtype == STATS_DTYPE assert STATS_DTYPE[0][0] == 'body_id' assert STATS_DTYPE[1][0] == 'segment_id' block_sv_stats = block_sv_stats.view( [STATS_DTYPE[0], STATS_DTYPE[1], ('other_cols', STATS_DTYPE[2:])] ) if segment_to_body_df is None: # No agglomeration block_sv_stats['body_id'] = block_sv_stats['segment_id'] else: assert list(segment_to_body_df.columns) == AGGLO_MAP_COLUMNS # This could be done via pandas merge(), followed by fillna(), etc., # but I suspect LabelMapper is faster and more frugal with RAM. mapper = LabelMapper(segment_to_body_df['segment_id'].values, segment_to_body_df['body_id'].values) del segment_to_body_df # Remap in batches to save RAM batch_size = 1_000_000 for chunk_start in range(0, len(block_sv_stats), batch_size): chunk_stop = min(chunk_start+batch_size, len(block_sv_stats)) chunk_segments = block_sv_stats['segment_id'][chunk_start:chunk_stop] block_sv_stats['body_id'][chunk_start:chunk_stop] = mapper.apply(chunk_segments, allow_unmapped=True)
def apply_mappings(supervoxels, mappings): assert isinstance(mappings, dict) df = pd.DataFrame(index=supervoxels.astype(np.uint64, copy=False)) df.index.name = 'sv' for name, mapping in mappings.items(): assert isinstance(mapping, pd.Series) index_values = mapping.index.values.astype(np.uint64, copy=False) mapping_values = mapping.values.astype(np.uint64, copy=False) mapper = LabelMapper(index_values, mapping_values) df[name] = mapper.apply(df.index.values, allow_unmapped=True) return df
def __init__(self, edge_array, directed=True): import graph_tool as gt assert edge_array.dtype in (np.uint32, np.uint64) self.node_ids = pd.unique(edge_array.reshape(-1)) self.cons_node_ids = np.arange(len(self.node_ids), dtype=np.uint32) self.mapper = LabelMapper(self.node_ids, self.cons_node_ids) cons_edges = self.mapper.apply(edge_array) self.g = gt.Graph(directed=directed) self.g.add_edge_list(cons_edges, hashed=False)
def apply_mapping_to_mergetable(merge_table_df, mapping): """ Set the 'body' column of the given merge table (append one if it didn't exist) by applying the given SV->body mapping to the merge table's id_a column. """ if isinstance(mapping, str): with Timer("Loading mapping", logger): mapping = load_mapping(mapping) assert isinstance(mapping, pd.Series), "Mapping must be a pd.Series" with Timer("Applying mapping to merge table", logger): mapper = LabelMapper(mapping.index.values, mapping.values) merge_table_df['body'] = mapper.apply(merge_table_df['id_a'].values, allow_unmapped=True)
def remap_bricks(partition_bricks): domain, codomain = mapping_pairs.transpose() mapper = LabelMapper(domain, codomain) partition_bricks = list(partition_bricks) for brick in partition_bricks: # TODO: Apparently LabelMapper can't handle non-contiguous arrays right now. # (It yields incorrect results) # Check to see if this is still a problem in the latest version of xtensor-python. brick.volume = np.asarray(brick.volume, order='C') mapper.apply_inplace(brick.volume, allow_unmapped=True) brick.compress() return partition_bricks
def check_tarsupervoxels_status_via_exists(server, uuid, tsv_instance, bodies, seg_instance=None, mapping=None, kafka_msgs=None): """ For the given bodies, query the given tarsupervoxels instance and return a DataFrame indicating which supervoxels are 'missing' from the instance. Bodies that no longer exist in the segmentation instance are ignored. This function downloads the complete mapping in advance and uses it to determine which supervoxels belong to each body. Then uses the /exists endpoint to query for missing supervoxels, rather than /missing, which incurs a disk read in DVID. """ if seg_instance is None: # Determine segmentation instance info = fetch_instance_info(server, uuid, tsv_instance) seg_instance = info["Base"]["Syncs"][0] if mapping is None: mapping = fetch_complete_mappings(server, uuid, seg_instance, kafka_msgs=kafka_msgs) # Filter out bodies we don't care about, # and append unmapped (singleton/identity) bodies _bodies = set(bodies) mapping = pd.DataFrame(mapping).query('body in @_bodies')['body'].copy() unmapped_bodies = _bodies - set(mapping) unmapped_bodies = np.fromiter(unmapped_bodies, np.uint64) singleton_mapping = pd.Series(index=unmapped_bodies, data=unmapped_bodies, dtype=np.uint64) mapping = pd.concat((mapping, singleton_mapping)) assert mapping.index.values.dtype == np.uint64 assert mapping.values.dtype == np.uint64 # Faster than mapping.loc[], apparently mapper = LabelMapper(mapping.index.values, mapping.values) statuses = fetch_exists(server, uuid, tsv_instance, mapping.index, batch_size=10_000, processes=16) missing_svs = statuses[~statuses].index.values assert missing_svs.dtype == np.uint64 missing_bodies = mapper.apply(missing_svs, True) missing_df = pd.DataFrame({'sv': missing_svs, 'body': missing_bodies}) assert missing_df['sv'].dtype == np.uint64 assert missing_df['body'].dtype == np.uint64 # Return a series, indexed by sv return missing_df.set_index('sv')['body']
def find_all_hotknife_edges_for_plane(server, uuid, instance, plane_center_coord_s0, tile_shape_s0, plane_spacing_s0, min_overlap_s0=1, min_jaccard=0.0, *, scale=0, mapping=None): """ Find all hotknife edges around the given X-plane, found in batches according to tile_shape_s0. See find_hotknife_edges() for more details. """ plane_box_s0 = fetch_volume_box(server, uuid, instance) plane_bounds_s0 = plane_box_s0[:,:2] edge_tables = [] tile_boxes = boxes_from_grid(plane_bounds_s0, tile_shape_s0, clipped=True) for tile_index, tile_bounds_s0 in enumerate(tile_boxes): tile_logger = PrefixedLogger(logger, f"Tile {tile_index:03d}/{len(tile_boxes):03d}: ") edge_table = find_hotknife_edges( server, uuid, instance, plane_center_coord_s0, tile_bounds_s0, plane_spacing_s0, min_overlap_s0, min_jaccard, scale=scale, supervoxels=True, # Compute edges across supervoxels, and then filter out already-merged bodies afterwards logger=tile_logger ) edge_tables.append(edge_table) edge_table = pd.concat(edge_tables, ignore_index=True) assert (edge_table.columns == ['left', 'right', 'xa', 'ya', 'za', 'xb', 'yb', 'zb', 'overlap', 'jaccard', 'left_cc_size', 'right_cc_size']).all() edge_table.columns = ['id_a', 'id_b', 'xa', 'ya', 'za', 'xb', 'yb', 'zb', 'overlap', 'jaccard', 'left_cc_size', 'right_cc_size'] assert edge_table['id_a'].dtype == np.uint64 assert edge_table['id_b'].dtype == np.uint64 # "Complete" mappings not necessary for our purposes. if mapping is None: mapping = fetch_mappings(server, uuid, instance, as_array=True) mapper = LabelMapper(mapping[:,0], mapping[:,1]) edge_table['body_a'] = mapper.apply(edge_table['id_a'].values, True) edge_table['body_b'] = mapper.apply(edge_table['id_b'].values, True) edge_table.query('body_a != body_b', inplace=True) return edge_table
def _erase_tiny_interior_segments(seg_vol, min_size): """ Erase any segments that are smaller than the given size and don't touch the edge of the volume. """ edge_mitos = (set(pd.unique(seg_vol[0, :, :].ravel())) | set(pd.unique(seg_vol[:, 0, :].ravel())) | set(pd.unique(seg_vol[:, :, 0].ravel())) | set(pd.unique(seg_vol[-1, :, :].ravel())) | set(pd.unique(seg_vol[:, -1, :].ravel())) | set(pd.unique(seg_vol[:, :, -1].ravel()))) mito_sizes = pd.Series(seg_vol.ravel()).value_counts() nontiny_mitos = mito_sizes[mito_sizes >= min_size].index keep_mitos = (edge_mitos | set(nontiny_mitos)) keep_mitos = np.array([*keep_mitos], np.uint64) if len(keep_mitos) == 0: return np.zeros_like(seg_vol) # Erase everything that isn't in the keep set seg_vol = LabelMapper(keep_mitos, keep_mitos).apply_with_default(seg_vol, 0) return seg_vol
def remap_cc_to_final(orig_brick, cc_brick, wrapped_brick_mapping_df): """ Given an original brick and the corresponding CC brick, Relabel the CC brick according to the final label mapping, as provided in wrapped_brick_mapping_df. """ assert (orig_brick.logical_box == cc_brick.logical_box).all() assert (orig_brick.physical_box == cc_brick.physical_box).all() # Check for NaN, which implies that the mapping for this # brick is empty (no objects to relabel). if isinstance(wrapped_brick_mapping_df, float): assert np.isnan(wrapped_brick_mapping_df) final_vol = orig_brick.volume orig_brick.compress() else: # Construct mapper from only the rows we need cc_labels = pd.unique(cc_brick.volume.reshape(-1)) # @UnusedVariable mapping = wrapped_brick_mapping_df.df.query('cc in @cc_labels')[['cc', 'final_cc']].values mapper = LabelMapper(*mapping.transpose()) # Apply mapping to CC vol, writing zeros whereever the CC isn't mapped. final_vol = mapper.apply_with_default(cc_brick.volume, 0) # Overwrite zero voxels from the original segmentation. final_vol = np.where(final_vol, final_vol, orig_brick.volume) orig_brick.compress() cc_brick.compress() final_brick = Brick( orig_brick.logical_box, orig_brick.physical_box, final_vol, location_id=orig_brick.location_id, compression=orig_brick.compression ) return final_brick
def region_features(label_img, grayscale_img=None, features=['Box', 'Count'], ignore_label=None): """ Wrapper around vigra.analysis.extractRegionFeatures() that supports uint64 and returns each feature as a pandas Series or DataFrame, indexed by object ID. For simple features such as 'Box' and 'Count', most of the time is spent remapping the input array from uint64 to uint32, which is the only label image type supported by vigra. See vigra docs regarding the supported features: - http://ukoethe.github.io/vigra/doc-release/vigranumpy/index.html#vigra.analysis.extractRegionFeatures - http://ukoethe.github.io/vigra/doc-release/vigra/group__FeatureAccumulators.html Args: label_img: An integer-valued label image, containing no negative values grayscle_img: Optional. If provided, then weighted features are available. See GRAYSCALE_FEATURE_NAMES, above. features: List of strings. If no grayscale image was provided, you can only ask for the features in ``SEGMENTATION_FEATURE_NAMES``, above. ignore_label: A background label to ignore. If you don't want to ignore any thing, pass ``None``. Returns: dict {name: feature}, where each feature value is indexed by label ID. For keys where the feature is scalar-valued for each label, the returned value is a Series. For keys where the feature is a 1D array for each label, a DataFrame is returned, with columns 'zyx'. For keys where the feature is a 2D array (e.g. Box, RegionAxes, etc.), a Series is returned whose dtype=object, and each item in the series is a 2D array. TODO: This might be a good place to use Xarray """ assert label_img.ndim in (2, 3) axes = 'zyx'[-label_img.ndim:] vfeatures = {*features} valid_names = {*SEGMENTATION_FEATURE_NAMES, *GRAYSCALE_FEATURE_NAMES} invalid_names = vfeatures - valid_names assert not invalid_names, \ f"Invalid feature names: {invalid_names}" if 'Box' in features: vfeatures -= {'Box'} vfeatures |= {'Coord<Minimum>', 'Coord<Maximum>'} if 'Box0' in features: vfeatures -= {'Box0'} vfeatures |= {'Coord<Minimum>'} if 'Box1' in features: vfeatures -= {'Box1'} vfeatures |= {'Coord<Maximum>'} assert np.issubdtype(label_img.dtype, np.integer) label_ids = None if label_img.dtype == np.uint32: label_img32 = label_img elif label_img.dtype == np.int32: label_img32 = label_img.view(np.uint32) elif label_img.dtype in (np.int64, np.uint64): label_img = label_img.view(np.uint64) label_ids = np.sort(pd.unique(label_img.ravel())) # Map from uint64 -> uint32 label_ids_32 = np.arange(len(label_ids), dtype=np.uint32) mapper = LabelMapper(label_ids, label_ids_32) label_img32 = mapper.apply(label_img) if ignore_label is not None: ignore_label = mapper.apply(np.array([ignore_label], np.uint64))[0] else: label_img32 = label_img.astype(np.uint32) assert label_img32.dtype == np.uint32 if grayscale_img is None: invalid_names = vfeatures - {*SEGMENTATION_FEATURE_NAMES} assert not invalid_names, \ f"Invalid segmentation feature names: {invalid_names}" grayscale_img = label_img32.view(np.float32) else: assert grayscale_img.dtype == np.float32, \ "Grayscale image must be float32" grayscale_img = vigra.taggedView(grayscale_img, axes) label_img32 = vigra.taggedView(label_img32, axes) # TODO: provide histogramRange options acc = vigra.analysis.extractRegionFeatures(grayscale_img, label_img32, [*vfeatures], ignoreLabel=ignore_label) results = {} if 'Box0' in features: v = acc['Coord<Minimum >'].astype(np.int32) results['Box0'] = pd.DataFrame(v, columns=[*axes]) if 'Box1' in features: v = 1 + acc['Coord<Maximum >'].astype(np.int32) results['Box1'] = pd.DataFrame(v, columns=[*axes]) if 'Box' in features: box0 = acc['Coord<Minimum >'].astype(np.int32) box1 = (1 + acc['Coord<Maximum >']).astype(np.int32) boxes = np.stack((box0, box1), axis=1) obj_boxes = np.zeros(len(boxes), object) obj_boxes[:] = list(boxes) results['Box'] = pd.Series(obj_boxes, name='Box') for k, v in acc.items(): k = k.replace(' ', '') # Only return the features the user explicitly requested. if k not in features: continue if v.ndim == 1: results[k] = pd.Series(v, name=k) elif v.ndim == 2: results[k] = pd.DataFrame(v, columns=[*axes]) else: # If the data doesn't neatly fit into a 1-d Series # or a 2-d DataFrame, then construct a Series with dtype=object # and make each row a separate ndarray object. obj_v = np.zeros(len(v), dtype=object) obj_v[:] = list(v) results[k] = pd.Series(obj_v, name=k) # Set index to the original uint64 values if label_img.dtype == np.uint64: for v in results.values(): v.index = label_ids return results
def select_hulls_for_mito_bodies(mito_body_ct, mito_bodies_mask, mito_binary, body_seg, hull_masks, seed_bodies, box, scale, viewer=None, res0=8, progress=False): mito_bodies_mito_seg = np.where(mito_bodies_mask & mito_binary, body_seg, 0) nonmito_body_seg = np.where(mito_bodies_mask, 0, body_seg) hull_cc_overlap_stats = [] for hull_cc, (mask_box, mask) in tqdm_proxy(hull_masks.items(), disable=not progress): mbms = mito_bodies_mito_seg[box_to_slicing(*mask_box)] masked_hull_cc_bodies = np.where(mask, mbms, 0) # Faster to check for any non-zero values at all before trying to count them. # This early check saves a lot of time in practice. if not masked_hull_cc_bodies.any(): continue # This hull was generated from a particular seed body (non-mito body). # If it accidentally overlaps with any other non-mito bodies, # then delete those voxels from the hull. # If that causes the hull to become split apart into multiple connected components, # then keep only the component(s) which overlap the seed body. seed_body = seed_bodies[hull_cc] nmbs = nonmito_body_seg[box_to_slicing(*mask_box)] other_bodies = set(pd.unique(nmbs[mask])) - {0, seed_body} if other_bodies: # Keep only the voxels on mito bodies or on the # particular non-mito body for this hull (the "seed body"). mbm = mito_bodies_mask[box_to_slicing(*mask_box)] mask[:] &= (mbm | (nmbs == seed_body)) mask = vigra.taggedView(mask, 'zyx') mask_cc = vigra.analysis.labelMultiArrayWithBackground( mask.view(np.uint8)) if mask_cc.max() > 1: mask_ct = contingency_table(mask_cc, nmbs).reset_index() keep_ccs = mask_ct['left'].loc[(mask_ct['left'] != 0) & (mask_ct['right'] == seed_body)] mask[:] = mask_for_labels(mask_cc, keep_ccs) mito_bodies, counts = np.unique(masked_hull_cc_bodies, return_counts=True) overlaps = pd.DataFrame({ 'mito_body': mito_bodies, 'overlap': counts, 'hull_cc': hull_cc, 'hull_size': mask.sum(), 'hull_body': seed_body }) hull_cc_overlap_stats.append(overlaps) if len(hull_cc_overlap_stats) == 0: logger.warning("Could not find any matches for any mito bodies!") mito_body_ct['hull_body'] = np.uint64(0) return mito_body_ct hull_cc_overlap_stats = pd.concat(hull_cc_overlap_stats, ignore_index=True) hull_cc_overlap_stats = hull_cc_overlap_stats.query( 'mito_body != 0').copy() # Aggregate the stats for each body and the hull bodies it overlaps with, # Select the hull_body with the most overlap, or in the case of ties, the hull body that is largest overall. # (Ties are probably more common in the event that two hulls completely encompass a small mito body.) hull_body_overlap_stats = hull_cc_overlap_stats.groupby( ['mito_body', 'hull_body'])[['overlap', 'hull_size']].sum() hull_body_overlap_stats = hull_body_overlap_stats.sort_values( ['mito_body', 'overlap', 'hull_size'], ascending=False) hull_body_overlap_stats = hull_body_overlap_stats.reset_index() mito_hull_selections = (hull_body_overlap_stats.drop_duplicates( 'mito_body').set_index('mito_body')['hull_body']) mito_body_ct = mito_body_ct.merge(mito_hull_selections, 'left', left_index=True, right_index=True) mito_body_ct['hull_body'] = mito_body_ct['hull_body'].fillna(0) dtypes = {col: np.float32 for col in mito_body_ct.columns} dtypes['hull_body'] = np.uint64 mito_body_ct = mito_body_ct.astype(dtypes) if viewer: assert mito_hull_selections.index.dtype == mito_hull_selections.values.dtype == np.uint64 mito_hull_mapper = LabelMapper(mito_hull_selections.index.values, mito_hull_selections.values) remapped_body_seg = mito_hull_mapper.apply(body_seg, True) remapped_body_seg = apply_mask_for_labels(remapped_body_seg, mito_hull_selections.values) update_seg_layer(viewer, 'altered-bodies', remapped_body_seg, scale, box) # Show the final hull masks (after erasure of non-target bodies) assert sorted(hull_masks.keys()) == [*range(1, 1 + len(hull_masks))] hull_cc_overlap_stats = hull_cc_overlap_stats.sort_values('hull_size') hull_seg = np.zeros_like(remapped_body_seg) for row in hull_cc_overlap_stats.itertuples(): mask_box, mask = hull_masks[row.hull_cc] view = hull_seg[box_to_slicing(*mask_box)] view[:] = np.where(mask, row.hull_body, view) update_seg_layer(viewer, 'final-hull-seg', hull_seg, scale, box) return mito_body_ct
def stitch_adjacent_faces(self, drop_unused_vertices=True, drop_duplicate_faces=True): """ Search for duplicate vertices and remove all references to them in self.faces, by replacing them with the index of the first matching vertex in the list. Works in-place. Note: Normals are recomputed iff they were present originally. Args: drop_unused_vertices: If True, drop the unused (duplicate) vertices from self.vertices_zyx (since no faces refer to them any more, this saves some RAM). drop_duplicate_faces: If True, remove faces with an identical vertex list to any previous face. Returns: False if no stitching was performed (none was needed), or True otherwise. """ need_normals = (self.normals_zyx.shape[0] > 0) mapping_pairs = remap_duplicates(self.vertices_zyx) dup_indices, orig_indices = mapping_pairs.transpose() if len(dup_indices) == 0: if need_normals: self.recompute_normals(True) return False # No stitching was needed. # Discard old normals self.drop_normals() # Remap faces to no longer refer to the duplicates mapper = LabelMapper(dup_indices, orig_indices) mapper.apply_inplace(self.faces, allow_unmapped=True) del mapper del orig_indices # Now the faces have been stitched, but the duplicate # vertices are still unnecessarily present, # and the face vertex indexes still reflect that. # Also, we may have uncovered duplicate faces now that the # vertexes have been canonicalized. if drop_unused_vertices: self.drop_unused_vertices() def _drop_duplicate_faces(): # Normalize face vertex order before checking for duplicates. # Technically, this means we don't distinguish # betweeen clockwise/counter-clockwise ordering, # but that seems unlikely to be a problem in practice. sorted_faces = pd.DataFrame(np.sort(self.faces, axis=1)) duplicate_faces_mask = sorted_faces.duplicated() faces_df = pd.DataFrame(self.faces) faces_df.drop(duplicate_faces_mask.nonzero()[0], inplace=True) self.faces = np.asarray(faces_df.values, order='C') if drop_duplicate_faces: _drop_duplicate_faces() if need_normals: self.recompute_normals(True) return True # stitching was needed.
def compute_focused_bodies(server, uuid, instance, synapse_samples, min_tbars, min_psds, root_sv_sizes, min_body_size, sv_classifications=None, marked_bad_bodies=None, return_table=False, kafka_msgs=None): """ Compute the complete set of focused bodies, based on criteria for number of tbars, psds, or overall size, and excluding explicitly listed bad bodies. This function takes ~20 minutes to run on hemibrain inputs, with a ton of RAM. The procedure is: 1. Apply synapse-based criteria a. Load synapse CSV file b. Map synapse SVs -> bodies (if needed) b2. If any SVs are 'retired', update those synapses to use the new IDs. c. Calculate synapses (tbars, psds) per body d. Initialize set with bodies that have enough synapses 2. Apply size-based criteria a. Calculate body sizes (based on supervoxel sizes and current mapping) b. Add "big" bodies to the set 3. Apply "bad body" criteria a. Read the list of "bad bodies" b. Remove bad bodies from the set Example: server = 'emdata3:8900' uuid = '7254' instance = 'segmentation' synapse_samples = '/nrs/flyem/bergs/complete-ffn-agglo/sampled-synapses-ef1d-locked.csv' min_tbars = 2 min_psds = 10 # old repo supervoxels (before server rebase) # # Note: This was taken from node 5501ae83e31247498303a159eef824d8, which is from a different repo. # But that segmentation was eventually copied to the production repo as root node a776af. # See /groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/8nm/copy-fixed-from-emdata2-to-emdata3-20180402.214505 # #root_sv_sizes_dir = '/groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/8nm/compute-8nm-extended-fixed-STATS-ONLY-20180402.192015' #root_sv_sizes = f'{root_sv_sizes_dir}/supervoxel-sizes.h5' root_sv_sizes_dir = '/groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/flattened/compute-stats-from-corrupt-20181016.203848' root_sv_sizes = f'{root_sv_sizes_dir}/supervoxel-sizes-2884.h5' min_body_size = int(10e6) sv_classifications = '/nrs/flyem/bergs/sv-classifications.h5' marked_bad_bodies = '/nrs/flyem/bergs/complete-ffn-agglo/bad-bodies-2019-02-26.csv' table_description = f'{uuid}-{min_tbars}tbars-{min_psds}psds-{min_body_size / 1e6:.1f}Mv' focused_table = compute_focused_bodies(server, uuid, instance, synapse_samples, min_tbars, min_psds, root_sv_sizes, min_body_size, sv_classifications, marked_bad_bodies, return_table=True) # As npy: np.save(f'focused-{table_description}.npy', focused_table.to_records(index=True)) # As CSV: focused_table.to_csv(f'focused-{table_description}.npy', index=True, header=True) Args: server, uuid, instance: labelmap instance root_sv_sizes: mapping of supervoxel sizes from the root node, as returned by load_supervoxel_sizes(), or a path to an hdf5 file from which it can be loaded synapse_samples: A DataFrame with columns 'body' (or 'sv') and 'kind', or a path to a CSV file with those columns. The 'kind' column is expected to have only 'PreSyn' and 'PostSyn' entries. If the table has an 'sv' column, any "retired" supervoxel IDs will be updated before continuing with the analysis. min_tbars: The minimum number pf PreSyn entries to pass the filter. Bodies with fewer tbars may still be included if they satisfy the min_psds criteria. min_psds: The minimum numer of PostSyn entires to pass the filter. Bodies with fewer psds may still pass the filter if they satisfy the min_tbars criteria. min_body_size: Determines which bodies are included on the basis of their size alone, regardless of synapse count. sv_classifications: Optional. Path to an hdf5 file containing supervoxel classifications. Must have datasets: 'supervoxel_ids', 'classifications', and 'class_names'. Used to exclude known-bad supervoxels. The file need not include all supervoxels. Any supervoxels MISSING from this file are not considered 'bad'. marked_bad_bodies: Optional. A list of known-bad bodies to exclude from the results, or a path to a .csv file with that list (in the first column), or a keyvalue instance name from which the list can be loaded as JSON. return_table: If True, return the focused bodies in a DataFrame, indexed by body, with columns for body size and synapse counts. If False, simply return the list of bodies (saves about 4 minutes). Returns: A list of body IDs that passed all filters, or a DataFrame with columns: ['voxel_count', 'PreSyn', 'PostSyn'] (See return_table option.) """ split_source = 'dvid' # Load full mapping. It's needed for both synapses and body sizes. mapping = fetch_complete_mappings(server, uuid, instance, include_retired=True, kafka_msgs=kafka_msgs) mapper = LabelMapper(mapping.index.values, mapping.values) ## ## Synapses ## if isinstance(synapse_samples, str): synapse_samples = load_synapses(synapse_samples) assert set(['sv', 'body']).intersection(set(synapse_samples.columns)), \ "synapse samples must have either 'body' or 'sv' column" # If 'sv' column is present, use it to create (or update) the body column if 'sv' in synapse_samples.columns: synapse_samples = update_synapse_table(server, uuid, instance, synapse_samples, split_source=split_source) assert synapse_samples['sv'].dtype == np.uint64 synapse_samples['body'] = mapper.apply(synapse_samples['sv'].values, True) with Timer("Filtering for synapses", logger): synapse_body_table = body_synapse_counts(synapse_samples) synapse_bodies = synapse_body_table.query( 'PreSyn >= @min_tbars or PostSyn >= @min_psds').index.values logger.info(f"Found {len(synapse_bodies)} with sufficient synapses") focused_bodies = set(synapse_bodies) ## ## Body sizes ## with Timer("Filtering for body size", logger): sv_sizes = load_all_supervoxel_sizes(server, uuid, instance, root_sv_sizes, split_source=split_source) body_stats = compute_body_sizes(sv_sizes, mapping, True) big_body_stats = body_stats.query('voxel_count >= @min_body_size') big_bodies = big_body_stats.index logger.info(f"Found {len(big_bodies)} with sufficient size") focused_bodies |= set(big_bodies) ## ## SV classifications ## if sv_classifications is not None: with Timer(f"Filtering by supervoxel classifications"), h5py.File( sv_classifications, 'r') as f: sv_classes = pd.DataFrame({ 'sv': f['supervoxel_ids'][:], 'klass': f['classifications'][:].astype(np.uint8) }) # Get the set of bad supervoxels all_class_names = list(map(bytes.decode, f['class_names'][:])) bad_class_names = [ 'unknown', 'blood vessels', 'broken white tissue', 'glia', 'oob' ] _bad_class_ids = set(map(all_class_names.index, bad_class_names)) bad_svs = sv_classes.query('klass in @_bad_class_ids')['sv'] # Add column for sizes bad_sv_sizes = pd.DataFrame(index=bad_svs).merge( pd.DataFrame(sv_sizes), how='left', left_index=True, right_index=True) # Append body bad_sv_sizes['body'] = mapper.apply(bad_sv_sizes.index.values, True) # For bodies that contain at least one bad supervoxel, # compute the total size of the bad supervoxels they contain body_bad_voxels = bad_sv_sizes.groupby('body').agg( {'voxel_count': 'sum'}) # Append total body size for comparison body_bad_voxels = body_bad_voxels.merge(body_stats[['voxel_count' ]], how='left', left_index=True, right_index=True, suffixes=('_bad', '_total')) bad_bodies = body_bad_voxels.query( 'voxel_count_bad > voxel_count_total//2').index bad_focused_bodies = focused_bodies & set(bad_bodies) logger.info( f"Dropping {len(bad_focused_bodies)} bodies with more than 50% bad supervoxels" ) focused_bodies -= bad_focused_bodies ## ## Marked Bad bodies ## if marked_bad_bodies is not None: if isinstance(marked_bad_bodies, str): if marked_bad_bodies.endswith('.csv'): marked_bad_bodies = read_csv_col(marked_bad_bodies, 0) else: # If it ain't a CSV, maybe it's a key-value instance and key to read from. raise AssertionError( "FIXME: Need convention for specifying key-value instance and key" ) #marked_bad_bodies = fetch_key(server, uuid, marked_bad_bodies, as_json=True) with Timer( f"Dropping {len(marked_bad_bodies)} bad bodies (from {len(focused_bodies)})" ): focused_bodies -= set(marked_bad_bodies) ## ## Prepare results ## logger.info(f"Found {len(focused_bodies)} focused bodies") focused_bodies = np.fromiter(focused_bodies, dtype=np.uint64) focused_bodies.sort() if not return_table: return focused_bodies with Timer("Computing full focused table", logger): # Start with an empty DataFrame (except index) focus_table = pd.DataFrame(index=focused_bodies) # Merge size/sv_count focus_table = focus_table.merge(body_stats, how='left', left_index=True, right_index=True, copy=False) # Add synapse columns focus_table = focus_table.merge(synapse_body_table, how='left', left_index=True, right_index=True, copy=False) focus_table.fillna(0, inplace=True) focus_table['voxel_count'] = focus_table['voxel_count'].astype( np.uint64) focus_table['sv_count'] = focus_table['sv_count'].astype(np.uint32) focus_table['PreSyn'] = focus_table['PreSyn'].astype(np.uint32) focus_table['PostSyn'] = focus_table['PostSyn'].astype(np.uint32) # Sort biggest..smallest focus_table.sort_values('voxel_count', ascending=False, inplace=True) focus_table.index.name = 'body' return focus_table
def cleave(edges, edge_weights, seeds_dict, node_ids, node_sizes=None, method='seeded-mst'): """ Cleave the graph with the given edges and edge weights. Args: edges: array, (E,2), uint32 edge_weights: array, (E,), float32 seeds_dict: dict, { seed_class : [node_id, node_id, ...] } node_ids: The complete list of node IDs in the graph. Must contain a superset of the ids given in edges. Extra ids in node_ids (i.e. not mentioned in 'edges') will be included in the results as disconnected components. method: One of: 'seeded-mst', 'seeded-watershed', 'agglomerative-clustering', 'echo-seeds' Returns: CleaveResults, namedtuple with fields: (node_ids, output_labels, disconnected_components, contains_unlabeled_components) Where: node_ids: The graph node_ids. output_labels: array (N,), uint32 Agglomerated node labeling, in the same order as node_ids. disconnected_components: A set of seeds which ended up with more than one component in the result. contains_unlabeled_components: True if the input contains one or more disjoint components that were not seeded and thus not labeled during agglomeration. False otherwise. """ assert isinstance(node_ids, np.ndarray) assert node_ids.dtype in (np.uint32, np.uint64) assert node_ids.ndim == 1 assert node_sizes is None or node_sizes.shape == node_ids.shape cleave_func, requires_sizes = get_cleave_method(method) assert not requires_sizes or node_sizes is not None, \ f"The specified cleave method ({method}) requires node sizes but none were provided." # Relabel node ids consecutively cons_node_ids = np.arange(len(node_ids), dtype=np.uint32) mapper = LabelMapper(node_ids, cons_node_ids) # Initialize sparse seed label array seed_labels = np.zeros_like(cons_node_ids) for seed_class, seed_nodes in seeds_dict.items(): seed_nodes = np.asarray(seed_nodes, dtype=np.uint64) mapper.apply_inplace(seed_nodes) seed_labels[seed_nodes] = seed_class if len(edges) == 0: # No edges: Return empty results (just seeds) return CleaveResults(seed_labels, set(seeds_dict.keys()), not seed_labels.all()) # Clean the edges (normalized form, no duplicates, no loops) edges.sort(axis=1) edges_df = pd.DataFrame({ 'u': edges[:, 0], 'v': edges[:, 1], 'weight': edge_weights }) edges_df.drop_duplicates(['u', 'v'], keep='last', inplace=True) edges_df = edges_df.query('u != v') edges = edges_df[['u', 'v']].values edge_weights = edges_df['weight'].values # Relabel edges for consecutive nodes cons_edges = mapper.apply(edges) assert cons_edges.dtype == np.uint32 cleave_results = cleave_func(cons_edges, edge_weights, seed_labels, node_sizes) assert isinstance(cleave_results, CleaveResults) return cleave_results
def compute_body_sizes(sv_sizes, mapping, include_unmapped_singletons=False): """ Given a Series of supervoxel sizes and an sv-to-body mapping, compute the size of each body in the mapping, and its count of supervoxels. Any supervoxels in the mapping that are missing from sv_sizes will be ignored. Body 0 will be excluded from the results, even if it is present in the mapping. Args: sv_sizes: pd.Series, indexed by sv, or a path to an hdf5 file which can be loaded via load_supervoxel_sizes() mapping: pd.Series, indexed by sv, with body as value, or a path to a file which can be loaded by load_mapping() include_unmapped_singletons: If True, then the result will also include all supervoxels from sv_sizes that weren't mentioned in the mapping. (They are presumed to be singleton-supervoxel bodies.) Returns: pd.DataFrame, indexed by body, with columns ['voxel_count', 'sv_count'], and sorted by decreasing voxel_count. Example: >>> mesh_job_dir = '/groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/8nm' >>> sv_sizes_path = f'{mesh_job_dir}/compute-8nm-extended-fixed-STATS-ONLY-20180402.192015/supervoxel-sizes.h5' >>> sv_sizes = load_supervoxel_sizes(sv_sizes_path) >>> mapping = fetch_complete_mappings('emdata3:8900', '52f9', 'segmentation') >>> body_sizes = compute_body_sizes(sv_sizes, mapping) """ if isinstance(sv_sizes, str): logger.info("Loading supervoxel sizes") assert os.path.splitext(sv_sizes)[1] == '.h5' sv_sizes = load_supervoxel_sizes(sv_sizes) if isinstance(mapping, str): logger.info("Loading mapping") mapping = load_mapping(mapping) assert isinstance(sv_sizes, pd.Series) assert isinstance(mapping, pd.Series) assert sv_sizes.index.dtype == np.uint64 sv_sizes = sv_sizes.astype(np.uint64, copy=False) size_mapper = LabelMapper(sv_sizes.index.values, sv_sizes.values) # Just drop SVs that we don't have sizes for. logger.info("Dropping unknown supervoxels") mapping = mapping.loc[mapping.index.isin(sv_sizes.index)] logger.info("Applying sizes to mapping") df = pd.DataFrame({'body': mapping}) df['voxel_count'] = size_mapper.apply(mapping.index.values) logger.info("Aggregating sizes by body") body_stats = df.groupby('body').agg({'voxel_count': ['sum', 'size']}) body_stats.columns = ['voxel_count', 'sv_count'] body_stats['sv_count'] = body_stats['sv_count'].astype(np.uint32) #body_sizes = body_stats['voxel_count'] if include_unmapped_singletons: logger.info("Appending singleton sizes") nonsingleton_rows = sv_sizes.index.isin(mapping.index) singleton_sizes = sv_sizes[~nonsingleton_rows] singleton_stats = pd.DataFrame({'voxel_count': singleton_sizes}) singleton_stats['sv_count'] = np.uint32(1) body_stats = pd.concat((body_stats, singleton_stats)) if 0 in body_stats.index: body_stats.drop(0, inplace=True) logger.info("Sorting sizes") body_stats.index.name = 'body' body_stats.sort_values(['voxel_count', 'sv_count'], inplace=True, ascending=False) return body_stats
def measure_tbar_mito_distances(seg_src, mito_src, body, *, search_configs=DEFAULT_SEARCH_CONFIGS, npclient=None, tbars=None, valid_mitos=None): """ Search for the closest mito to each tbar in a list of tbars (or any set of points, really). FIXME: Rename this function. It works for more than just tbars. Args: seg_src: (server, uuid, instance) OR a flyemflows VolumeService Labelmap instance for the neuron segmentation. mito_src: (server, uuid, instance) OR a flyemflows VolumeService Labelmap instance for the mitochondria "supervoxel" segmentation -- not just the "masks". body: The body ID of interest, on which the tbars reside. search_configs: A list ``SearchConfig`` tuples. For each tbar, this function tries to locate a mitochondria within he given search radius, using data downloaded from the given scale. If the search fails and no mito can be found, the function tries again using the next search criteria in the list. The radius should always be specified in scale-0 units, regardless of the scale at which you want to perform the analysis. Additionally, the data will be downloaded at the specified scale, then downsampled (with continuity preserving downsampling) to a lower scale for analysis. Notes: - Scale 4 is too low-res. Stick with scale-3 or better. - Higher radius is more expensive, but some of that expense is recouped because all points that fall within the radius are analyzed at once. See _measure_tbar_mito_distances() implementation for details. dilation_radius_s0: If dilation_radius_s0 is non-zero, the segmentation will be "repaired" to close gaps, using a procedure involving a dilation of the given radius. dilation_exclusion_buffer_s0: We want to close small gaps in the segmentation, but only if we think they're really a gap in the actual segmentation, not if they are merely fingers of the same branch that are actually connected outside of our analysis volume. The dilation procedure tends to form such spurious connections near the volume border, so this parameter can be used to exclude a buffer (inner halo) near the border from dilation repairs. npclient: ``neuprint.Client`` to use when fetching the list of tbars that belong to the given body, unless you provide your own tbar points in the next argument. tbars: A DataFrame of tbar coordinates at least with columns ``['x', 'y', 'z']``. valid_mitos: If provided, only the listed mito IDs will be considered valid as search targets. Returns: DataFrame of tbar coordinates, mito distances, and mito coordinates. Points for which no nearby mito could be found (after trying all the given search_configs) will be marked with `done=False` in the results. """ assert search_configs[-1].is_final, "Last search config should be marked is_final" assert all([not cfg.is_final for cfg in search_configs[:-1]]), \ "Only the last search config should be marked is_final (no others)." # Fetch tbars if tbars is None: tbars = fetch_synapses(body, SC(type='pre', primary_only=True), client=npclient) tbars = initialize_results(body, tbars) if valid_mitos is None or len(valid_mitos) == 0: valid_mito_mapper = None else: valid_mitos = np.asarray(valid_mitos, dtype=np.uint64) valid_mito_mapper = LabelMapper(valid_mitos, valid_mitos) with tqdm_proxy(total=len(tbars)) as progress: for row in tbars.itertuples(): # can't use row.done -- itertuples might be out-of-sync done = tbars['done'].loc[row.Index] if done: continue loop_logger = None for cfg in search_configs: prefix = (f"({row.x}, {row.y}, {row.z}) [ds={cfg.download_scale} " f"as={cfg.analysis_scale} r={cfg.radius_s0:4} dil={cfg.dilation_radius_s0:2}] ") loop_logger = PrefixedLogger(logger, prefix) prev_num_done = tbars['done'].sum() _measure_tbar_mito_distances( seg_src, mito_src, body, tbars, row.Index, cfg, valid_mito_mapper, loop_logger) num_done = tbars['done'].sum() progress.update(num_done - prev_num_done) done = tbars['done'].loc[row.Index] if done: break if not cfg.is_final: loop_logger.info("Search failed for primary tbar. Trying next search config!") if not done: loop_logger.warning(f"Failed to find a nearby mito for tbar at point {(row.x, row.y, row.z)}") progress.update(1) failed = np.isinf(tbars['mito-distance']) succeeded = ~failed logger.info(f"Found mitos for {succeeded.sum()} tbars, failed for {failed.sum()} tbars") return tbars
def find_edges_in_brick(brick, closest_scale=None, subset_groups=[], subset_requirement=2): """ Find all pairs of adjacent labels in the given brick, and find the central-most point along the edge between them. (Edges to/from label 0 are discarded.) If closest_scale is not None, then non-adjacent pairs will be considered, according to a particular heuristic to decide which pairs to consider. Args: brick: A Brick to analyze closest_scale: If None, then consider direct (touching) adjacencies only. If not-None, then non-direct "adjacencies" (i.e. close-but-not-touching) are found. In that case `closest_scale` should be an integer >=0 indicating the scale at which the analysis will be performed. Higher scales are faster, but less precise. See ``neuclease.util.approximate_closest_approach`` for more information. subset_groups: A DataFrame with columns [label, group]. Only the given labels will be analyzed for adjacencies. Furthermore, edges (pairs) will only be returned if both labels in the edge are from the same group. subset_requirement: Whether or not both labels in each edge must be in subset_groups, or only one in each edge. (Currently, subset_requirement must be 2.) Returns: If the brick contains no edges at all (other than edges to label 0), return None. Otherwise, returns pd.DataFrame with columns: [label_a, label_b, forwardness, z, y, x, axis, edge_area, distance]. # fixme where label_a < label_b, 'axis' indicates which axis the edge crosses at the chosen coordinate, (z,y,x) is always given as the coordinate to the left/above/front of the edge (depending on the axis). If 'forwardness' is True, then the given coordinate falls on label_a and label_b is one voxel "after" it (to the right/below/behind the coordinate). Otherwise, the coordinate falls on label_b, and label_a is "after". And 'edge_area' is the total count of the voxels touching both labels. """ # Profiling indicates that query('... in ...') spends # most of its time in np.unique, believe it or not. # After looking at the implementation, I think it might help a # little if we sort the array first. brick_labels = np.sort(pd.unique(brick.volume.reshape(-1))) if (len(brick_labels) == 1) or (len(brick_labels) == 2 and (0 in brick_labels)): return None # brick is solid -- no possible edges # Drop labels that aren't even present subset_groups = subset_groups.query('label in @brick_labels').copy() # Drop groups that don't have enough members (usually 2) in this brick. group_counts = subset_groups['group'].value_counts() _kept_groups = group_counts.loc[(group_counts >= subset_requirement)].index subset_groups = subset_groups.query('group in @_kept_groups').copy() if len(subset_groups) == 0: return None # No possible edges to find in this brick. # Contruct a mapper that includes only the labels we'll keep. # (Other labels will be mapped to 0). # Also, the mapper converts to uint32 (required by _find_and_select_central_edges, # but also just good for RAM reasons). kept_labels = np.sort(np.unique(subset_groups['label'].values)) remapped_kept_labels = np.arange(1, len(kept_labels) + 1, dtype=np.uint32) mapper = LabelMapper(kept_labels, remapped_kept_labels) reverse_mapper = LabelMapper(remapped_kept_labels, kept_labels) # Construct RAG -- finds all edges in the volume, on a per-pixel basis. remapped_volume = mapper.apply_with_default(brick.volume, 0) brick.compress() remapped_subset_groups = subset_groups.copy() remapped_subset_groups['label'] = mapper.apply( subset_groups['label'].values) try: if closest_scale is None: best_edges_df = _find_and_select_central_edges( remapped_volume, remapped_subset_groups, subset_requirement) else: best_edges_df = _find_closest_approaches(remapped_volume, closest_scale, remapped_subset_groups) except: brick_name = f"{brick.logical_box[:,::-1].tolist()}" np.save(f'problematic-remapped-brick-{brick_name}.npy', remapped_volume) logger.error(f"Error in brick (XYZ): {brick_name}" ) # This will appear in the worker log. raise if best_edges_df is None: return None # Translate coordinates to global space best_edges_df.loc[:, ['za', 'ya', 'xa']] += brick.physical_box[0] best_edges_df.loc[:, ['zb', 'yb', 'xb']] += brick.physical_box[0] # Restore to original label set best_edges_df['label_a'] = reverse_mapper.apply( best_edges_df['label_a'].values) best_edges_df['label_b'] = reverse_mapper.apply( best_edges_df['label_b'].values) # Normalize swap_df_cols(best_edges_df, None, best_edges_df.eval('label_a > label_b'), ['a', 'b']) return best_edges_df
def cleave(edges, edge_weights, seeds_dict, node_ids=None, method='seeded-watershed'): """ Cleave the graph with the given edges and edge weights. If node_ids is given, it must contain a superset of the ids given in edges. Extra ids in node_ids (i.e. not mentioned in 'edges') will be included in the results as disconnected components. Args: edges: array, (E,2), uint32 edge_weights: array, (E,), float32 seeds_dict: dict, { seed_class : [node_id, node_id, ...] } node_ids: The complete list of node IDs in the graph. method: Either 'seeded-watershed' or 'agglomerative-clustering' Returns: CleaveResults, namedtuple with fields: (node_ids, output_labels, disconnected_components, contains_unlabeled_components) Where: node_ids: The graph node_ids. output_labels: array (N,), uint32 Agglomerated node labeling, in the same order as node_ids. disconnected_components: A set of seeds which ended up with more than one component in the result. contains_unlabeled_components: True if the input contains one or more disjoint components that were not seeded and thus not labeled during agglomeration. False otherwise. """ if node_ids is None: node_ids = pd.unique(edges.flat) node_ids.sort() assert isinstance(node_ids, np.ndarray) assert node_ids.dtype in (np.uint32, np.uint64) assert node_ids.ndim == 1 assert method in ('seeded-watershed', 'agglomerative-clustering') # Clean the edges (normalized form, no duplicates, no loops) edges.sort(axis=1) edges_df = pd.DataFrame({ 'u': edges[:, 0], 'v': edges[:, 1], 'weight': edge_weights }) edges_df.drop_duplicates(['u', 'v'], keep='last', inplace=True) edges_df = edges_df.query('u != v') edges = edges_df[['u', 'v']].values edge_weights = edges_df['weight'].values # Relabel node ids consecutively cons_node_ids = np.arange(len(node_ids), dtype=np.uint32) mapper = LabelMapper(node_ids, cons_node_ids) cons_edges = mapper.apply(edges) assert cons_edges.dtype == np.uint32 # Initialize sparse seed label array seed_labels = np.zeros_like(cons_node_ids) for seed_class, seed_nodes in seeds_dict.items(): seed_nodes = np.asarray(seed_nodes, dtype=np.uint64) mapper.apply_inplace(seed_nodes) seed_labels[seed_nodes] = seed_class if method == 'agglomerative-clustering': output_labels, disconnected_components, contains_unlabeled_components = agglomerative_clustering( cons_edges, edge_weights, seed_labels) elif method == 'seeded-watershed': output_labels, disconnected_components, contains_unlabeled_components = edge_weighted_watershed( cons_edges, edge_weights, seed_labels) return CleaveResults(node_ids, output_labels, disconnected_components, contains_unlabeled_components)
def find_missing_adjacencies(server, uuid, instance, body, known_edges, svs=None, search_distance=1, connect_non_adjacent=False): """ Given a body and an intra-body merge graph defined by the given list of "known" supervoxel-to-supervoxel edges within that body, 1. Determine whether or not all supervoxels in the body are connected by a single component within the given graph. If so, return immediately. 2. Attempt to augment the graph with additional edges based on supervoxel adjacencies in the segmentation from DVID. This is done by downloading the DVID labelindex to determine which blocks might contain adjacent supervoxels that could unify the graph, and then downloading those blocks (only) to search for the adjacencies. Notes: - Requires scikit-image (which, currently, is not otherwise listed as a dependency of neuclease's conda-recipe). - This function does not attempt to find ALL adjacencies between supervoxels; it stops looking as soon as they form a single connected component. - This function only considers two supervoxels "adjacent" if they are literally touching each other in the scale-0 segmentation. If there is a small gap between them, then they are not considered adjacent. - This function does not attempt to find inter-block adjacencies; only adjacencies within each block are detected. So, in pathological cases where a supervoxel is only adjacent to the rest of the body on a block-aligned edge, the adjacency will not be detected by this funciton. Args: server, uuid, instance: DVID segmentation labelmap instance body: ID of the body to inspect known_edges: ndarray (N,2), array of supervoxel pairs; known edges of the intra-body merge graph svs: Optional. The complete list of supervoxels that belong to this body, according to DVID. Providing this enhances performance in one important case: If the known_edges ALREADY constitute a single connected component which covers all supervoxels in the body, there is no need to download the labelindex. search_distance: If > 1, supervoxels are considered adjacent if they are within the given distance from each other, even if they aren't directly adjacent. connect_non_adjacent: If searching by adjacency failed to fully connect all supervoxels in the body into a single connected component, generate edges for supervoxels that are not adjacent, but merely are in the same block (if it helps unify the body). Returns: (new_edges, orig_num_cc, final_num_cc, block_tables), Where: new_edges are the new edges found via inspection of supervoxel adjacencies, orig_num_cc is the number of disjoint components in the given merge graph before this function runs, final_num_cc is the number of disjoint components after adding the new_edges, block_tables contains debug information about the adjacencies found in each block of analyzed segmentation Ideally, final_num_cc == 1, but in some cases the body's supervoxels may not be directly adjacent, or the adjacencies were not detected. (See notes above.) """ from skimage.morphology import dilation BLOCK_TABLE_COLS = ['z', 'y', 'x', 'sv_a', 'sv_b', 'cc_a', 'cc_b', 'detected', 'applied'] known_edges = np.asarray(known_edges, np.uint64) if svs is None: # We could compute the supervoxel list ourselves from # the labelindex, but dvid can do it faster. svs = fetch_supervoxels_for_body(server, uuid, instance, body) cc = connected_components_nonconsecutive(known_edges, svs) orig_num_cc = final_num_cc = cc.max()+1 if orig_num_cc == 1: return np.zeros((0,2), np.uint64), orig_num_cc, final_num_cc, pd.DataFrame(columns=BLOCK_TABLE_COLS) labelindex = fetch_labelindex(server, uuid, instance, body, format='protobuf') encoded_block_coords = np.fromiter(labelindex.blocks.keys(), np.uint64, len(labelindex.blocks)) coords_zyx = decode_labelindex_blocks(encoded_block_coords) cc_mapper = LabelMapper(svs, cc) svs_set = set(svs) sv_adj_found = [] cc_adj_found = set() block_tables = {} searched_block_svs = {} for coord_zyx, sv_counts in zip(coords_zyx, labelindex.blocks.values()): # Given the supervoxels in this block, what CC adjacencies # MIGHT we find if we were to inspect the segmentation? block_svs = np.fromiter(sv_counts.counts.keys(), np.uint64) block_ccs = cc_mapper.apply(block_svs) possible_cc_adjacencies = set(combinations( set(block_ccs), 2 )) # We only aim to find (at most) a single link between each CC pair. # That is, we don't care about adjacencies between CC that we've already linked so far. possible_cc_adjacencies -= cc_adj_found if not possible_cc_adjacencies: continue searched_block_svs[(*coord_zyx,)] = block_svs # Not used in the search; only returned for debug purposes. try: block_adj_table = _init_adj_table(coord_zyx, block_svs, cc_mapper) except: raise block_vol = fetch_block_vol(server, uuid, instance, coord_zyx, svs_set) if search_distance > 0: # It would be nice to do a proper spherical dilation, # but apparently dilation() is special-cased to be WAY # faster with a square structuring element, and we prefer # speed over cleaner dilation. # footprint = skimage.morphology.ball(dilation) radius = search_distance//2 footprint = np.ones(3*(1+2*radius,), np.uint8) dilated_block_vol = dilation(block_vol, footprint) # Since dilation is a max-filter, we might have accidentally # erased small, low-valued supervoxels, erasing the adjacendies. # Overlay the original volume to make sure they still count. block_vol = np.where(block_vol, block_vol, dilated_block_vol) sv_adjacencies = compute_label_adjacencies(block_vol) sv_adjacencies['cc_a'] = cc_mapper.apply( sv_adjacencies['sv_a'].values ) sv_adjacencies['cc_b'] = cc_mapper.apply( sv_adjacencies['sv_b'].values ) found_new_adj = False for row in sv_adjacencies.itertuples(index=False): if (row.cc_a != row.cc_b): sv_adj = (row.sv_a, row.sv_b) cc_adj = (row.cc_a, row.cc_b) # Normalize if row.cc_a > row.cc_b: cc_adj = (row.cc_b, row.cc_a) if row.sv_a > row.sv_b: sv_adj = (row.sv_b, row.sv_a) block_adj_table.loc[sv_adj, 'detected'] = True if cc_adj not in cc_adj_found: found_new_adj = True cc_adj_found.add( cc_adj ) sv_adj_found.append( sv_adj ) block_adj_table.loc[sv_adj, 'applied'] = True block_tables[(*coord_zyx,)] = block_adj_table # If we made at least one change and we've # finally unified all components, then we're done. if found_new_adj: final_num_cc = connected_components(np.array(list(cc_adj_found), np.uint64), orig_num_cc).max()+1 if final_num_cc == 1: break # If we couldn't connect everything via direct adjacencies, # we can just add edges for any supervoxels that share a block. if final_num_cc > 1 and connect_non_adjacent: for coord_zyx, block_svs in searched_block_svs.items(): block_ccs = cc_mapper.apply(block_svs) # We only need one SV per connected component, # so load them into a dict. selected_svs = dict(zip(block_ccs, block_svs)) for (sv_a, sv_b) in combinations(sorted(selected_svs.values()), 2): (cc_a, cc_b) = cc_mapper.apply(np.array([sv_a, sv_b], np.uint64)) if cc_a > cc_b: cc_a, cc_b = cc_b, cc_a if (cc_a, cc_b) not in cc_adj_found: if sv_a > sv_b: sv_a, sv_b = sv_b, sv_a cc_adj_found.add( (cc_a, cc_b) ) sv_adj_found.append( (sv_a, sv_b) ) block_tables[(*coord_zyx,)].loc[(sv_a, sv_b), 'applied'] = True final_num_cc = connected_components(np.array(list(cc_adj_found), np.uint64), orig_num_cc).max()+1 if len(block_tables) == 0: block_table = pd.DataFrame(columns=BLOCK_TABLE_COLS) else: block_table = pd.concat(block_tables.values(), sort=False).reset_index() block_table = block_table[BLOCK_TABLE_COLS] new_edges = np.array(sv_adj_found, np.uint64) return new_edges, int(orig_num_cc), int(final_num_cc), block_table
def update_merge_table(server, uuid, instance, table_df, complete_mapping=None, split_mapping=None): """ Given a merge table (such as a focused proofreading decision table), find rows whose supervoxels no longer exist in the given instance (due to splits). For those invalid rows, determine the new supervoxel and body ID at the given coordinates to determine the updated supervoxel/body IDs. Updates (in-place) the supervoxel and body columns. Note: If any coordinate appears to be misplaced (i.e. the supervoxel ID at the coordinate is not a descendant of the listed supervoxel), the supervoxel is left unchanged and the body is mapped to 0. Args: server, uuid, instance: Table will be updated with respect to the given segmentation instance info table_df: DataFrame with SV columns and coordinate columns ('xa', 'ya', 'za', etc.) complete_mapping: Optional. Will be fetched if not provided. Must be the complete mapping as returned by fetch_complete_mappings(..., include_retired=True) split_mapping: Optional. Will be fetched if not provided. A mapping from supervoxel fragments to root supervoxel IDs. Returns: None. (The table is modified in place.) """ seg_info = server, uuid, instance # Ensure proper table columns/dtypes if 'id_a' in table_df.columns: col_sv_a, col_sv_b = 'id_a', 'id_b' elif 'sv_a' in table_df.columns: col_sv_a, col_sv_b = 'sv_a', 'sv_b' else: raise RuntimeError("table has no sv columns") assert set([col_sv_a, col_sv_b, 'xa', 'ya', 'za', 'xb', 'yb', 'zb']).issubset(table_df.columns) for col in ['xa', 'ya', 'za', 'xb', 'yb', 'zb']: table_df[col] = table_df[col].fillna(0).astype(np.int32) # Construct mappings if necessary kafka_msgs = None if complete_mapping is None or split_mapping is None: kafka_msgs = read_kafka_messages(*seg_info) if complete_mapping is None: complete_mapping = fetch_complete_mappings(*seg_info, include_retired=True, kafka_msgs=kafka_msgs) complete_mapper = LabelMapper(complete_mapping.index.values, complete_mapping.values) if split_mapping is None: split_events = fetch_supervoxel_splits(*seg_info) split_mapping = split_events_to_mapping(split_events, leaves_only=False) split_mapper = LabelMapper(split_mapping.index.values, split_mapping.values) # Apply up-to-date body mapping # (Retired supervoxels will map to body 0) table_df['body_a'] = complete_mapper.apply(table_df[col_sv_a].values, True) table_df['body_b'] = complete_mapper.apply(table_df[col_sv_b].values, True) successfully_updated = 0 failed_update = 0 # Update the rows with invalid bodies. for index in tqdm_proxy( table_df.query('body_a == 0 or body_b == 0').index): def update_row(index, col_body, col_sv, col_z, col_y, col_x): nonlocal successfully_updated, failed_update # Extract from table coord = table_df.loc[index, [col_z, col_y, col_x]] old_sv = table_df.loc[index, col_sv] # Check current SV in the volume new_sv = fetch_label(*seg_info, coord, supervoxels=True) # The old/new SV must have come from the same root SV. # If not, the coordinate must be misplaced and can't be used here. svs = np.asarray([new_sv, old_sv], np.uint64) mapped_svs = split_mapper.apply(svs, True) if mapped_svs[0] != mapped_svs[1]: failed_update += 1 else: body = complete_mapper.apply(np.array([new_sv], np.uint64))[0] table_df.loc[index, col_body] = body table_df.loc[index, col_sv] = new_sv successfully_updated += 1 # id_a/body_a if table_df.loc[index, 'body_a'] == 0: update_row(index, 'body_a', col_sv_a, 'za', 'ya', 'xa') # id_b/body_b if table_df.loc[index, 'body_b'] == 0: update_row(index, 'body_b', col_sv_b, 'zb', 'yb', 'xb') logger.info( f"Updated {successfully_updated}, failed to update {failed_update}")
def extract_important_merges(speculative_merge_tables, important_bodies, body_mapping=None, mapping_instance_info=None, drop_duplicate_body_pairs=False): assert (body_mapping is None) ^ (mapping_instance_info is None), \ "You must set either body_mapping or mapping_instance_info (but not both)" if mapping_instance_info is not None: body_mapping = fetch_complete_mappings(mapping_instance_info) assert isinstance(body_mapping, pd.Series) mapper = LabelMapper(body_mapping.index.values, body_mapping.values) # pd.Index is faster than builtin set for large sets important_bodies = pd.Index(important_bodies) results = [] for spec_merge_table_df in tqdm(speculative_merge_tables): logger.info(f"Processing table with {len(spec_merge_table_df)} rows") with Timer("Applying mapping", logger): spec_merge_table_df['body_a'] = mapper.apply( spec_merge_table_df['id_a'].values, allow_unmapped=True) spec_merge_table_df['body_b'] = mapper.apply( spec_merge_table_df['id_b'].values, allow_unmapped=True) with Timer("Dropping identity merges", logger): orig_size = len(spec_merge_table_df) spec_merge_table_df.query('body_a != body_b', inplace=True) logger.info( f"Dropped {orig_size-len(spec_merge_table_df)}/{orig_size} edges." ) with Timer("Normalizing edges", logger): # Normalize for body ID, not SV ID # (This involves a lot of copying, but you've got plenty of RAM, right?) a_cols = list( filter(lambda s: s[-1] == 'a', spec_merge_table_df.columns)) b_cols = list( filter(lambda s: s[-1] == 'b', spec_merge_table_df.columns)) spec_merge_table = spec_merge_table_df.to_records(index=False) normalize_recarray_inplace(spec_merge_table, 'body_a', 'body_b', a_cols, b_cols) spec_merge_table_df = pd.DataFrame(spec_merge_table) with Timer("Filtering edges", logger): q = 'body_a in @important_bodies and body_b in @important_bodies' orig_len = len(spec_merge_table_df) spec_merge_table_df = spec_merge_table_df.query(q, inplace=True) logger.info( f"Filtered out {orig_len - len(spec_merge_table_df)} non-important edges." ) if drop_duplicate_body_pairs: with Timer("Dropping duplicate body pairs", logger): orig_len = len(spec_merge_table_df) spec_merge_table_df.drop_duplicates(['body_a', 'body_b'], inplace=True) logger.info( f"Dropped {orig_len - len(spec_merge_table_df)} duplicate body pairs" ) results.append(spec_merge_table_df) return results
def split_disconnected_bodies(labels_orig): """ Produces 3D volume split into connected components. This function identifies bodies that are the same label but are not connected. It splits these bodies and produces a dict that maps these newly split bodies to the original body label. Special exception: Segments with label 0 are not relabeled. Note: Requires scikit-image (which, currently, is not otherwise listed as a dependency of neuclease's conda-recipe). Args: labels_orig (numpy.array): 3D array of labels Returns: (labels_new, new_to_orig) labels_new: The partially relabeled array. Segments that were not split will keep their original IDs. Among split segments, the largest 'child' of a split segment retains the original ID. The smaller segments are assigned new labels in the range (N+1)..(N+1+S) where N is highest original label and S is the number of new segments after splitting. new_to_orig: A pseudo-minimal (but not quite minimal) mapping of labels (N+1)..(N+1+S) -> some subset of (1..N), which maps new segment IDs to the segments they came from. Segments that were not split at all are not mentioned in this mapping, for split segments, every mapping pair for the split is returned, including the k->k (identity) pair. new_unique_labels: An array of all label IDs in the newly relabeled volume. The original label set can be selected via: new_unique_labels[new_unique_labels < min(new_to_orig.keys())] """ import skimage.measure as skm # Compute connected components and cast back to original dtype labels_cc = skm.label(labels_orig, background=0, connectivity=1) assert labels_cc.dtype == np.int64 if labels_orig.dtype == np.uint64: labels_cc = labels_cc.view(np.uint64) else: labels_cc = labels_cc.astype(labels_orig.dtype, copy=False) # Find overlapping segments between orig and CC volumes overlap_table_df = contingency_table(labels_orig, labels_cc).reset_index() assert overlap_table_df.columns.tolist() == [ 'left', 'right', 'voxel_count' ] overlap_table_df.columns = ['orig', 'cc', 'voxels'] overlap_table_df.sort_values('voxels', ascending=False, inplace=True) # If a label in 'orig' is duplicated, it has multiple components in labels_cc. # The largest component gets to keep the original ID; # the other components must take on new values. # (The new values must not conflict with any of the IDs in the original, so start at orig_max+1) new_cc_pos = overlap_table_df['orig'].duplicated() orig_max = overlap_table_df['orig'].max() new_cc_values = np.arange(orig_max + 1, orig_max + 1 + new_cc_pos.sum(), dtype=labels_orig.dtype) overlap_table_df['final_cc'] = overlap_table_df['orig'].copy() overlap_table_df.loc[new_cc_pos, 'final_cc'] = new_cc_values # Relabel the CC volume to use the 'final_cc' labels mapper = LabelMapper(overlap_table_df['cc'].values, overlap_table_df['final_cc'].values) mapper.apply_inplace(labels_cc) # Generate the mapping that could (if desired) convert the new # volume into the original one, as described in the docstring above. emitted_mapping_rows = overlap_table_df['orig'].duplicated(keep=False) emitted_mapping_pairs = overlap_table_df.loc[emitted_mapping_rows, ['final_cc', 'orig']].values new_to_orig = dict(emitted_mapping_pairs) new_unique_labels = pd.unique(overlap_table_df['final_cc'].values) new_unique_labels = new_unique_labels.astype( overlap_table_df['final_cc'].dtype) new_unique_labels.sort() return labels_cc, new_to_orig, new_unique_labels
def sparse_brick_coords_for_labels(self, labels, clip=True): """ Return a DataFrame indicating the brick coordinates (starting corner) that encompass the given labels. Args: labels: A list of body IDs (if ``self.supervoxels`` is False), or supervoxel IDs (if ``self.supervoxels`` is True). clip: If True, filter the results to exclude any coordinates that fall outside this service's bounding-box. Otherwise, all brick coordinates that encompass the given labels will be returned, whether or not they fall within the bounding box. Returns: DataFrame with columns [z,y,x,label], where z,y,x represents the starting corner (in full-res coordinates) of a brick that contains the label. """ assert not isinstance(labels, set), "Pass labels as a list or array, not a set" labels = pd.unique(labels) is_supervoxels = self.supervoxels brick_shape = self.preferred_message_shape assert (brick_shape % self.block_width == 0).all(), \ ("Brick shape ('preferred-message-shape') must be a multiple of the " f"block width ({self.block_width}) in all dimensions, not {brick_shape}") bad_labels = [] if not is_supervoxels: # No supervoxel filtering. # Sort by body, since that should be slightly nicer for dvid performance. bodies_and_svs = {label: None for label in sorted(labels)} else: # Arbitrary heuristic for whether to do the body-lookups on DVID or on the client. if len(labels) < 100_000: # If we're only dealing with a few supervoxels, # ask dvid to map them to bodies for us. mapping = fetch_mapping(*self.instance_triple, labels, as_series=True) else: # If we're dealing with a lot of supervoxels, ask for # the entire mapping, and look up the bodies ourselves. complete_mapping = fetch_mappings(*self.instance_triple) mapper = LabelMapper(complete_mapping.index.values, complete_mapping.values) labels = np.asarray(labels, np.uint64) bodies = mapper.apply(labels, True) mapping = pd.Series(index=labels, data=bodies, name='body') mapping.index.rename('sv', inplace=True) bad_svs = mapping[mapping == 0] bad_labels.extend(bad_svs.index.tolist()) # Group by body mapping = mapping[mapping != 0] grouped_svs = mapping.reset_index().groupby('body').agg( {'sv': list})['sv'] # Sort by body, since that should be slightly nicer for dvid performance. bodies_and_svs = grouped_svs.sort_index().to_dict() # Extract these to avoid pickling 'self' (just for speed) server, uuid, instance = self.instance_triple if self._use_resource_manager_for_sparse_coords: mgr = self.resource_manager_client else: mgr = ResourceManagerClient("", 0) def fetch_brick_coords(body, supervoxel_subset): """ Fetch the block coordinates for the given body, filter them for the given supervoxels (if any), and convert the block coordinates to brick coordinates. """ assert is_supervoxels or supervoxel_subset is None try: with mgr.access_context(server, True, 1, 1): labelindex = fetch_labelindex(server, uuid, instance, body, 'protobuf') coords_df = convert_labelindex_to_pandas(labelindex).blocks except HTTPError as ex: if (ex.response is not None and ex.response.status_code == 404): return (body, None) raise except RuntimeError as ex: if 'does not map to any body' in str(ex): return (body, None) raise if len(coords_df) == 0: return (body, None) if is_supervoxels: supervoxel_subset = set(supervoxel_subset) coords_df = coords_df.query('sv in @supervoxel_subset').copy() coords_df[['z', 'y', 'x']] //= brick_shape coords_df['body'] = np.uint64(body) coords_df.drop_duplicates(inplace=True) return (body, coords_df) def fetch_and_concatenate_brick_coords(bodies_and_supervoxels): """ To reduce the number of tiny DataFrames collected to the driver, it's best to concatenate the partitions first, on the workers, rather than a straightforward call to starmap(fetch_brick_coords). Hence, this function that consolidates each partition. """ bad_bodies = [] coord_dfs = [] for (body, supervoxel_subset) in bodies_and_supervoxels: _, coords_df = fetch_brick_coords(body, supervoxel_subset) if coords_df is None: bad_bodies.append(body) else: coord_dfs.append(coords_df) del coords_df if coord_dfs: return [(pd.concat(coord_dfs, ignore_index=True), bad_bodies)] else: return [(None, bad_bodies)] with Timer( f"Fetching coarse sparsevols for {len(labels)} labels ({len(bodies_and_svs)} bodies)", logger=logger): import dask.bag as db coords_and_bad_bodies = ( db.from_sequence( bodies_and_svs.items(), npartitions=4096 ) # Instead of fancy heuristics, just pick 4096 .map_partitions(fetch_and_concatenate_brick_coords).compute()) coords_df_partitions, bad_body_partitions = zip(*coords_and_bad_bodies) for body in chain(*bad_body_partitions): if is_supervoxels: bad_labels.extend(bodies_and_svs[body]) else: bad_labels.append(body) if bad_labels: name = 'sv' if is_supervoxels else 'body' pd.Series(bad_labels, name=name).to_csv('labels-without-sparsevols.csv', index=False, header=True) if len(bad_labels) < 100: msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels: {bad_labels}" else: msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels. See labels-without-sparsevols.csv" logger.warning(msg) coords_df_partitions = list( filter(lambda df: df is not None, coords_df_partitions)) if len(coords_df_partitions) == 0: raise RuntimeError( "Could not find bricks for any of the given labels") coords_df = pd.concat(coords_df_partitions, ignore_index=True) if self.supervoxels: coords_df['label'] = coords_df['sv'] else: coords_df['label'] = coords_df['body'] coords_df.drop_duplicates(['z', 'y', 'x', 'label'], inplace=True) coords_df[['z', 'y', 'x']] *= brick_shape if clip: # Keep if the last pixel in the brick is to the right of the bounding-box start # and the first pixel in the brick is to the left of the bounding-box stop keep = (coords_df[['z', 'y', 'x']] + brick_shape > self.bounding_box_zyx[0]).all(axis=1) keep &= (coords_df[['z', 'y', 'x']] < self.bounding_box_zyx[1]).all(axis=1) coords_df = coords_df.loc[keep] return coords_df[['z', 'y', 'x', 'label']]
def distance_transform_watershed(mask, smoothing=0.0, seed_mask=None, seed_labels=None, flood_from='interior'): """ Compute a watershed over the distance transform within a mask. You can either compute the watershed from inside-to-outside or outside-to-inside. For the former, the watershed is seeded from the most interior points, and the distance transform is inverted so the watershed can proceed from low to high as usual. For the latter, the distance transform is seeded from the voxels immediately outside the mask, using labels as found in the seed_labels volume. In this mode, the results effectively tell you which exterior segment (in the seed volume) is closest to any given point within the interior of the mask. Or you can provide your own seeds if you think you know what you're doing. Args: mask: Only the masked area will be processed smoothing: If non-zero, run gaussian smoothing on the distance transform with the given sigma before defining seed points or running the watershed. seed_mask: seed_labels: flood_from: Returns: dt, labeled_seeds, ws, max_id Notes: This function provides a subset of the options that can be found in other libraries, such as: - https://github.com/ilastik/wsdt/blob/3709b27/wsdt/wsDtSegmentation.py#L26 - https://github.com/constantinpape/elf/blob/34f5e76/elf/segmentation/watershed.py#L69 This uses vigra's efficient true euclidean distance transform, which is superior to distance transform approximations, e.g. as found in https://imagej.net/Distance_Transform_Watershed The imglib2 distance transform uses the same algorithm as vigra: https://github.com/imglib/imglib2-algorithm/tree/master/src/main/java/net/imglib2/algorithm/morphology/distance http://www.theoryofcomputing.org/articles/v008a019/ """ mask = mask.astype(bool, copy=False) mask = vigra.taggedView(mask, 'zyx').astype(np.uint32) imask = np.logical_not(mask) outer_edge_mask = binary_edge_mask(mask, 'outer') assert flood_from in ('interior', 'exterior') if flood_from == 'interior': # Negate the distance transform result, # since watershed must start at minima, not maxima. # Convert to uint8 to benefit from 'turbo' watershed mode # (uses a bucket queue). dt = distance_transform(mask, False, smoothing, negate=True) if seed_mask is None: # requires float32 input for some reason if dt.ndim == 2: minima = vigra.analysis.localMinima(dt, marker=np.nan, neighborhood=8, allowAtBorder=True, allowPlateaus=False) else: minima = vigra.analysis.localMinima3D(dt, marker=np.nan, neighborhood=26, allowAtBorder=True, allowPlateaus=False) seed_mask = np.isnan(minima) del minima dt = normalize_image_range(dt, np.uint8) else: if seed_labels is None and seed_mask is None: logger.warning( "Without providing your own seed mask and/or seed labels, " "the watershed operation will simply be the same as a " "connected components operation. Is that what you meant?") if seed_mask is None: seed_mask = outer_edge_mask.copy() # Dilate the mask once more. outer_edge_mask[:] |= binary_edge_mask(outer_edge_mask | mask, 'outer') dt = distance_transform(mask, False, smoothing, negate=False) dt = normalize_image_range(dt, np.uint8) if seed_labels is None: seed_mask = vigra.taggedView(seed_mask, 'zyx') labeled_seeds = vigra.analysis.labelMultiArrayWithBackground( seed_mask.view('uint8')) else: labeled_seeds = np.where(seed_mask, seed_labels, 0) # Make sure seed_mask matches labeled_seeds, # Even if some seed_labels were zero-valued seed_mask = (labeled_seeds != 0) # Must remap to uint32 before calling vigra's watershed. seed_mapper = None seed_values = None if labeled_seeds.dtype in (np.uint64, np.int64): labeled_seeds = labeled_seeds.astype(np.uint64) seed_values = np.sort(pd.unique(labeled_seeds.ravel())) if seed_values[0] != 0: seed_values = np.array([0] + list(seed_values), np.uint64) assert seed_values.dtype == np.uint64 assert labeled_seeds.dtype == np.uint64 ws_seed_values = np.arange(len(seed_values), dtype=np.uint32) seed_mapper = LabelMapper(seed_values, ws_seed_values) ws_seeds = seed_mapper.apply(labeled_seeds) assert ws_seeds.dtype == np.uint32 else: ws_seeds = labeled_seeds # Fill the non-masked area with one big seed, # except for a thin border around the mask. # This saves time in the watershed step, # since these voxels now don't need to be # consumed in the watershed. dummy_seed = ws_seeds.max() + np.uint32(1) ws_seeds[np.logical_not(mask | outer_edge_mask)] = dummy_seed ws_seeds[outer_edge_mask & ~seed_mask] = 0 dt[outer_edge_mask] = 255 dt[seed_mask] = 0 dt = vigra.taggedView(dt, 'zyx') ws_seeds = vigra.taggedView(ws_seeds, 'zyx') ws, max_id = vigra.analysis.watershedsNew(dt, seeds=ws_seeds, method='Turbo') # Areas that were unreachable without crossing over the border # could end up with the dummy seed. # We treat such areas as if they are outside of the mask. ws[ws == dummy_seed] = 0 ws_seeds[imask] = 0 # If we converted from uint64 to uint32 to perform the watershed, # convert back before returning. if seed_mapper is not None: ws = seed_values[ws] return dt, labeled_seeds, ws
def agglomerative_clustering(cleaned_edges, edge_weights, seed_labels, node_sizes=None, num_classes=None): """ Run vigra.graphs.agglomerativeClustering() on the given graph with N nodes and E edges. The graph node IDs must be consecutive, starting with zero, dtype=np.uint32 Args: cleaned_edges: array, (E,2), uint32 Node IDs should be consecutive (more-or-less). To avoid segfaults: - Must not contain duplicates. - Must not contain 'loops' (no self-edges). edge_weights: array, (E,), float32 seed_labels: array (N,), uint32 All un-seeded nodes should be marked as 0. Returns: (output_labels, disconnected_components, contains_unlabeled_components) Where: output_labels: array (N,), uint32 Agglomerated node labeling. disconnected_components: A set of seeds which ended up with more than one component in the result. contains_unlabeled_components: True if the input contains one or more disjoint components that were not seeded and thus not labeled during agglomeration. False otherwise. """ # # Notes: # # vigra.graphs.agglomerativeClustering() is somewhat sophisticated. # # During agglomeration, edges are selected for 'contraction' and the corresponding nodes are merged. # The newly merged node contains the superset of the edges from its constituent nodes, with duplicate # edges combined via weighted average according to their relative 'edgeLengths'. # # The edge weights used in the optimization are adjusted dynamically after every merge. # The dynamic edge weight is computed as a weighted average of it's original 'edgeWeight' # and the similarity of its two nodes (by distance between 'nodeFeatures', # using the distance measure defined by 'metric'). # # The relative importances of the original edgeWeight and the node similarity is determined by 'beta'. # To ignore node feature similarity completely, use beta=0.0. To ignore edgeWeights completely, use beta=1.0. # # After computing that weighted average, the dynamic edge weight is then scaled by a 'Ward factor', # which seems to give priority to edges that connect smaller components. # The importance of the 'Ward factor' is determined by 'wardness'. To disable it, set wardness=0.0. # # # For reference, here are the relevant lines from vigra/hierarchical_clustering.hxx: # # ValueType getEdgeWeight(const Edge & e){ # ... # const ValueType wardFac = 2.0 / ( 1.0/std::pow(sizeU,wardness_) + 1/std::pow(sizeV,wardness_) ); # const ValueType fromEdgeIndicator = edgeIndicatorMap_[ee]; # ValueType fromNodeDist = metric_(nodeFeatureMap_[uu],nodeFeatureMap_[vv]); # ValueType totalWeight = ((1.0-beta_)*fromEdgeIndicator + beta_*fromNodeDist)*wardFac; # ... # } # # # To achieve the "most naive" version of hierarchical clustering, # i.e. based purely on pre-computed edge weights (and no node features), # use beta=0.0, wardness=0.0. # # (Ideally, we would also set nodeSizes=[0,...], but unfortunately, # setting nodeSizes of 0.0 seems to result in strange bugs. # Therefore, we can't avoid the affect of using cumulative node size during the agglomeration.) assert cleaned_edges.dtype == np.uint32 assert cleaned_edges.ndim == 2 assert cleaned_edges.shape[1] == 2 assert edge_weights.shape == (len(cleaned_edges), ) assert seed_labels.ndim == 1 assert cleaned_edges.max() < len(seed_labels) # Initialize graph # (These params merely reserve RAM in advance. They don't initialize actual graph state.) g = vg.AdjacencyListGraph(len(seed_labels), len(cleaned_edges)) # Make sure there are the correct number of nodes. # (Internally, AdjacencyListGraph ensures contiguous nodes are created # up to the max id it has seen, so adding the max node is sufficient to # ensure all nodes are present.) g.addNode(len(seed_labels) - 1) # Insert edges. g.addEdges(cleaned_edges) if num_classes is None: num_classes = len(set(pd.unique(seed_labels)) - set([0])) output_labels = vg.agglomerativeClustering( graph=g, edgeWeights=edge_weights, #edgeLengths=..., #nodeFeatures=..., #nodeSizes=..., nodeLabels=seed_labels, nodeNumStop=num_classes, beta=0.0, #metric='l1', wardness=0.0) # For some reason, the output labels do not necessarily # have the same values as the seed labels. We have to relabel them ourselves. # # Furthermore, there are some special cases to consider: # # 1. It is possible that some seeds will map to disconnected components, # if one of the following is true: # - The input contains disconnected components with identical seeds # - The input contains no disconnected components, but it failed to # connect two components with identical seeds (some other seeded # component ended up blocking the path between the two disconnected # components). # In those cases, we should ensure that the disconnected components are # still labeled with the right input seed, but add the seed to the returned # 'disconnected components' set. # # 2. If the input contains any disconnected components that were NOT seeded, # we should relabel those as 0, and return contains_unlabeled_components=True # Get mapping of seeds -> corresponding agg values. # (There might be more than one agg value for a given seed, as explained in point 1 above) df = pd.DataFrame({'seed': seed_labels, 'agg': output_labels}) df.drop_duplicates(inplace=True) # How many unique agg values are there for each seed class? seed_mapping_df = df.query('seed != 0') seed_component_counts = seed_mapping_df.groupby(['seed' ]).agg({'agg': 'size'}) seed_component_counts.columns = ['component_count'] # More than one agg value for a seed class implies that it wasn't fully agglomerated. disconnected_components = set( seed_component_counts.query('component_count > 1').index) # If there are 'extra' agg values (not corresponding to seeds), # then some component(s) are unlabeled. (Point 2 above.) _seeded_agg_ids = set(seed_mapping_df['agg']) nonseeded_agg_ids = df.query('agg not in @_seeded_agg_ids')['agg'] contains_unlabeled_components = (len(nonseeded_agg_ids) > 0) # Map from output agg values back to original seed classes. agg_values = seed_mapping_df['agg'].values seed_values = seed_mapping_df['seed'].values if len(nonseeded_agg_ids) > 0: nonseeded_agg_ids = np.fromiter(nonseeded_agg_ids, np.uint32) agg_values = np.concatenate((agg_values, nonseeded_agg_ids)) seed_values = np.concatenate( (seed_values, np.zeros((len(nonseeded_agg_ids), ), np.uint32))) mapper = LabelMapper(agg_values, seed_values) mapper.apply_inplace(output_labels) return CleaveResults(output_labels, disconnected_components, contains_unlabeled_components)
def compute_focused_paths(server, uuid, instance, original_mapping, important_bodies, speculative_merge_tables, split_mapping=None, max_depth=10, stop_after_endpoint_num=None, return_after_setup=False): from ..merge_graph import LabelmapMergeGraph with Timer("Loading speculative merge graph", logger): merge_graph = LabelmapMergeGraph(speculative_merge_tables, uuid) if split_mapping is not None: _bad_edges = merge_graph.append_edges_for_split_supervoxels( split_mapping, server, uuid, instance, parent_sv_handling='drop') merge_table_df = merge_graph.merge_table_df if isinstance(original_mapping, str): with Timer("Loading mapping", logger): original_mapping = load_mapping(original_mapping) else: assert isinstance(original_mapping, pd.Series) with Timer("Applying mapping", logger): mapper = LabelMapper(original_mapping.index.values, original_mapping.values) merge_table_df['body_a'] = mapper.apply(merge_table_df['id_a'].values, allow_unmapped=True) merge_table_df['body_b'] = mapper.apply(merge_table_df['id_b'].values, allow_unmapped=True) if isinstance(important_bodies, str): with Timer("Reading importances", logger): important_bodies = read_csv_col(important_bodies).values important_bodies = pd.Index(important_bodies, dtype=np.uint64) with Timer("Assigning importances", logger): merge_table_df['important_a'] = merge_table_df['body_a'].isin( important_bodies) merge_table_df['important_b'] = merge_table_df['body_b'].isin( important_bodies) with Timer("Discarding merged edges within 'important' bodies ", logger): size_before = len(merge_table_df) merge_table_df.query( '(body_a != body_b) and not (important_a and important_b)', inplace=True) size_after = len(merge_table_df) logger.info( f"Discarded {size_before - size_after} edges, ended with {len(merge_table_df)} edges" ) edges = merge_table_df[['id_a', 'id_b']].values assert edges.dtype == np.uint64 if return_after_setup: logger.info("Returning setup instead of searching for paths") return (edges, original_mapping, important_bodies, max_depth, stop_after_endpoint_num) logger.info( f"Finding paths among {len(important_bodies)} important bodies") # FIXME: Need to augment original_mapping with rows for single-sv bodies that are 'important' all_paths = find_all_paths(edges, original_mapping, important_bodies, max_depth, stop_after_endpoint_num) return all_paths
def edge_weighted_watershed(cleaned_edges, edge_weights, seed_labels): """ Run nifty.graph.edgeWeightedWatershedsSegmentation() on the given graph with N nodes and E edges. The graph node IDs must be consecutive, starting with zero, dtype=np.uint32 Args: cleaned_edges: array, (E,2), uint32 Node IDs should be consecutive (more-or-less). To avoid segfaults: - Must not contain duplicates. - Must not contain 'loops' (no self-edges). edge_weights: array, (E,), float32 seed_labels: array (N,), uint32 All un-seeded nodes should be marked as 0. Returns: (output_labels, disconnected_components, contains_unlabeled_components) Where: output_labels: array (N,), uint32 Agglomerated node labeling. disconnected_components: A set of seeds which ended up with more than one component in the result. contains_unlabeled_components: True if the input contains one or more disjoint components that were not seeded and thus not labeled during agglomeration. False otherwise. """ assert cleaned_edges.dtype == np.uint32 assert cleaned_edges.ndim == 2 assert cleaned_edges.shape[1] == 2 assert edge_weights.shape == (len(cleaned_edges), ) assert seed_labels.ndim == 1 assert cleaned_edges.max() < len(seed_labels) g = nifty.graph.UndirectedGraph(len(seed_labels)) g.insertEdges(cleaned_edges) output_labels = nifty.graph.edgeWeightedWatershedsSegmentation( g, seed_labels, edge_weights) contains_unlabeled_components = not output_labels.all() mapper = LabelMapper(np.arange(output_labels.shape[0], dtype=np.uint32), output_labels) labeled_edges = mapper.apply(cleaned_edges) preserved_edges = cleaned_edges[labeled_edges[:, 0] == labeled_edges[:, 1]] component_labels = connected_components(preserved_edges, len(output_labels)) assert len(component_labels) == len(output_labels) == len(seed_labels) cc_df = pd.DataFrame({'label': output_labels, 'cc': component_labels}) cc_counts = cc_df.groupby('label').nunique()['cc'] disconnected_cc_counts = cc_counts[cc_counts > 1] disconnected_components = set(disconnected_cc_counts.index) - set([0]) return output_labels, disconnected_components, contains_unlabeled_components