Beispiel #1
0
def connected_components_nonconsecutive(edges, node_ids):
    """
    Run connected components on the graph encoded by 'edges' and node_ids.
    All nodes from edges must be present in node_ids.
    (Additional nodes are permitted, in which case each is assigned its own CC.)
    The node_ids do not need to be consecutive, and can have arbitrarily large values.

    For graphs with consecutively-valued nodes, connected_components() (below)
    will be faster because it avoids a relabeling step.
    
    Args:
        edges:
            ndarray, shape=(E,2), dtype np.uint32 or uint64
        
        node_ids:
            ndarray, shape=(N,), dtype np.uint32 or uint64

    Returns:
        ndarray, same shape as node_ids, labeled by component index from 0..C
    """
    assert node_ids.ndim == 1
    assert node_ids.dtype in (np.uint32, np.uint64)

    cons_node_ids = np.arange(len(node_ids), dtype=np.uint32)
    mapper = LabelMapper(node_ids, cons_node_ids)
    cons_edges = mapper.apply(edges)
    return connected_components(cons_edges, len(node_ids))
def _overwrite_body_id_column(block_sv_stats, segment_to_body_df=None):
    """
    Given a stats array with 'columns' as defined in STATS_DTYPE,
    overwrite the body_id column according to the given agglomeration mapping DataFrame.
    
    If no mapping is given, simply copy the segment_id column into the body_id column.
    
    Args:
        block_sv_stats: numpy.ndarray, with dtype=STATS_DTYPE
        segment_to_body_df: pandas.DataFrame, with columns ['segment_id', 'body_id']
    """
    assert block_sv_stats.dtype == STATS_DTYPE

    assert STATS_DTYPE[0][0] == 'body_id'
    assert STATS_DTYPE[1][0] == 'segment_id'
    
    block_sv_stats = block_sv_stats.view( [STATS_DTYPE[0], STATS_DTYPE[1], ('other_cols', STATS_DTYPE[2:])] )

    if segment_to_body_df is None:
        # No agglomeration
        block_sv_stats['body_id'] = block_sv_stats['segment_id']
    else:
        assert list(segment_to_body_df.columns) == AGGLO_MAP_COLUMNS
        
        # This could be done via pandas merge(), followed by fillna(), etc.,
        # but I suspect LabelMapper is faster and more frugal with RAM.
        mapper = LabelMapper(segment_to_body_df['segment_id'].values, segment_to_body_df['body_id'].values)
        del segment_to_body_df
    
        # Remap in batches to save RAM
        batch_size = 1_000_000
        for chunk_start in range(0, len(block_sv_stats), batch_size):
            chunk_stop = min(chunk_start+batch_size, len(block_sv_stats))
            chunk_segments = block_sv_stats['segment_id'][chunk_start:chunk_stop]
            block_sv_stats['body_id'][chunk_start:chunk_stop] = mapper.apply(chunk_segments, allow_unmapped=True)
Beispiel #3
0
def _find_disconnected_components(cleaned_edges, output_labels):
    """
    Given a graph defined by cleaned_edges and a node labeling in output_labels,
    Check if any output labels are split among discontiguous groups,
    and return the set of output label IDs for such objects.
    """
    # Figure out which edges were 'cut' (endpoints got different labels)
    # and which were preserved
    mapper = LabelMapper(np.arange(output_labels.shape[0], dtype=np.uint32),
                         output_labels)
    labeled_edges = mapper.apply(cleaned_edges)
    preserved_edges = cleaned_edges[labeled_edges[:, 0] == labeled_edges[:, 1]]

    # Compute CC on the graph WITHOUT cut edges (keep only preserved edges)
    component_labels = connected_components(preserved_edges,
                                            len(output_labels))
    assert len(component_labels) == len(output_labels)

    # Align node output labels to their connected component labels
    cc_df = pd.DataFrame({'label': output_labels, 'cc': component_labels})

    # How many unique connected component labels are associated with each output label?
    cc_counts = cc_df.groupby('label').nunique()['cc']

    # Any output labels that map to multiple CC labels are 'disconnected components' in the output.
    disconnected_cc_counts = cc_counts[cc_counts > 1]
    disconnected_components = set(disconnected_cc_counts.index) - set([0])

    return disconnected_components
Beispiel #4
0
def _overwrite_body_id_column(block_sv_stats, segment_to_body_df=None):
    """
    Given a stats array with 'columns' as defined in STATS_DTYPE,
    overwrite the body_id column according to the given agglomeration mapping DataFrame.
    
    If no mapping is given, simply copy the segment_id column into the body_id column.
    
    Args:
        block_sv_stats: numpy.ndarray, with dtype=STATS_DTYPE
        segment_to_body_df: pandas.DataFrame, with columns ['segment_id', 'body_id']
    """
    assert block_sv_stats.dtype == STATS_DTYPE

    assert STATS_DTYPE[0][0] == 'body_id'
    assert STATS_DTYPE[1][0] == 'segment_id'
    
    block_sv_stats = block_sv_stats.view( [STATS_DTYPE[0], STATS_DTYPE[1], ('other_cols', STATS_DTYPE[2:])] )

    if segment_to_body_df is None:
        # No agglomeration
        block_sv_stats['body_id'] = block_sv_stats['segment_id']
    else:
        assert list(segment_to_body_df.columns) == AGGLO_MAP_COLUMNS
        
        # This could be done via pandas merge(), followed by fillna(), etc.,
        # but I suspect LabelMapper is faster and more frugal with RAM.
        mapper = LabelMapper(segment_to_body_df['segment_id'].values, segment_to_body_df['body_id'].values)
        del segment_to_body_df
    
        # Remap in batches to save RAM
        batch_size = 1_000_000
        for chunk_start in range(0, len(block_sv_stats), batch_size):
            chunk_stop = min(chunk_start+batch_size, len(block_sv_stats))
            chunk_segments = block_sv_stats['segment_id'][chunk_start:chunk_stop]
            block_sv_stats['body_id'][chunk_start:chunk_stop] = mapper.apply(chunk_segments, allow_unmapped=True)
def infer_hierarchy(neuron_df, connection_df, min_weight=10, init='groundtruth', verbose=True, special_debug=False):
    ##
    ## TODO: If filtering connections for min_weight drops some neurons entirely, they should be removed from neuron_df
    ##
    lsf_slots = os.environ.get('LSB_DJOB_NUMPROC', default=0)
    if lsf_slots:
        os.environ['OMP_NUM_THREADS'] = lsf_slots
        logger.info(f"Using {lsf_slots} CPUs for OpenMP")

    assert init in ('groundtruth', 'random')
    neuron_df = load_table(neuron_df)
    connection_df = load_table(connection_df)

    assert {*neuron_df.columns} >= {'bodyId', 'instance', 'type'}
    assert {*connection_df.columns} >= {'bodyId_pre', 'bodyId_post', 'weight'}

    if special_debug:
        # Choose a very small subset of the data
        neuron_df = neuron_df.iloc[::100]
        bodies = neuron_df['bodyId']
        connection_df = connection_df.query('bodyId_pre in @bodies and bodyId_post in @bodies')

    if init == "groundtruth":
        with Timer("Computing initial hierarchy from groundtruth", logger):
            assign_morpho_indexes(neuron_df)
            num_morpho_groups = neuron_df.morpho_index.max()+1
            init_bs = [neuron_df['morpho_index'].values, np.zeros(num_morpho_groups, dtype=int)]
    else:
        init_bs = None

    # If this is a per-ROI table, sum up the ROIs.
    if 'roi' in connection_df:
        connection_df = connection_df.groupby(['bodyId_pre', 'bodyId_post'], as_index=False)['weight'].sum()

    strong_connections_df = connection_df.query('weight >= @min_weight')
    strong_bodies = pd.unique(strong_connections_df[['bodyId_pre', 'bodyId_post']].values.reshape(-1))
    weights = strong_connections_df.set_index(['bodyId_pre', 'bodyId_post'])['weight']
    
    logger.info(f"Strong connectome (cutoff={min_weight}) has {len(strong_bodies)} bodies and {len(weights)} edges")
    
    vertexes = np.arange(len(strong_bodies), dtype=np.uint32)
    vertex_mapper = LabelMapper(strong_bodies.astype(np.uint64), vertexes)
    vertex_reverse_mapper = LabelMapper(vertexes, strong_bodies.astype(np.uint64))

    g = construct_graph(weights, vertexes, vertex_mapper)
    
    with Timer("Running inference"):
        # Computes a NestedBlockState
        nbs = graph_tool.inference.minimize_nested_blockmodel_dl(g,
                                                                 bs=init_bs,
                                                                 mcmc_args=dict(parallel=True), # see graph-tool docs and mailing list for caveats 
                                                                 deg_corr=True,
                                                                 verbose=verbose)

    partition_df = construct_partition_table(nbs, neuron_df, vertexes, vertex_reverse_mapper)
    return strong_connections_df, g, nbs, partition_df
Beispiel #6
0
    def __init__(self, edge_array, directed=True):
        import graph_tool as gt

        assert edge_array.dtype in (np.uint32, np.uint64)

        self.node_ids = pd.unique(edge_array.reshape(-1))
        self.cons_node_ids = np.arange(len(self.node_ids), dtype=np.uint32)

        self.mapper = LabelMapper(self.node_ids, self.cons_node_ids)
        cons_edges = self.mapper.apply(edge_array)

        self.g = gt.Graph(directed=directed)
        self.g.add_edge_list(cons_edges, hashed=False)
 def remap_bricks(partition_bricks):
     domain, codomain = mapping_pairs.transpose()
     mapper = LabelMapper(domain, codomain)
     
     partition_bricks = list(partition_bricks)
     for brick in partition_bricks:
         # TODO: Apparently LabelMapper can't handle non-contiguous arrays right now.
         #       (It yields incorrect results)
         #       Check to see if this is still a problem in the latest version of xtensor-python.
         brick.volume = np.asarray( brick.volume, order='C' )
         
         mapper.apply_inplace(brick.volume, allow_unmapped=True)
     return partition_bricks
Beispiel #8
0
def apply_mappings(supervoxels, mappings):
    assert isinstance(mappings, dict)
    df = pd.DataFrame(index=supervoxels.astype(np.uint64, copy=False))
    df.index.name = 'sv'

    for name, mapping in mappings.items():
        assert isinstance(mapping, pd.Series)
        index_values = mapping.index.values.astype(np.uint64, copy=False)
        mapping_values = mapping.values.astype(np.uint64, copy=False)
        mapper = LabelMapper(index_values, mapping_values)
        df[name] = mapper.apply(df.index.values, allow_unmapped=True)

    return df
Beispiel #9
0
class SparseNodeGraph:
    """
    Wrapper around gt.Graph() that permits arbitrarily large
    node IDs, which are internally mapped to consecutive node IDs.
    
    Ideally, we could just use g.add_edge_list(..., hashed=True),
    but that feature appears to be unstable.
    (In my build, at least, it tends to segfault).
    """
    
    def __init__(self, edge_array, directed=True):
        import graph_tool as gt
    
        assert edge_array.dtype in (np.uint32, np.uint64)
    
        self.node_ids = pd.unique(edge_array.reshape(-1))
        self.cons_node_ids = np.arange(len(self.node_ids), dtype=np.uint32)
    
        self.mapper = LabelMapper(self.node_ids, self.cons_node_ids)
        cons_edges = self.mapper.apply(edge_array)
    
        self.g = gt.Graph(directed=directed)
        self.g.add_edge_list(cons_edges, hashed=False)
        

    def _map_node(self, node):
        return self.mapper.apply(np.array([node], self.node_ids.dtype))[0]


    def add_edge_weights(self, weights):
        if weights is not None:
            assert len(weights) == self.g.num_edges()
            if np.issubdtype(weights.dtype, np.integer):
                assert weights.dtype.itemsize <= 4, "Can't handle 8-byte ints, sadly."
                self.g.new_edge_property("int", vals=weights)
            elif np.issubdtype(weights.dtype, np.floating):
                self.g.new_edge_property("float", vals=weights)
            else:
                raise AssertionError("Can't handle non-numeric weights")


    def get_out_neighbors(self, node):
        cons_node = self._map_node(node)
        cons_neighbors = self.g.get_out_neighbors(cons_node)
        return self.node_ids[cons_neighbors]
    

    def get_in_neighbors(self, node):
        cons_node = self._map_node(node)
        cons_neighbors = self.g.get_in_neighbors(cons_node)
        return self.node_ids[cons_neighbors]
Beispiel #10
0
    def remap_bricks(partition_bricks):
        domain, codomain = mapping_pairs.transpose()
        mapper = LabelMapper(domain, codomain)

        partition_bricks = list(partition_bricks)
        for brick in partition_bricks:
            # TODO: Apparently LabelMapper can't handle non-contiguous arrays right now.
            #       (It yields incorrect results)
            #       Check to see if this is still a problem in the latest version of xtensor-python.
            brick.volume = np.asarray(brick.volume, order='C')

            mapper.apply_inplace(brick.volume, allow_unmapped=True)
            brick.compress()
        return partition_bricks
Beispiel #11
0
def apply_mapping_to_mergetable(merge_table_df, mapping):
    """
    Set the 'body' column of the given merge table (append one if it didn't exist)
    by applying the given SV->body mapping to the merge table's id_a column.
    """
    if isinstance(mapping, str):
        with Timer("Loading mapping", logger):
            mapping = load_mapping(mapping)

    assert isinstance(mapping, pd.Series), "Mapping must be a pd.Series"
    with Timer("Applying mapping to merge table", logger):
        mapper = LabelMapper(mapping.index.values, mapping.values)
        merge_table_df['body'] = mapper.apply(merge_table_df['id_a'].values,
                                              allow_unmapped=True)
Beispiel #12
0
def check_tarsupervoxels_status_via_exists(server, uuid, tsv_instance, bodies, seg_instance=None, mapping=None, kafka_msgs=None):
    """
    For the given bodies, query the given tarsupervoxels instance and return a
    DataFrame indicating which supervoxels are 'missing' from the instance.
    
    Bodies that no longer exist in the segmentation instance are ignored.

    This function downloads the complete mapping in advance and uses it to determine
    which supervoxels belong to each body.  Then uses the /exists endpoint to
    query for missing supervoxels, rather than /missing, which incurs a disk
    read in DVID.
    """
    if seg_instance is None:
        # Determine segmentation instance
        info = fetch_instance_info(server, uuid, tsv_instance)
        seg_instance = info["Base"]["Syncs"][0]
    
    if mapping is None:
        mapping = fetch_complete_mappings(server, uuid, seg_instance, kafka_msgs=kafka_msgs)
    
    # Filter out bodies we don't care about,
    # and append unmapped (singleton/identity) bodies
    _bodies = set(bodies)
    mapping = pd.DataFrame(mapping).query('body in @_bodies')['body'].copy()
    unmapped_bodies = _bodies - set(mapping)
    unmapped_bodies = np.fromiter(unmapped_bodies, np.uint64)
    singleton_mapping = pd.Series(index=unmapped_bodies, data=unmapped_bodies, dtype=np.uint64)
    mapping = pd.concat((mapping, singleton_mapping))
    
    assert mapping.index.values.dtype == np.uint64
    assert mapping.values.dtype == np.uint64
    
    # Faster than mapping.loc[], apparently
    mapper = LabelMapper(mapping.index.values, mapping.values)

    statuses = fetch_exists(server, uuid, tsv_instance, mapping.index, batch_size=10_000, processes=16)
    missing_svs = statuses[~statuses].index.values
    assert missing_svs.dtype == np.uint64
    missing_bodies = mapper.apply(missing_svs, True)

    missing_df = pd.DataFrame({'sv': missing_svs, 'body': missing_bodies})
    assert missing_df['sv'].dtype == np.uint64
    assert missing_df['body'].dtype == np.uint64
    
    # Return a series, indexed by sv
    return missing_df.set_index('sv')['body']
Beispiel #13
0
def find_all_hotknife_edges_for_plane(server, uuid, instance, plane_center_coord_s0, tile_shape_s0, plane_spacing_s0, min_overlap_s0=1, min_jaccard=0.0, *, scale=0, mapping=None):
    """
    Find all hotknife edges around the given X-plane, found in batches according to tile_shape_s0.
     
    See find_hotknife_edges() for more details.
    """
    plane_box_s0 = fetch_volume_box(server, uuid, instance)
    plane_bounds_s0 = plane_box_s0[:,:2]

    edge_tables = []
    tile_boxes = boxes_from_grid(plane_bounds_s0, tile_shape_s0, clipped=True)
    for tile_index, tile_bounds_s0 in enumerate(tile_boxes):
        tile_logger = PrefixedLogger(logger, f"Tile {tile_index:03d}/{len(tile_boxes):03d}: ")
        edge_table = find_hotknife_edges( server,
                                          uuid,
                                          instance,
                                          plane_center_coord_s0,
                                          tile_bounds_s0,
                                          plane_spacing_s0,
                                          min_overlap_s0,
                                          min_jaccard,
                                          scale=scale,
                                          supervoxels=True, # Compute edges across supervoxels, and then filter out already-merged bodies afterwards
                                          logger=tile_logger )
        edge_tables.append(edge_table)

    edge_table = pd.concat(edge_tables, ignore_index=True)
    assert (edge_table.columns == ['left', 'right', 'xa', 'ya', 'za', 'xb', 'yb', 'zb', 'overlap', 'jaccard', 'left_cc_size', 'right_cc_size']).all()
    edge_table.columns = ['id_a', 'id_b', 'xa', 'ya', 'za', 'xb', 'yb', 'zb', 'overlap', 'jaccard', 'left_cc_size', 'right_cc_size']
    assert edge_table['id_a'].dtype == np.uint64
    assert edge_table['id_b'].dtype == np.uint64
    
    # "Complete" mappings not necessary for our purposes.
    if mapping is None:
        mapping = fetch_mappings(server, uuid, instance, as_array=True)
    mapper = LabelMapper(mapping[:,0], mapping[:,1])
    
    edge_table['body_a'] = mapper.apply(edge_table['id_a'].values, True)
    edge_table['body_b'] = mapper.apply(edge_table['id_b'].values, True)

    edge_table.query('body_a != body_b', inplace=True)
    return edge_table
        def remap_cc_to_final(orig_brick, cc_brick, wrapped_brick_mapping_df):
            """
            Given an original brick and the corresponding CC brick,
            Relabel the CC brick according to the final label mapping,
            as provided in wrapped_brick_mapping_df.
            """
            assert (orig_brick.logical_box == cc_brick.logical_box).all()
            assert (orig_brick.physical_box == cc_brick.physical_box).all()

            # Check for NaN, which implies that the mapping for this
            # brick is empty (no objects to relabel).
            if isinstance(wrapped_brick_mapping_df, float):
                assert np.isnan(wrapped_brick_mapping_df)
                final_vol = orig_brick.volume
                orig_brick.compress()
            else:
                # Construct mapper from only the rows we need
                cc_labels = pd.unique(cc_brick.volume.reshape(-1)) # @UnusedVariable
                mapping = wrapped_brick_mapping_df.df.query('cc in @cc_labels')[['cc', 'final_cc']].values
                mapper = LabelMapper(*mapping.transpose())
    
                # Apply mapping to CC vol, writing zeros whereever the CC isn't mapped.
                final_vol = mapper.apply_with_default(cc_brick.volume, 0)
                
                # Overwrite zero voxels from the original segmentation.
                final_vol = np.where(final_vol, final_vol, orig_brick.volume)
            
                orig_brick.compress()
                cc_brick.compress()
            
            final_brick = Brick( orig_brick.logical_box,
                                 orig_brick.physical_box,
                                 final_vol,
                                 location_id=orig_brick.location_id,
                                 compression=orig_brick.compression )
            return final_brick
Beispiel #15
0
def _erase_tiny_interior_segments(seg_vol, min_size):
    """
    Erase any segments that are smaller than the given
    size and don't touch the edge of the volume.
    """
    edge_mitos = (set(pd.unique(seg_vol[0, :, :].ravel()))
                  | set(pd.unique(seg_vol[:, 0, :].ravel()))
                  | set(pd.unique(seg_vol[:, :, 0].ravel()))
                  | set(pd.unique(seg_vol[-1, :, :].ravel()))
                  | set(pd.unique(seg_vol[:, -1, :].ravel()))
                  | set(pd.unique(seg_vol[:, :, -1].ravel())))

    mito_sizes = pd.Series(seg_vol.ravel()).value_counts()
    nontiny_mitos = mito_sizes[mito_sizes >= min_size].index

    keep_mitos = (edge_mitos | set(nontiny_mitos))
    keep_mitos = np.array([*keep_mitos], np.uint64)
    if len(keep_mitos) == 0:
        return np.zeros_like(seg_vol)

    # Erase everything that isn't in the keep set
    seg_vol = LabelMapper(keep_mitos,
                          keep_mitos).apply_with_default(seg_vol, 0)
    return seg_vol
    def sparse_brick_coords_for_labels(self, labels, clip=True):
        """
        Return a DataFrame indicating the brick
        coordinates (starting corner) that encompass the given labels.

        Args:
            labels:
                A list of body IDs (if ``self.supervoxels`` is False),
                or supervoxel IDs (if ``self.supervoxels`` is True).

            clip:
                If True, filter the results to exclude any coordinates
                that fall outside this service's bounding-box.
                Otherwise, all brick coordinates that encompass the given labels
                will be returned, whether or not they fall within the bounding box.

        Returns:
            DataFrame with columns [z,y,x,label],
            where z,y,x represents the starting corner (in full-res coordinates)
            of a brick that contains the label.
        """
        assert not isinstance(labels,
                              set), "Pass labels as a list or array, not a set"
        labels = pd.unique(labels)
        is_supervoxels = self.supervoxels
        brick_shape = self.preferred_message_shape
        assert (brick_shape % self.block_width == 0).all(), \
            ("Brick shape ('preferred-message-shape') must be a multiple of the "
             f"block width ({self.block_width}) in all dimensions, not {brick_shape}")

        bad_labels = []

        if not is_supervoxels:
            # No supervoxel filtering.
            # Sort by body, since that should be slightly nicer for dvid performance.
            bodies_and_svs = {label: None for label in sorted(labels)}
        else:
            # Arbitrary heuristic for whether to do the body-lookups on DVID or on the client.
            if len(labels) < 100_000:
                # If we're only dealing with a few supervoxels,
                # ask dvid to map them to bodies for us.
                mapping = fetch_mapping(*self.instance_triple,
                                        labels,
                                        as_series=True)
            else:
                # If we're dealing with a lot of supervoxels, ask for
                # the entire mapping, and look up the bodies ourselves.
                complete_mapping = fetch_mappings(*self.instance_triple)
                mapper = LabelMapper(complete_mapping.index.values,
                                     complete_mapping.values)

                labels = np.asarray(labels, np.uint64)
                bodies = mapper.apply(labels, True)
                mapping = pd.Series(index=labels, data=bodies, name='body')
                mapping.index.rename('sv', inplace=True)

            bad_svs = mapping[mapping == 0]
            bad_labels.extend(bad_svs.index.tolist())

            # Group by body
            mapping = mapping[mapping != 0]
            grouped_svs = mapping.reset_index().groupby('body').agg(
                {'sv': list})['sv']

            # Sort by body, since that should be slightly nicer for dvid performance.
            bodies_and_svs = grouped_svs.sort_index().to_dict()

        # Extract these to avoid pickling 'self' (just for speed)
        server, uuid, instance = self.instance_triple
        if self._use_resource_manager_for_sparse_coords:
            mgr = self.resource_manager_client
        else:
            mgr = ResourceManagerClient("", 0)

        def fetch_brick_coords(body, supervoxel_subset):
            """
            Fetch the block coordinates for the given body,
            filter them for the given supervoxels (if any),
            and convert the block coordinates to brick coordinates.
            """
            assert is_supervoxels or supervoxel_subset is None

            try:
                with mgr.access_context(server, True, 1, 1):
                    labelindex = fetch_labelindex(server, uuid, instance, body,
                                                  'protobuf')
                coords_df = convert_labelindex_to_pandas(labelindex).blocks

            except HTTPError as ex:
                if (ex.response is not None
                        and ex.response.status_code == 404):
                    return (body, None)
                raise
            except RuntimeError as ex:
                if 'does not map to any body' in str(ex):
                    return (body, None)
                raise

            if len(coords_df) == 0:
                return (body, None)

            if is_supervoxels:
                supervoxel_subset = set(supervoxel_subset)
                coords_df = coords_df.query('sv in @supervoxel_subset').copy()

            coords_df[['z', 'y', 'x']] //= brick_shape
            coords_df['body'] = np.uint64(body)
            coords_df.drop_duplicates(inplace=True)
            return (body, coords_df)

        def fetch_and_concatenate_brick_coords(bodies_and_supervoxels):
            """
            To reduce the number of tiny DataFrames collected to the driver,
            it's best to concatenate the partitions first, on the workers,
            rather than a straightforward call to starmap(fetch_brick_coords).

            Hence, this function that consolidates each partition.
            """
            bad_bodies = []
            coord_dfs = []
            for (body, supervoxel_subset) in bodies_and_supervoxels:
                _, coords_df = fetch_brick_coords(body, supervoxel_subset)
                if coords_df is None:
                    bad_bodies.append(body)
                else:
                    coord_dfs.append(coords_df)
                    del coords_df

            if coord_dfs:
                return [(pd.concat(coord_dfs, ignore_index=True), bad_bodies)]
            else:
                return [(None, bad_bodies)]

        with Timer(
                f"Fetching coarse sparsevols for {len(labels)} labels ({len(bodies_and_svs)} bodies)",
                logger=logger):
            import dask.bag as db
            coords_and_bad_bodies = (
                db.from_sequence(
                    bodies_and_svs.items(), npartitions=4096
                )  # Instead of fancy heuristics, just pick 4096
                .map_partitions(fetch_and_concatenate_brick_coords).compute())

        coords_df_partitions, bad_body_partitions = zip(*coords_and_bad_bodies)

        for body in chain(*bad_body_partitions):
            if is_supervoxels:
                bad_labels.extend(bodies_and_svs[body])
            else:
                bad_labels.append(body)

        if bad_labels:
            name = 'sv' if is_supervoxels else 'body'
            pd.Series(bad_labels,
                      name=name).to_csv('labels-without-sparsevols.csv',
                                        index=False,
                                        header=True)
            if len(bad_labels) < 100:
                msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels: {bad_labels}"
            else:
                msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels. See labels-without-sparsevols.csv"

            logger.warning(msg)

        coords_df_partitions = list(
            filter(lambda df: df is not None, coords_df_partitions))
        if len(coords_df_partitions) == 0:
            raise RuntimeError(
                "Could not find bricks for any of the given labels")

        coords_df = pd.concat(coords_df_partitions, ignore_index=True)

        if self.supervoxels:
            coords_df['label'] = coords_df['sv']
        else:
            coords_df['label'] = coords_df['body']

        coords_df.drop_duplicates(['z', 'y', 'x', 'label'], inplace=True)
        coords_df[['z', 'y', 'x']] *= brick_shape

        if clip:
            # Keep if the last pixel in the brick is to the right of the bounding-box start
            # and the first pixel in the brick is to the left of the bounding-box stop
            keep = (coords_df[['z', 'y', 'x']] + brick_shape >
                    self.bounding_box_zyx[0]).all(axis=1)
            keep &= (coords_df[['z', 'y', 'x']] <
                     self.bounding_box_zyx[1]).all(axis=1)
            coords_df = coords_df.loc[keep]

        return coords_df[['z', 'y', 'x', 'label']]
Beispiel #17
0
def split_disconnected_bodies(labels_orig):
    """
    Produces 3D volume split into connected components.

    This function identifies bodies that are the same label
    but are not connected.  It splits these bodies and
    produces a dict that maps these newly split bodies to
    the original body label.

    Special exception: Segments with label 0 are not relabeled.
    
    Note:
        Requires scikit-image (which, currently, is not otherwise
        listed as a dependency of neuclease's conda-recipe).

    Args:
        labels_orig (numpy.array): 3D array of labels

    Returns:
        (labels_new, new_to_orig)

        labels_new:
            The partially relabeled array.
            Segments that were not split will keep their original IDs.
            Among split segments, the largest 'child' of a split segment retains the original ID.
            The smaller segments are assigned new labels in the range (N+1)..(N+1+S) where N is
            highest original label and S is the number of new segments after splitting.
        
        new_to_orig:
            A pseudo-minimal (but not quite minimal) mapping of labels
            (N+1)..(N+1+S) -> some subset of (1..N),
            which maps new segment IDs to the segments they came from.
            Segments that were not split at all are not mentioned in this mapping,
            for split segments, every mapping pair for the split is returned, including the k->k (identity) pair.
        
        new_unique_labels:
            An array of all label IDs in the newly relabeled volume.
            The original label set can be selected via:
            
                new_unique_labels[new_unique_labels < min(new_to_orig.keys())]
        
    """
    import skimage.measure as skm
    # Compute connected components and cast back to original dtype
    labels_cc = skm.label(labels_orig, background=0, connectivity=1)
    assert labels_cc.dtype == np.int64
    if labels_orig.dtype == np.uint64:
        labels_cc = labels_cc.view(np.uint64)
    else:
        labels_cc = labels_cc.astype(labels_orig.dtype, copy=False)

    # Find overlapping segments between orig and CC volumes
    overlap_table_df = contingency_table(labels_orig, labels_cc).reset_index()
    assert overlap_table_df.columns.tolist() == [
        'left', 'right', 'voxel_count'
    ]
    overlap_table_df.columns = ['orig', 'cc', 'voxels']
    overlap_table_df.sort_values('voxels', ascending=False, inplace=True)

    # If a label in 'orig' is duplicated, it has multiple components in labels_cc.
    # The largest component gets to keep the original ID;
    # the other components must take on new values.
    # (The new values must not conflict with any of the IDs in the original, so start at orig_max+1)
    new_cc_pos = overlap_table_df['orig'].duplicated()
    orig_max = overlap_table_df['orig'].max()
    new_cc_values = np.arange(orig_max + 1,
                              orig_max + 1 + new_cc_pos.sum(),
                              dtype=labels_orig.dtype)

    overlap_table_df['final_cc'] = overlap_table_df['orig'].copy()
    overlap_table_df.loc[new_cc_pos, 'final_cc'] = new_cc_values

    # Relabel the CC volume to use the 'final_cc' labels
    mapper = LabelMapper(overlap_table_df['cc'].values,
                         overlap_table_df['final_cc'].values)
    mapper.apply_inplace(labels_cc)

    # Generate the mapping that could (if desired) convert the new
    # volume into the original one, as described in the docstring above.
    emitted_mapping_rows = overlap_table_df['orig'].duplicated(keep=False)
    emitted_mapping_pairs = overlap_table_df.loc[emitted_mapping_rows,
                                                 ['final_cc', 'orig']].values

    new_to_orig = dict(emitted_mapping_pairs)

    new_unique_labels = pd.unique(overlap_table_df['final_cc'].values)
    new_unique_labels = new_unique_labels.astype(
        overlap_table_df['final_cc'].dtype)
    new_unique_labels.sort()

    return labels_cc, new_to_orig, new_unique_labels
Beispiel #18
0
def distance_transform_watershed(mask,
                                 smoothing=0.0,
                                 seed_mask=None,
                                 seed_labels=None,
                                 flood_from='interior'):
    """
    Compute a watershed over the distance transform within a mask.
    You can either compute the watershed from inside-to-outside or outside-to-inside.
    For the former, the watershed is seeded from the most interior points,
    and the distance transform is inverted so the watershed can proceed from low to high as usual.
    For the latter, the distance transform is seeded from the voxels immediately outside the mask,
    using labels as found in the seed_labels volume. In this mode, the results effectively tell
    you which exterior segment (in the seed volume) is closest to any given point within the
    interior of the mask.

    Or you can provide your own seeds if you think you know what you're doing.

    Args:
        mask:
            Only the masked area will be processed
        smoothing:
            If non-zero, run gaussian smoothing on the distance transform with the
            given sigma before defining seed points or running the watershed.

        seed_mask:
        seed_labels:
        flood_from:

    Returns:
        dt, labeled_seeds, ws, max_id

    Notes:
        This function provides a subset of the options that can be found in other
        libraries, such as:
            - https://github.com/ilastik/wsdt/blob/3709b27/wsdt/wsDtSegmentation.py#L26
            - https://github.com/constantinpape/elf/blob/34f5e76/elf/segmentation/watershed.py#L69

        This uses vigra's efficient true euclidean distance transform, which is superior to
        distance transform approximations, e.g. as found in https://imagej.net/Distance_Transform_Watershed

        The imglib2 distance transform uses the same algorithm as vigra:
        https://github.com/imglib/imglib2-algorithm/tree/master/src/main/java/net/imglib2/algorithm/morphology/distance
        http://www.theoryofcomputing.org/articles/v008a019/
    """
    mask = mask.astype(bool, copy=False)
    mask = vigra.taggedView(mask, 'zyx').astype(np.uint32)

    imask = np.logical_not(mask)
    outer_edge_mask = binary_edge_mask(mask, 'outer')

    assert flood_from in ('interior', 'exterior')
    if flood_from == 'interior':
        # Negate the distance transform result,
        # since watershed must start at minima, not maxima.
        # Convert to uint8 to benefit from 'turbo' watershed mode
        # (uses a bucket queue).
        dt = distance_transform(mask, False, smoothing, negate=True)

        if seed_mask is None:
            # requires float32 input for some reason
            if dt.ndim == 2:
                minima = vigra.analysis.localMinima(dt,
                                                    marker=np.nan,
                                                    neighborhood=8,
                                                    allowAtBorder=True,
                                                    allowPlateaus=False)
            else:
                minima = vigra.analysis.localMinima3D(dt,
                                                      marker=np.nan,
                                                      neighborhood=26,
                                                      allowAtBorder=True,
                                                      allowPlateaus=False)
            seed_mask = np.isnan(minima)
            del minima

        dt = normalize_image_range(dt, np.uint8)
    else:
        if seed_labels is None and seed_mask is None:
            logger.warning(
                "Without providing your own seed mask and/or seed labels, "
                "the watershed operation will simply be the same as a "
                "connected components operation.  Is that what you meant?")

        if seed_mask is None:
            seed_mask = outer_edge_mask.copy()

        # Dilate the mask once more.
        outer_edge_mask[:] |= binary_edge_mask(outer_edge_mask | mask, 'outer')

        dt = distance_transform(mask, False, smoothing, negate=False)
        dt = normalize_image_range(dt, np.uint8)

    if seed_labels is None:
        seed_mask = vigra.taggedView(seed_mask, 'zyx')
        labeled_seeds = vigra.analysis.labelMultiArrayWithBackground(
            seed_mask.view('uint8'))
    else:
        labeled_seeds = np.where(seed_mask, seed_labels, 0)

    # Make sure seed_mask matches labeled_seeds,
    # Even if some seed_labels were zero-valued
    seed_mask = (labeled_seeds != 0)

    # Must remap to uint32 before calling vigra's watershed.
    seed_mapper = None
    seed_values = None
    if labeled_seeds.dtype in (np.uint64, np.int64):
        labeled_seeds = labeled_seeds.astype(np.uint64)
        seed_values = np.sort(pd.unique(labeled_seeds.ravel()))
        if seed_values[0] != 0:
            seed_values = np.array([0] + list(seed_values), np.uint64)

        assert seed_values.dtype == np.uint64
        assert labeled_seeds.dtype == np.uint64

        ws_seed_values = np.arange(len(seed_values), dtype=np.uint32)
        seed_mapper = LabelMapper(seed_values, ws_seed_values)
        ws_seeds = seed_mapper.apply(labeled_seeds)
        assert ws_seeds.dtype == np.uint32
    else:
        ws_seeds = labeled_seeds

    # Fill the non-masked area with one big seed,
    # except for a thin border around the mask.
    # This saves time in the watershed step,
    # since these voxels now don't need to be
    # consumed in the watershed.
    dummy_seed = ws_seeds.max() + np.uint32(1)
    ws_seeds[np.logical_not(mask | outer_edge_mask)] = dummy_seed
    ws_seeds[outer_edge_mask & ~seed_mask] = 0

    dt[outer_edge_mask] = 255
    dt[seed_mask] = 0

    dt = vigra.taggedView(dt, 'zyx')
    ws_seeds = vigra.taggedView(ws_seeds, 'zyx')
    ws, max_id = vigra.analysis.watershedsNew(dt,
                                              seeds=ws_seeds,
                                              method='Turbo')

    # Areas that were unreachable without crossing over the border
    # could end up with the dummy seed.
    # We treat such areas as if they are outside of the mask.
    ws[ws == dummy_seed] = 0
    ws_seeds[imask] = 0

    # If we converted from uint64 to uint32 to perform the watershed,
    # convert back before returning.
    if seed_mapper is not None:
        ws = seed_values[ws]
    return dt, labeled_seeds, ws
Beispiel #19
0
def select_hulls_for_mito_bodies(mito_body_ct,
                                 mito_bodies_mask,
                                 mito_binary,
                                 body_seg,
                                 hull_masks,
                                 seed_bodies,
                                 box,
                                 scale,
                                 viewer=None,
                                 res0=8,
                                 progress=False):

    mito_bodies_mito_seg = np.where(mito_bodies_mask & mito_binary, body_seg,
                                    0)
    nonmito_body_seg = np.where(mito_bodies_mask, 0, body_seg)

    hull_cc_overlap_stats = []
    for hull_cc, (mask_box, mask) in tqdm_proxy(hull_masks.items(),
                                                disable=not progress):
        mbms = mito_bodies_mito_seg[box_to_slicing(*mask_box)]
        masked_hull_cc_bodies = np.where(mask, mbms, 0)
        # Faster to check for any non-zero values at all before trying to count them.
        # This early check saves a lot of time in practice.
        if not masked_hull_cc_bodies.any():
            continue

        # This hull was generated from a particular seed body (non-mito body).
        # If it accidentally overlaps with any other non-mito bodies,
        # then delete those voxels from the hull.
        # If that causes the hull to become split apart into multiple connected components,
        # then keep only the component(s) which overlap the seed body.
        seed_body = seed_bodies[hull_cc]
        nmbs = nonmito_body_seg[box_to_slicing(*mask_box)]
        other_bodies = set(pd.unique(nmbs[mask])) - {0, seed_body}
        if other_bodies:
            # Keep only the voxels on mito bodies or on the
            # particular non-mito body for this hull (the "seed body").
            mbm = mito_bodies_mask[box_to_slicing(*mask_box)]
            mask[:] &= (mbm | (nmbs == seed_body))
            mask = vigra.taggedView(mask, 'zyx')
            mask_cc = vigra.analysis.labelMultiArrayWithBackground(
                mask.view(np.uint8))
            if mask_cc.max() > 1:
                mask_ct = contingency_table(mask_cc, nmbs).reset_index()
                keep_ccs = mask_ct['left'].loc[(mask_ct['left'] != 0) &
                                               (mask_ct['right'] == seed_body)]
                mask[:] = mask_for_labels(mask_cc, keep_ccs)

        mito_bodies, counts = np.unique(masked_hull_cc_bodies,
                                        return_counts=True)
        overlaps = pd.DataFrame({
            'mito_body': mito_bodies,
            'overlap': counts,
            'hull_cc': hull_cc,
            'hull_size': mask.sum(),
            'hull_body': seed_body
        })
        hull_cc_overlap_stats.append(overlaps)

    if len(hull_cc_overlap_stats) == 0:
        logger.warning("Could not find any matches for any mito bodies!")
        mito_body_ct['hull_body'] = np.uint64(0)
        return mito_body_ct

    hull_cc_overlap_stats = pd.concat(hull_cc_overlap_stats, ignore_index=True)
    hull_cc_overlap_stats = hull_cc_overlap_stats.query(
        'mito_body != 0').copy()

    # Aggregate the stats for each body and the hull bodies it overlaps with,
    # Select the hull_body with the most overlap, or in the case of ties, the hull body that is largest overall.
    # (Ties are probably more common in the event that two hulls completely encompass a small mito body.)
    hull_body_overlap_stats = hull_cc_overlap_stats.groupby(
        ['mito_body', 'hull_body'])[['overlap', 'hull_size']].sum()
    hull_body_overlap_stats = hull_body_overlap_stats.sort_values(
        ['mito_body', 'overlap', 'hull_size'], ascending=False)
    hull_body_overlap_stats = hull_body_overlap_stats.reset_index()

    mito_hull_selections = (hull_body_overlap_stats.drop_duplicates(
        'mito_body').set_index('mito_body')['hull_body'])
    mito_body_ct = mito_body_ct.merge(mito_hull_selections,
                                      'left',
                                      left_index=True,
                                      right_index=True)
    mito_body_ct['hull_body'] = mito_body_ct['hull_body'].fillna(0)

    dtypes = {col: np.float32 for col in mito_body_ct.columns}
    dtypes['hull_body'] = np.uint64
    mito_body_ct = mito_body_ct.astype(dtypes)

    if viewer:
        assert mito_hull_selections.index.dtype == mito_hull_selections.values.dtype == np.uint64
        mito_hull_mapper = LabelMapper(mito_hull_selections.index.values,
                                       mito_hull_selections.values)
        remapped_body_seg = mito_hull_mapper.apply(body_seg, True)
        remapped_body_seg = apply_mask_for_labels(remapped_body_seg,
                                                  mito_hull_selections.values)
        update_seg_layer(viewer, 'altered-bodies', remapped_body_seg, scale,
                         box)

        # Show the final hull masks (after erasure of non-target bodies)
        assert sorted(hull_masks.keys()) == [*range(1, 1 + len(hull_masks))]
        hull_cc_overlap_stats = hull_cc_overlap_stats.sort_values('hull_size')
        hull_seg = np.zeros_like(remapped_body_seg)
        for row in hull_cc_overlap_stats.itertuples():
            mask_box, mask = hull_masks[row.hull_cc]
            view = hull_seg[box_to_slicing(*mask_box)]
            view[:] = np.where(mask, row.hull_body, view)
        update_seg_layer(viewer, 'final-hull-seg', hull_seg, scale, box)

    return mito_body_ct
Beispiel #20
0
def region_features(label_img,
                    grayscale_img=None,
                    features=['Box', 'Count'],
                    ignore_label=None):
    """
    Wrapper around vigra.analysis.extractRegionFeatures() that supports uint64 and
    returns each feature as a pandas Series or DataFrame, indexed by object ID.

    For simple features such as 'Box' and 'Count', most of the time is spent remapping the
    input array from uint64 to uint32, which is the only label image type supported by vigra.

    See vigra docs regarding the supported features:
        - http://ukoethe.github.io/vigra/doc-release/vigranumpy/index.html#vigra.analysis.extractRegionFeatures
        - http://ukoethe.github.io/vigra/doc-release/vigra/group__FeatureAccumulators.html

    Args:
        label_img:
            An integer-valued label image, containing no negative values

        grayscle_img:
            Optional.  If provided, then weighted features are available.
            See GRAYSCALE_FEATURE_NAMES, above.

        features:
            List of strings.  If no grayscale image was provided, you can only
            ask for the features in ``SEGMENTATION_FEATURE_NAMES``, above.

        ignore_label:
            A background label to ignore. If you don't want to ignore any thing, pass ``None``.

    Returns:
        dict {name: feature}, where each feature value is indexed by label ID.
        For keys where the feature is scalar-valued for each label, the returned value is a Series.
        For keys where the feature is a 1D array for each label, a DataFrame is returned, with columns 'zyx'.
        For keys where the feature is a 2D array (e.g. Box, RegionAxes, etc.), a Series is returned whose
        dtype=object, and each item in the series is a 2D array.
        TODO: This might be a good place to use Xarray
    """
    assert label_img.ndim in (2, 3)
    axes = 'zyx'[-label_img.ndim:]

    vfeatures = {*features}

    valid_names = {*SEGMENTATION_FEATURE_NAMES, *GRAYSCALE_FEATURE_NAMES}
    invalid_names = vfeatures - valid_names
    assert not invalid_names, \
        f"Invalid feature names: {invalid_names}"

    if 'Box' in features:
        vfeatures -= {'Box'}
        vfeatures |= {'Coord<Minimum>', 'Coord<Maximum>'}

    if 'Box0' in features:
        vfeatures -= {'Box0'}
        vfeatures |= {'Coord<Minimum>'}

    if 'Box1' in features:
        vfeatures -= {'Box1'}
        vfeatures |= {'Coord<Maximum>'}

    assert np.issubdtype(label_img.dtype, np.integer)

    label_ids = None
    if label_img.dtype == np.uint32:
        label_img32 = label_img
    elif label_img.dtype == np.int32:
        label_img32 = label_img.view(np.uint32)
    elif label_img.dtype in (np.int64, np.uint64):
        label_img = label_img.view(np.uint64)
        label_ids = np.sort(pd.unique(label_img.ravel()))

        # Map from uint64 -> uint32
        label_ids_32 = np.arange(len(label_ids), dtype=np.uint32)
        mapper = LabelMapper(label_ids, label_ids_32)
        label_img32 = mapper.apply(label_img)
        if ignore_label is not None:
            ignore_label = mapper.apply(np.array([ignore_label], np.uint64))[0]
    else:
        label_img32 = label_img.astype(np.uint32)

    assert label_img32.dtype == np.uint32

    if grayscale_img is None:
        invalid_names = vfeatures - {*SEGMENTATION_FEATURE_NAMES}
        assert not invalid_names, \
            f"Invalid segmentation feature names: {invalid_names}"
        grayscale_img = label_img32.view(np.float32)
    else:
        assert grayscale_img.dtype == np.float32, \
            "Grayscale image must be float32"

    grayscale_img = vigra.taggedView(grayscale_img, axes)
    label_img32 = vigra.taggedView(label_img32, axes)

    # TODO: provide histogramRange options
    acc = vigra.analysis.extractRegionFeatures(grayscale_img,
                                               label_img32, [*vfeatures],
                                               ignoreLabel=ignore_label)

    results = {}
    if 'Box0' in features:
        v = acc['Coord<Minimum >'].astype(np.int32)
        results['Box0'] = pd.DataFrame(v, columns=[*axes])
    if 'Box1' in features:
        v = 1 + acc['Coord<Maximum >'].astype(np.int32)
        results['Box1'] = pd.DataFrame(v, columns=[*axes])
    if 'Box' in features:
        box0 = acc['Coord<Minimum >'].astype(np.int32)
        box1 = (1 + acc['Coord<Maximum >']).astype(np.int32)
        boxes = np.stack((box0, box1), axis=1)
        obj_boxes = np.zeros(len(boxes), object)
        obj_boxes[:] = list(boxes)
        results['Box'] = pd.Series(obj_boxes, name='Box')

    for k, v in acc.items():
        k = k.replace(' ', '')

        # Only return the features the user explicitly requested.
        if k not in features:
            continue

        if v.ndim == 1:
            results[k] = pd.Series(v, name=k)
        elif v.ndim == 2:
            results[k] = pd.DataFrame(v, columns=[*axes])
        else:
            # If the data doesn't neatly fit into a 1-d Series
            # or a 2-d DataFrame, then construct a Series with dtype=object
            # and make each row a separate ndarray object.
            obj_v = np.zeros(len(v), dtype=object)
            obj_v[:] = list(v)
            results[k] = pd.Series(obj_v, name=k)

    # Set index to the original uint64 values
    if label_img.dtype == np.uint64:
        for v in results.values():
            v.index = label_ids

    return results
Beispiel #21
0
def compute_focused_bodies(server,
                           uuid,
                           instance,
                           synapse_samples,
                           min_tbars,
                           min_psds,
                           root_sv_sizes,
                           min_body_size,
                           sv_classifications=None,
                           marked_bad_bodies=None,
                           return_table=False,
                           kafka_msgs=None):
    """
    Compute the complete set of focused bodies, based on criteria for
    number of tbars, psds, or overall size, and excluding explicitly
    listed bad bodies.
    
    This function takes ~20 minutes to run on hemibrain inputs, with a ton of RAM.
    
    The procedure is:

    1. Apply synapse-based criteria
      a. Load synapse CSV file
      b. Map synapse SVs -> bodies (if needed)
        b2. If any SVs are 'retired', update those synapses to use the new IDs.
      c. Calculate synapses (tbars, psds) per body
      d. Initialize set with bodies that have enough synapses
    
    2. Apply size-based criteria
      a. Calculate body sizes (based on supervoxel sizes and current mapping)
      b. Add "big" bodies to the set
    
    3. Apply "bad body" criteria
      a. Read the list of "bad bodies"
      b. Remove bad bodies from the set
    
    Example:

        server = 'emdata3:8900'
        uuid = '7254'
        instance = 'segmentation'

        synapse_samples = '/nrs/flyem/bergs/complete-ffn-agglo/sampled-synapses-ef1d-locked.csv'

        min_tbars = 2
        min_psds = 10

        # old repo supervoxels (before server rebase)
        #
        # Note: This was taken from node 5501ae83e31247498303a159eef824d8, which is from a different repo.
        #       But that segmentation was eventually copied to the production repo as root node a776af.
        #       See /groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/8nm/copy-fixed-from-emdata2-to-emdata3-20180402.214505
        #
        #root_sv_sizes_dir = '/groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/8nm/compute-8nm-extended-fixed-STATS-ONLY-20180402.192015'
        #root_sv_sizes = f'{root_sv_sizes_dir}/supervoxel-sizes.h5'
        
        root_sv_sizes_dir = '/groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/flattened/compute-stats-from-corrupt-20181016.203848'
        root_sv_sizes = f'{root_sv_sizes_dir}/supervoxel-sizes-2884.h5'
        
        min_body_size = int(10e6)

        sv_classifications = '/nrs/flyem/bergs/sv-classifications.h5'
        marked_bad_bodies = '/nrs/flyem/bergs/complete-ffn-agglo/bad-bodies-2019-02-26.csv'
        
        table_description = f'{uuid}-{min_tbars}tbars-{min_psds}psds-{min_body_size / 1e6:.1f}Mv'
        focused_table = compute_focused_bodies(server, uuid, instance, synapse_samples, min_tbars, min_psds, root_sv_sizes, min_body_size, sv_classifications, marked_bad_bodies, return_table=True)

        # As npy:
        np.save(f'focused-{table_description}.npy', focused_table.to_records(index=True))

        # As CSV:
        focused_table.to_csv(f'focused-{table_description}.npy', index=True, header=True)
    
    Args:

        server, uuid, instance:
            labelmap instance

        root_sv_sizes:
            mapping of supervoxel sizes from the root node, as returned by load_supervoxel_sizes(),
            or a path to an hdf5 file from which it can be loaded
        
        synapse_samples:
            A DataFrame with columns 'body' (or 'sv') and 'kind', or a path to a CSV file with those columns.
            The 'kind' column is expected to have only 'PreSyn' and 'PostSyn' entries.
            If the table has an 'sv' column, any "retired" supervoxel IDs will be updated before
            continuing with the analysis.
        
        min_tbars:
            The minimum number pf PreSyn entries to pass the filter.
            Bodies with fewer tbars may still be included if they satisfy the min_psds criteria.
        
        min_psds:
            The minimum numer of PostSyn entires to pass the filter.
            Bodies with fewer psds may still pass the filter if they satisfy the min_tbars criteria.

        min_body_size:
            Determines which bodies are included on the basis of their size alone,
            regardless of synapse count.
        
        sv_classifications:
            Optional. Path to an hdf5 file containing supervoxel classifications.
            Must have datasets: 'supervoxel_ids', 'classifications', and 'class_names'.
            Used to exclude known-bad supervoxels. The file need not include all supervoxels.
            Any supervoxels MISSING from this file are not considered 'bad'.

        marked_bad_bodies:
            Optional. A list of known-bad bodies to exclude from the results,
            or a path to a .csv file with that list (in the first column),
            or a keyvalue instance name from which the list can be loaded as JSON.
        
        return_table:
            If True, return the focused bodies in a DataFrame, indexed by body,
            with columns for body size and synapse counts.
            If False, simply return the list of bodies (saves about 4 minutes).
    
    Returns:
        A list of body IDs that passed all filters, or a DataFrame with columns:
            ['voxel_count', 'PreSyn', 'PostSyn']
        (See return_table option.)
    """
    split_source = 'dvid'

    # Load full mapping. It's needed for both synapses and body sizes.
    mapping = fetch_complete_mappings(server,
                                      uuid,
                                      instance,
                                      include_retired=True,
                                      kafka_msgs=kafka_msgs)
    mapper = LabelMapper(mapping.index.values, mapping.values)

    ##
    ## Synapses
    ##
    if isinstance(synapse_samples, str):
        synapse_samples = load_synapses(synapse_samples)

    assert set(['sv', 'body']).intersection(set(synapse_samples.columns)), \
        "synapse samples must have either 'body' or 'sv' column"

    # If 'sv' column is present, use it to create (or update) the body column
    if 'sv' in synapse_samples.columns:
        synapse_samples = update_synapse_table(server,
                                               uuid,
                                               instance,
                                               synapse_samples,
                                               split_source=split_source)
        assert synapse_samples['sv'].dtype == np.uint64
        synapse_samples['body'] = mapper.apply(synapse_samples['sv'].values,
                                               True)

    with Timer("Filtering for synapses", logger):
        synapse_body_table = body_synapse_counts(synapse_samples)
        synapse_bodies = synapse_body_table.query(
            'PreSyn >= @min_tbars or PostSyn >= @min_psds').index.values
    logger.info(f"Found {len(synapse_bodies)} with sufficient synapses")

    focused_bodies = set(synapse_bodies)

    ##
    ## Body sizes
    ##
    with Timer("Filtering for body size", logger):
        sv_sizes = load_all_supervoxel_sizes(server,
                                             uuid,
                                             instance,
                                             root_sv_sizes,
                                             split_source=split_source)
        body_stats = compute_body_sizes(sv_sizes, mapping, True)
        big_body_stats = body_stats.query('voxel_count >= @min_body_size')
        big_bodies = big_body_stats.index
    logger.info(f"Found {len(big_bodies)} with sufficient size")

    focused_bodies |= set(big_bodies)

    ##
    ## SV classifications
    ##
    if sv_classifications is not None:
        with Timer(f"Filtering by supervoxel classifications"), h5py.File(
                sv_classifications, 'r') as f:
            sv_classes = pd.DataFrame({
                'sv':
                f['supervoxel_ids'][:],
                'klass':
                f['classifications'][:].astype(np.uint8)
            })

            # Get the set of bad supervoxels
            all_class_names = list(map(bytes.decode, f['class_names'][:]))
            bad_class_names = [
                'unknown', 'blood vessels', 'broken white tissue', 'glia',
                'oob'
            ]
            _bad_class_ids = set(map(all_class_names.index, bad_class_names))
            bad_svs = sv_classes.query('klass in @_bad_class_ids')['sv']

            # Add column for sizes
            bad_sv_sizes = pd.DataFrame(index=bad_svs).merge(
                pd.DataFrame(sv_sizes),
                how='left',
                left_index=True,
                right_index=True)

            # Append body
            bad_sv_sizes['body'] = mapper.apply(bad_sv_sizes.index.values,
                                                True)

            # For bodies that contain at least one bad supervoxel,
            # compute the total size of the bad supervoxels they contain
            body_bad_voxels = bad_sv_sizes.groupby('body').agg(
                {'voxel_count': 'sum'})

            # Append total body size for comparison
            body_bad_voxels = body_bad_voxels.merge(body_stats[['voxel_count'
                                                                ]],
                                                    how='left',
                                                    left_index=True,
                                                    right_index=True,
                                                    suffixes=('_bad',
                                                              '_total'))

            bad_bodies = body_bad_voxels.query(
                'voxel_count_bad > voxel_count_total//2').index

            bad_focused_bodies = focused_bodies & set(bad_bodies)
            logger.info(
                f"Dropping {len(bad_focused_bodies)} bodies with more than 50% bad supervoxels"
            )

            focused_bodies -= bad_focused_bodies

    ##
    ## Marked Bad bodies
    ##
    if marked_bad_bodies is not None:
        if isinstance(marked_bad_bodies, str):
            if marked_bad_bodies.endswith('.csv'):
                marked_bad_bodies = read_csv_col(marked_bad_bodies, 0)
            else:
                # If it ain't a CSV, maybe it's a key-value instance and key to read from.
                raise AssertionError(
                    "FIXME: Need convention for specifying key-value instance and key"
                )
                #marked_bad_bodies = fetch_key(server, uuid, marked_bad_bodies, as_json=True)

        with Timer(
                f"Dropping {len(marked_bad_bodies)} bad bodies (from {len(focused_bodies)})"
        ):
            focused_bodies -= set(marked_bad_bodies)

    ##
    ## Prepare results
    ##
    logger.info(f"Found {len(focused_bodies)} focused bodies")
    focused_bodies = np.fromiter(focused_bodies, dtype=np.uint64)
    focused_bodies.sort()

    if not return_table:
        return focused_bodies

    with Timer("Computing full focused table", logger):
        # Start with an empty DataFrame (except index)
        focus_table = pd.DataFrame(index=focused_bodies)

        # Merge size/sv_count
        focus_table = focus_table.merge(body_stats,
                                        how='left',
                                        left_index=True,
                                        right_index=True,
                                        copy=False)

        # Add synapse columns
        focus_table = focus_table.merge(synapse_body_table,
                                        how='left',
                                        left_index=True,
                                        right_index=True,
                                        copy=False)

        focus_table.fillna(0, inplace=True)
        focus_table['voxel_count'] = focus_table['voxel_count'].astype(
            np.uint64)
        focus_table['sv_count'] = focus_table['sv_count'].astype(np.uint32)
        focus_table['PreSyn'] = focus_table['PreSyn'].astype(np.uint32)
        focus_table['PostSyn'] = focus_table['PostSyn'].astype(np.uint32)

        # Sort biggest..smallest
        focus_table.sort_values('voxel_count', ascending=False, inplace=True)
        focus_table.index.name = 'body'

    return focus_table
Beispiel #22
0
def edge_weighted_watershed(cleaned_edges, edge_weights, seed_labels):
    """
    Run nifty.graph.edgeWeightedWatershedsSegmentation() on the given graph with N nodes and E edges.
    The graph node IDs must be consecutive, starting with zero, dtype=np.uint32
    
    
    Args:
        cleaned_edges:
            array, (E,2), uint32
            Node IDs should be consecutive (more-or-less).
            To avoid segfaults:
                - Must not contain duplicates.
                - Must not contain 'loops' (no self-edges).
        
        edge_weights:
            array, (E,), float32
        
        seed_labels:
            array (N,), uint32
            All un-seeded nodes should be marked as 0.
        
    Returns:
        (output_labels, disconnected_components, contains_unlabeled_components)
        
        Where:
        
            output_labels:
                array (N,), uint32
                Agglomerated node labeling.
                
            disconnected_components:
                A set of seeds which ended up with more than one component in the result.
            
            contains_unlabeled_components:
                True if the input contains one or more disjoint components that were not seeded
                and thus not labeled during agglomeration. False otherwise.
    """
    assert cleaned_edges.dtype == np.uint32
    assert cleaned_edges.ndim == 2
    assert cleaned_edges.shape[1] == 2
    assert edge_weights.shape == (len(cleaned_edges), )
    assert seed_labels.ndim == 1
    assert cleaned_edges.max() < len(seed_labels)

    g = nifty.graph.UndirectedGraph(len(seed_labels))
    g.insertEdges(cleaned_edges)
    output_labels = nifty.graph.edgeWeightedWatershedsSegmentation(
        g, seed_labels, edge_weights)
    contains_unlabeled_components = not output_labels.all()

    mapper = LabelMapper(np.arange(output_labels.shape[0], dtype=np.uint32),
                         output_labels)
    labeled_edges = mapper.apply(cleaned_edges)
    preserved_edges = cleaned_edges[labeled_edges[:, 0] == labeled_edges[:, 1]]

    component_labels = connected_components(preserved_edges,
                                            len(output_labels))
    assert len(component_labels) == len(output_labels) == len(seed_labels)
    cc_df = pd.DataFrame({'label': output_labels, 'cc': component_labels})
    cc_counts = cc_df.groupby('label').nunique()['cc']
    disconnected_cc_counts = cc_counts[cc_counts > 1]
    disconnected_components = set(disconnected_cc_counts.index) - set([0])

    return output_labels, disconnected_components, contains_unlabeled_components
Beispiel #23
0
def agglomerative_clustering(cleaned_edges,
                             edge_weights,
                             seed_labels,
                             node_sizes=None,
                             num_classes=None):
    """
    Run vigra.graphs.agglomerativeClustering() on the given graph with N nodes and E edges.
    The graph node IDs must be consecutive, starting with zero, dtype=np.uint32
    
    
    Args:
        cleaned_edges:
            array, (E,2), uint32
            Node IDs should be consecutive (more-or-less).
            To avoid segfaults:
                - Must not contain duplicates.
                - Must not contain 'loops' (no self-edges).
        
        edge_weights:
            array, (E,), float32
        
        seed_labels:
            array (N,), uint32
            All un-seeded nodes should be marked as 0.
        
    Returns:
        (output_labels, disconnected_components, contains_unlabeled_components)
        
        Where:
        
            output_labels:
                array (N,), uint32
                Agglomerated node labeling.
                
            disconnected_components:
                A set of seeds which ended up with more than one component in the result.
            
            contains_unlabeled_components:
                True if the input contains one or more disjoint components that were not seeded
                and thus not labeled during agglomeration. False otherwise.
    """
    #
    # Notes:
    #
    # vigra.graphs.agglomerativeClustering() is somewhat sophisticated.
    #
    # During agglomeration, edges are selected for 'contraction' and the corresponding nodes are merged.
    # The newly merged node contains the superset of the edges from its constituent nodes, with duplicate
    # edges combined via weighted average according to their relative 'edgeLengths'.
    #
    # The edge weights used in the optimization are adjusted dynamically after every merge.
    # The dynamic edge weight is computed as a weighted average of it's original 'edgeWeight'
    # and the similarity of its two nodes (by distance between 'nodeFeatures',
    # using the distance measure defined by 'metric').
    #
    # The relative importances of the original edgeWeight and the node similarity is determined by 'beta'.
    # To ignore node feature similarity completely, use beta=0.0.  To ignore edgeWeights completely, use beta=1.0.
    #
    # After computing that weighted average, the dynamic edge weight is then scaled by a 'Ward factor',
    # which seems to give priority to edges that connect smaller components.
    # The importance of the 'Ward factor' is determined by 'wardness'. To disable it, set wardness=0.0.
    #
    #
    # For reference, here are the relevant lines from vigra/hierarchical_clustering.hxx:
    #
    #    ValueType getEdgeWeight(const Edge & e){
    #        ...
    #        const ValueType wardFac = 2.0 / ( 1.0/std::pow(sizeU,wardness_) + 1/std::pow(sizeV,wardness_) );
    #        const ValueType fromEdgeIndicator = edgeIndicatorMap_[ee];
    #        ValueType fromNodeDist = metric_(nodeFeatureMap_[uu],nodeFeatureMap_[vv]);
    #        ValueType totalWeight = ((1.0-beta_)*fromEdgeIndicator + beta_*fromNodeDist)*wardFac;
    #        ...
    #    }
    #
    #
    # To achieve the "most naive" version of hierarchical clustering,
    # i.e. based purely on pre-computed edge weights (and no node features),
    # use beta=0.0, wardness=0.0.
    #
    # (Ideally, we would also set nodeSizes=[0,...], but unfortunately,
    # setting nodeSizes of 0.0 seems to result in strange bugs.
    # Therefore, we can't avoid the affect of using cumulative node size during the agglomeration.)

    assert cleaned_edges.dtype == np.uint32
    assert cleaned_edges.ndim == 2
    assert cleaned_edges.shape[1] == 2
    assert edge_weights.shape == (len(cleaned_edges), )
    assert seed_labels.ndim == 1
    assert cleaned_edges.max() < len(seed_labels)

    # Initialize graph
    # (These params merely reserve RAM in advance. They don't initialize actual graph state.)
    g = vg.AdjacencyListGraph(len(seed_labels), len(cleaned_edges))

    # Make sure there are the correct number of nodes.
    # (Internally, AdjacencyListGraph ensures contiguous nodes are created
    # up to the max id it has seen, so adding the max node is sufficient to
    # ensure all nodes are present.)
    g.addNode(len(seed_labels) - 1)

    # Insert edges.
    g.addEdges(cleaned_edges)

    if num_classes is None:
        num_classes = len(set(pd.unique(seed_labels)) - set([0]))

    output_labels = vg.agglomerativeClustering(
        graph=g,
        edgeWeights=edge_weights,
        #edgeLengths=...,
        #nodeFeatures=...,
        #nodeSizes=...,
        nodeLabels=seed_labels,
        nodeNumStop=num_classes,
        beta=0.0,
        #metric='l1',
        wardness=0.0)

    # For some reason, the output labels do not necessarily
    # have the same values as the seed labels. We have to relabel them ourselves.
    #
    # Furthermore, there are some special cases to consider:
    #
    # 1. It is possible that some seeds will map to disconnected components,
    #    if one of the following is true:
    #      - The input contains disconnected components with identical seeds
    #      - The input contains no disconnected components, but it failed to
    #        connect two components with identical seeds (some other seeded
    #        component ended up blocking the path between the two disconnected
    #        components).
    #    In those cases, we should ensure that the disconnected components are
    #    still labeled with the right input seed, but add the seed to the returned
    #    'disconnected components' set.
    #
    # 2. If the input contains any disconnected components that were NOT seeded,
    #    we should relabel those as 0, and return contains_unlabeled_components=True

    # Get mapping of seeds -> corresponding agg values.
    # (There might be more than one agg value for a given seed, as explained in point 1 above)
    df = pd.DataFrame({'seed': seed_labels, 'agg': output_labels})
    df.drop_duplicates(inplace=True)

    # How many unique agg values are there for each seed class?
    seed_mapping_df = df.query('seed != 0')
    seed_component_counts = seed_mapping_df.groupby(['seed'
                                                     ]).agg({'agg': 'size'})
    seed_component_counts.columns = ['component_count']

    # More than one agg value for a seed class implies that it wasn't fully agglomerated.
    disconnected_components = set(
        seed_component_counts.query('component_count > 1').index)

    # If there are 'extra' agg values (not corresponding to seeds),
    # then some component(s) are unlabeled. (Point 2 above.)
    _seeded_agg_ids = set(seed_mapping_df['agg'])
    nonseeded_agg_ids = df.query('agg not in @_seeded_agg_ids')['agg']
    contains_unlabeled_components = (len(nonseeded_agg_ids) > 0)

    # Map from output agg values back to original seed classes.
    agg_values = seed_mapping_df['agg'].values
    seed_values = seed_mapping_df['seed'].values
    if len(nonseeded_agg_ids) > 0:
        nonseeded_agg_ids = np.fromiter(nonseeded_agg_ids, np.uint32)
        agg_values = np.concatenate((agg_values, nonseeded_agg_ids))
        seed_values = np.concatenate(
            (seed_values, np.zeros((len(nonseeded_agg_ids), ), np.uint32)))

    mapper = LabelMapper(agg_values, seed_values)
    mapper.apply_inplace(output_labels)

    return CleaveResults(output_labels, disconnected_components,
                         contains_unlabeled_components)
Beispiel #24
0
def cleave(edges,
           edge_weights,
           seeds_dict,
           node_ids,
           node_sizes=None,
           method='seeded-mst'):
    """
    Cleave the graph with the given edges and edge weights.
    
    Args:
        
        edges:
            array, (E,2), uint32
        
        edge_weights:
            array, (E,), float32
        
        seeds_dict:
            dict, { seed_class : [node_id, node_id, ...] }
        
        node_ids:
            The complete list of node IDs in the graph. Must contain a superset of the ids given in edges.
            Extra ids in node_ids (i.e. not mentioned in 'edges') will be included
            in the results as disconnected components.
        
        method:
            One of: 'seeded-mst', 'seeded-watershed', 'agglomerative-clustering', 'echo-seeds'

    Returns:
    
        CleaveResults, namedtuple with fields:
        (node_ids, output_labels, disconnected_components, contains_unlabeled_components)
        
        Where:
            node_ids:
                The graph node_ids.
                
            output_labels:
                array (N,), uint32
                Agglomerated node labeling, in the same order as node_ids.
                
            disconnected_components:
                A set of seeds which ended up with more than one component in the result.
            
            contains_unlabeled_components:
                True if the input contains one or more disjoint components that were not seeded
                and thus not labeled during agglomeration. False otherwise.
        
    """
    assert isinstance(node_ids, np.ndarray)
    assert node_ids.dtype in (np.uint32, np.uint64)
    assert node_ids.ndim == 1
    assert node_sizes is None or node_sizes.shape == node_ids.shape

    cleave_func, requires_sizes = get_cleave_method(method)
    assert not requires_sizes or node_sizes is not None, \
        f"The specified cleave method ({method}) requires node sizes but none were provided."

    # Relabel node ids consecutively
    cons_node_ids = np.arange(len(node_ids), dtype=np.uint32)
    mapper = LabelMapper(node_ids, cons_node_ids)

    # Initialize sparse seed label array
    seed_labels = np.zeros_like(cons_node_ids)
    for seed_class, seed_nodes in seeds_dict.items():
        seed_nodes = np.asarray(seed_nodes, dtype=np.uint64)
        mapper.apply_inplace(seed_nodes)
        seed_labels[seed_nodes] = seed_class

    if len(edges) == 0:
        # No edges: Return empty results (just seeds)
        return CleaveResults(seed_labels, set(seeds_dict.keys()),
                             not seed_labels.all())

    # Clean the edges (normalized form, no duplicates, no loops)
    edges.sort(axis=1)
    edges_df = pd.DataFrame({
        'u': edges[:, 0],
        'v': edges[:, 1],
        'weight': edge_weights
    })
    edges_df.drop_duplicates(['u', 'v'], keep='last', inplace=True)
    edges_df = edges_df.query('u != v')
    edges = edges_df[['u', 'v']].values
    edge_weights = edges_df['weight'].values

    # Relabel edges for consecutive nodes
    cons_edges = mapper.apply(edges)
    assert cons_edges.dtype == np.uint32

    cleave_results = cleave_func(cons_edges, edge_weights, seed_labels,
                                 node_sizes)
    assert isinstance(cleave_results, CleaveResults)
    return cleave_results
def measure_tbar_mito_distances(seg_src,
                                mito_src,
                                body,
                                *,
                                search_configs=DEFAULT_SEARCH_CONFIGS,
                                npclient=None,
                                tbars=None,
                                valid_mitos=None):
    """
    Search for the closest mito to each tbar in a list of tbars
    (or any set of points, really).

    FIXME: Rename this function.  It works for more than just tbars.

    Args:
        seg_src:
            (server, uuid, instance) OR a flyemflows VolumeService
            Labelmap instance for the neuron segmentation.
        mito_src:
            (server, uuid, instance) OR a flyemflows VolumeService
            Labelmap instance for the mitochondria "supervoxel"
            segmentation -- not just the "masks".
        body:
            The body ID of interest, on which the tbars reside.
        search_configs:
            A list ``SearchConfig`` tuples.
            For each tbar, this function tries to locate a mitochondria within
            he given search radius, using data downloaded from the given scale.
            If the search fails and no mito can be found, the function tries
            again using the next search criteria in the list.
            The radius should always be specified in scale-0 units,
            regardless of the scale at which you want to perform the analysis.
            Additionally, the data will be downloaded at the specified scale,
            then downsampled (with continuity preserving downsampling) to a lower scale for analysis.
            Notes:
                - Scale 4 is too low-res.  Stick with scale-3 or better.
                - Higher radius is more expensive, but some of that expense is
                  recouped because all points that fall within the radius are
                  analyzed at once.  See _measure_tbar_mito_distances()
                  implementation for details.
            dilation_radius_s0:
                If dilation_radius_s0 is non-zero, the segmentation will be "repaired" to close
                gaps, using a procedure involving a dilation of the given radius.
            dilation_exclusion_buffer_s0:
                We want to close small gaps in the segmentation, but only if we think
                they're really a gap in the actual segmentation, not if they are merely
                fingers of the same branch that are actually connected outside of our
                analysis volume. The dilation procedure tends to form such spurious
                connections near the volume border, so this parameter can be used to
                exclude a buffer (inner halo) near the border from dilation repairs.
        npclient:
            ``neuprint.Client`` to use when fetching the list of tbars that belong
            to the given body, unless you provide your own tbar points in the next
            argument.
        tbars:
            A DataFrame of tbar coordinates at least with columns ``['x', 'y', 'z']``.
        valid_mitos:
            If provided, only the listed mito IDs will be considered valid as search targets.
    Returns:
        DataFrame of tbar coordinates, mito distances, and mito coordinates.
        Points for which no nearby mito could be found (after trying all the given search_configs)
        will be marked with `done=False` in the results.
    """
    assert search_configs[-1].is_final, "Last search config should be marked is_final"
    assert all([not cfg.is_final for cfg in search_configs[:-1]]), \
        "Only the last search config should be marked is_final (no others)."

    # Fetch tbars
    if tbars is None:
        tbars = fetch_synapses(body, SC(type='pre', primary_only=True), client=npclient)

    tbars = initialize_results(body, tbars)

    if valid_mitos is None or len(valid_mitos) == 0:
        valid_mito_mapper = None
    else:
        valid_mitos = np.asarray(valid_mitos, dtype=np.uint64)
        valid_mito_mapper = LabelMapper(valid_mitos, valid_mitos)

    with tqdm_proxy(total=len(tbars)) as progress:
        for row in tbars.itertuples():
            # can't use row.done -- itertuples might be out-of-sync
            done = tbars['done'].loc[row.Index]
            if done:
                continue

            loop_logger = None
            for cfg in search_configs:
                prefix = (f"({row.x}, {row.y}, {row.z}) [ds={cfg.download_scale} "
                          f"as={cfg.analysis_scale} r={cfg.radius_s0:4} dil={cfg.dilation_radius_s0:2}] ")
                loop_logger = PrefixedLogger(logger, prefix)

                prev_num_done = tbars['done'].sum()
                _measure_tbar_mito_distances(
                    seg_src, mito_src, body, tbars, row.Index, cfg, valid_mito_mapper, loop_logger)
                num_done = tbars['done'].sum()

                progress.update(num_done - prev_num_done)
                done = tbars['done'].loc[row.Index]
                if done:
                    break

                if not cfg.is_final:
                    loop_logger.info("Search failed for primary tbar. Trying next search config!")

            if not done:
                loop_logger.warning(f"Failed to find a nearby mito for tbar at point {(row.x, row.y, row.z)}")
                progress.update(1)

    failed = np.isinf(tbars['mito-distance'])
    succeeded = ~failed
    logger.info(f"Found mitos for {succeeded.sum()} tbars, failed for {failed.sum()} tbars")

    return tbars
Beispiel #26
0
def cleave(edges,
           edge_weights,
           seeds_dict,
           node_ids=None,
           method='seeded-watershed'):
    """
    Cleave the graph with the given edges and edge weights.
    If node_ids is given, it must contain a superset of the ids given in edges.
    Extra ids in node_ids (i.e. not mentioned in 'edges') will be included
    in the results as disconnected components.
    
    Args:
        
        edges:
            array, (E,2), uint32
        
        edge_weights:
            array, (E,), float32
        
        seeds_dict:
            dict, { seed_class : [node_id, node_id, ...] }
        
        node_ids:
            The complete list of node IDs in the graph.
        
        method:
            Either 'seeded-watershed' or 'agglomerative-clustering'

    Returns:
    
        CleaveResults, namedtuple with fields:
        (node_ids, output_labels, disconnected_components, contains_unlabeled_components)
        
        Where:
            node_ids:
                The graph node_ids.
                
            output_labels:
                array (N,), uint32
                Agglomerated node labeling, in the same order as node_ids.
                
            disconnected_components:
                A set of seeds which ended up with more than one component in the result.
            
            contains_unlabeled_components:
                True if the input contains one or more disjoint components that were not seeded
                and thus not labeled during agglomeration. False otherwise.
        
    """
    if node_ids is None:
        node_ids = pd.unique(edges.flat)
        node_ids.sort()

    assert isinstance(node_ids, np.ndarray)
    assert node_ids.dtype in (np.uint32, np.uint64)
    assert node_ids.ndim == 1

    assert method in ('seeded-watershed', 'agglomerative-clustering')

    # Clean the edges (normalized form, no duplicates, no loops)
    edges.sort(axis=1)
    edges_df = pd.DataFrame({
        'u': edges[:, 0],
        'v': edges[:, 1],
        'weight': edge_weights
    })
    edges_df.drop_duplicates(['u', 'v'], keep='last', inplace=True)
    edges_df = edges_df.query('u != v')
    edges = edges_df[['u', 'v']].values
    edge_weights = edges_df['weight'].values

    # Relabel node ids consecutively
    cons_node_ids = np.arange(len(node_ids), dtype=np.uint32)
    mapper = LabelMapper(node_ids, cons_node_ids)
    cons_edges = mapper.apply(edges)
    assert cons_edges.dtype == np.uint32

    # Initialize sparse seed label array
    seed_labels = np.zeros_like(cons_node_ids)
    for seed_class, seed_nodes in seeds_dict.items():
        seed_nodes = np.asarray(seed_nodes, dtype=np.uint64)
        mapper.apply_inplace(seed_nodes)
        seed_labels[seed_nodes] = seed_class

    if method == 'agglomerative-clustering':
        output_labels, disconnected_components, contains_unlabeled_components = agglomerative_clustering(
            cons_edges, edge_weights, seed_labels)
    elif method == 'seeded-watershed':
        output_labels, disconnected_components, contains_unlabeled_components = edge_weighted_watershed(
            cons_edges, edge_weights, seed_labels)

    return CleaveResults(node_ids, output_labels, disconnected_components,
                         contains_unlabeled_components)
Beispiel #27
0
def compute_body_sizes(sv_sizes, mapping, include_unmapped_singletons=False):
    """
    Given a Series of supervoxel sizes and an sv-to-body mapping,
    compute the size of each body in the mapping, and its count of supervoxels.
    
    Any supervoxels in the mapping that are missing from sv_sizes will be ignored.

    Body 0 will be excluded from the results, even if it is present in the mapping.
    
    Args:
        sv_sizes:
            pd.Series, indexed by sv, or a path to an hdf5 file which
            can be loaded via load_supervoxel_sizes()
        
        mapping:
            pd.Series, indexed by sv, with body as value,
            or a path to a file which can be loaded by load_mapping()
       
        include_unmapped_singletons:
            If True, then the result will also include all
            supervoxels from sv_sizes that weren't mentioned in the mapping.
            (They are presumed to be singleton-supervoxel bodies.)
    
    Returns:
        pd.DataFrame, indexed by body, with columns ['voxel_count', 'sv_count'],
        and sorted by decreasing voxel_count.

    Example:
    
        >>> mesh_job_dir = '/groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/8nm'
        >>> sv_sizes_path = f'{mesh_job_dir}/compute-8nm-extended-fixed-STATS-ONLY-20180402.192015/supervoxel-sizes.h5'
        >>> sv_sizes = load_supervoxel_sizes(sv_sizes_path)
        >>> mapping = fetch_complete_mappings('emdata3:8900', '52f9', 'segmentation')
        >>> body_sizes = compute_body_sizes(sv_sizes, mapping)
    """
    if isinstance(sv_sizes, str):
        logger.info("Loading supervoxel sizes")
        assert os.path.splitext(sv_sizes)[1] == '.h5'
        sv_sizes = load_supervoxel_sizes(sv_sizes)

    if isinstance(mapping, str):
        logger.info("Loading mapping")
        mapping = load_mapping(mapping)

    assert isinstance(sv_sizes, pd.Series)
    assert isinstance(mapping, pd.Series)

    assert sv_sizes.index.dtype == np.uint64

    sv_sizes = sv_sizes.astype(np.uint64, copy=False)
    size_mapper = LabelMapper(sv_sizes.index.values, sv_sizes.values)

    # Just drop SVs that we don't have sizes for.
    logger.info("Dropping unknown supervoxels")
    mapping = mapping.loc[mapping.index.isin(sv_sizes.index)]

    logger.info("Applying sizes to mapping")
    df = pd.DataFrame({'body': mapping})
    df['voxel_count'] = size_mapper.apply(mapping.index.values)

    logger.info("Aggregating sizes by body")
    body_stats = df.groupby('body').agg({'voxel_count': ['sum', 'size']})
    body_stats.columns = ['voxel_count', 'sv_count']
    body_stats['sv_count'] = body_stats['sv_count'].astype(np.uint32)
    #body_sizes = body_stats['voxel_count']

    if include_unmapped_singletons:
        logger.info("Appending singleton sizes")
        nonsingleton_rows = sv_sizes.index.isin(mapping.index)
        singleton_sizes = sv_sizes[~nonsingleton_rows]
        singleton_stats = pd.DataFrame({'voxel_count': singleton_sizes})
        singleton_stats['sv_count'] = np.uint32(1)
        body_stats = pd.concat((body_stats, singleton_stats))

    if 0 in body_stats.index:
        body_stats.drop(0, inplace=True)

    logger.info("Sorting sizes")
    body_stats.index.name = 'body'
    body_stats.sort_values(['voxel_count', 'sv_count'],
                           inplace=True,
                           ascending=False)
    return body_stats
Beispiel #28
0
def find_missing_adjacencies(server, uuid, instance, body, known_edges, svs=None, search_distance=1, connect_non_adjacent=False):
    """
    Given a body and an intra-body merge graph defined by the given
    list of "known" supervoxel-to-supervoxel edges within that body,
    
    1. Determine whether or not all supervoxels in the body are
       connected by a single component within the given graph.
       If so, return immediately.
    
    2. Attempt to augment the graph with additional edges based on
       supervoxel adjacencies in the segmentation from DVID.
       This is done by downloading the DVID labelindex to determine
       which blocks might contain adjacent supervoxels that could unify
       the graph, and then downloading those blocks (only) to search
       for the adjacencies.
    
    Notes:
        - Requires scikit-image (which, currently, is not otherwise
          listed as a dependency of neuclease's conda-recipe).
        
        - This function does not attempt to find ALL adjacencies between supervoxels;
          it stops looking as soon as they form a single connected component.

        - This function only considers two supervoxels "adjacent" if they are
          literally touching each other in the scale-0 segmentation. If there is
          a small gap between them, then they are not considered adjacent.
        
        - This function does not attempt to find inter-block adjacencies;
          only adjacencies within each block are detected.
          So, in pathological cases where a supervoxel is only adjacent to the
          rest of the body on a block-aligned edge, the adjacency will not be
          detected by this funciton.
        
    Args:
        server, uuid, instance:
            DVID segmentation labelmap instance
        
        body:
            ID of the body to inspect
        
        known_edges:
            ndarray (N,2), array of supervoxel pairs;
            known edges of the intra-body merge graph
        
        svs:
            Optional. The complete list of supervoxels
            that belong to this body, according to DVID.
            Providing this enhances performance in one important case:
            If the known_edges ALREADY constitute a single connected component
            which covers all supervoxels in the body, there is no need to
            download the labelindex.
        
        search_distance:
            If > 1, supervoxels are considered adjacent if they are within
            the given distance from each other, even if they aren't directly adjacent.
        
        connect_non_adjacent:
            If searching by adjacency failed to fully connect all supervoxels in the
            body into a single connected component, generate edges for supervoxels
            that are not adjacent, but merely are in the same block (if it helps
            unify the body).
    
    Returns:
        (new_edges, orig_num_cc, final_num_cc, block_tables),
        Where:
            new_edges are the new edges found via inspection of supervoxel adjacencies,
            
            orig_num_cc is the number of disjoint components in the given merge graph before
                this function runs,
            
            final_num_cc is the number of disjoint components after adding the new_edges,
            
            block_tables contains debug information about the adjacencies found in each
                block of analyzed segmentation
                
        Ideally, final_num_cc == 1, but in some cases the body's supervoxels may not be
        directly adjacent, or the adjacencies were not detected.  (See notes above.)
    """
    from skimage.morphology import dilation
    
    BLOCK_TABLE_COLS = ['z', 'y', 'x', 'sv_a', 'sv_b', 'cc_a', 'cc_b', 'detected', 'applied']
    known_edges = np.asarray(known_edges, np.uint64)
    if svs is None:
        # We could compute the supervoxel list ourselves from 
        # the labelindex, but dvid can do it faster.
        svs = fetch_supervoxels_for_body(server, uuid, instance, body)

    cc = connected_components_nonconsecutive(known_edges, svs)
    orig_num_cc = final_num_cc = cc.max()+1
    
    if orig_num_cc == 1:
        return np.zeros((0,2), np.uint64), orig_num_cc, final_num_cc, pd.DataFrame(columns=BLOCK_TABLE_COLS)

    labelindex = fetch_labelindex(server, uuid, instance, body, format='protobuf')
    encoded_block_coords = np.fromiter(labelindex.blocks.keys(), np.uint64, len(labelindex.blocks))
    coords_zyx = decode_labelindex_blocks(encoded_block_coords)

    cc_mapper = LabelMapper(svs, cc)
    svs_set = set(svs)

    sv_adj_found = []
    cc_adj_found = set()
    block_tables = {}
    
    searched_block_svs = {}
    
    for coord_zyx, sv_counts in zip(coords_zyx, labelindex.blocks.values()):
        # Given the supervoxels in this block, what CC adjacencies
        # MIGHT we find if we were to inspect the segmentation?
        block_svs = np.fromiter(sv_counts.counts.keys(), np.uint64)
        block_ccs = cc_mapper.apply(block_svs)
        possible_cc_adjacencies = set(combinations( set(block_ccs), 2 ))
        
        # We only aim to find (at most) a single link between each CC pair.
        # That is, we don't care about adjacencies between CC that we've already linked so far.
        possible_cc_adjacencies -= cc_adj_found
        if not possible_cc_adjacencies:
            continue

        searched_block_svs[(*coord_zyx,)] = block_svs
        
        # Not used in the search; only returned for debug purposes.
        try:
            block_adj_table = _init_adj_table(coord_zyx, block_svs, cc_mapper)
        except:
            raise

        block_vol = fetch_block_vol(server, uuid, instance, coord_zyx, svs_set)
        if search_distance > 0:
            # It would be nice to do a proper spherical dilation,
            # but apparently dilation() is special-cased to be WAY
            # faster with a square structuring element, and we prefer
            # speed over cleaner dilation.
            # footprint = skimage.morphology.ball(dilation)
            radius = search_distance//2
            footprint = np.ones(3*(1+2*radius,), np.uint8)
            dilated_block_vol = dilation(block_vol, footprint)
            
            # Since dilation is a max-filter, we might have accidentally
            # erased small, low-valued supervoxels, erasing the adjacendies.
            # Overlay the original volume to make sure they still count.
            block_vol = np.where(block_vol, block_vol, dilated_block_vol)
        
        sv_adjacencies = compute_label_adjacencies(block_vol)
        sv_adjacencies['cc_a'] = cc_mapper.apply( sv_adjacencies['sv_a'].values )
        sv_adjacencies['cc_b'] = cc_mapper.apply( sv_adjacencies['sv_b'].values )
        
        found_new_adj = False
        for row in sv_adjacencies.itertuples(index=False):
            if (row.cc_a != row.cc_b):
                sv_adj = (row.sv_a, row.sv_b)
                cc_adj = (row.cc_a, row.cc_b)
                
                # Normalize
                if row.cc_a > row.cc_b:
                    cc_adj = (row.cc_b, row.cc_a)

                if row.sv_a > row.sv_b:
                    sv_adj = (row.sv_b, row.sv_a)

                block_adj_table.loc[sv_adj, 'detected'] = True
                    
                if cc_adj not in cc_adj_found:
                    found_new_adj = True
                    cc_adj_found.add( cc_adj )
                    sv_adj_found.append( sv_adj )
                    
                    block_adj_table.loc[sv_adj, 'applied'] = True

        block_tables[(*coord_zyx,)] = block_adj_table

        # If we made at least one change and we've 
        # finally unified all components, then we're done.
        if found_new_adj:
            final_num_cc = connected_components(np.array(list(cc_adj_found), np.uint64), orig_num_cc).max()+1
            if final_num_cc == 1:
                break
    
    # If we couldn't connect everything via direct adjacencies,
    # we can just add edges for any supervoxels that share a block.
    if final_num_cc > 1 and connect_non_adjacent:
        for coord_zyx, block_svs in searched_block_svs.items():
            block_ccs = cc_mapper.apply(block_svs)
            
            # We only need one SV per connected component,
            # so load them into a dict.
            selected_svs = dict(zip(block_ccs, block_svs))
            for (sv_a, sv_b) in combinations(sorted(selected_svs.values()), 2):
                (cc_a, cc_b) = cc_mapper.apply(np.array([sv_a, sv_b], np.uint64))
                if cc_a > cc_b:
                    cc_a, cc_b = cc_b, cc_a
                
                if (cc_a, cc_b) not in cc_adj_found:
                    if sv_a > sv_b:
                        sv_a, sv_b = sv_b, sv_a

                    cc_adj_found.add( (cc_a, cc_b) )
                    sv_adj_found.append( (sv_a, sv_b) )

                    block_tables[(*coord_zyx,)].loc[(sv_a, sv_b), 'applied'] = True

        final_num_cc = connected_components(np.array(list(cc_adj_found), np.uint64), orig_num_cc).max()+1
    
    if len(block_tables) == 0:
        block_table = pd.DataFrame(columns=BLOCK_TABLE_COLS)
    else:
        block_table = pd.concat(block_tables.values(), sort=False).reset_index()
        block_table = block_table[BLOCK_TABLE_COLS]
    
    new_edges = np.array(sv_adj_found, np.uint64)
    return new_edges, int(orig_num_cc), int(final_num_cc), block_table
Beispiel #29
0
def compute_focused_paths(server,
                          uuid,
                          instance,
                          original_mapping,
                          important_bodies,
                          speculative_merge_tables,
                          split_mapping=None,
                          max_depth=10,
                          stop_after_endpoint_num=None,
                          return_after_setup=False):

    from ..merge_graph import LabelmapMergeGraph

    with Timer("Loading speculative merge graph", logger):
        merge_graph = LabelmapMergeGraph(speculative_merge_tables, uuid)

    if split_mapping is not None:
        _bad_edges = merge_graph.append_edges_for_split_supervoxels(
            split_mapping, server, uuid, instance, parent_sv_handling='drop')
    merge_table_df = merge_graph.merge_table_df

    if isinstance(original_mapping, str):
        with Timer("Loading mapping", logger):
            original_mapping = load_mapping(original_mapping)
    else:
        assert isinstance(original_mapping, pd.Series)

    with Timer("Applying mapping", logger):
        mapper = LabelMapper(original_mapping.index.values,
                             original_mapping.values)
        merge_table_df['body_a'] = mapper.apply(merge_table_df['id_a'].values,
                                                allow_unmapped=True)
        merge_table_df['body_b'] = mapper.apply(merge_table_df['id_b'].values,
                                                allow_unmapped=True)

    if isinstance(important_bodies, str):
        with Timer("Reading importances", logger):
            important_bodies = read_csv_col(important_bodies).values
        important_bodies = pd.Index(important_bodies, dtype=np.uint64)

    with Timer("Assigning importances", logger):
        merge_table_df['important_a'] = merge_table_df['body_a'].isin(
            important_bodies)
        merge_table_df['important_b'] = merge_table_df['body_b'].isin(
            important_bodies)

    with Timer("Discarding merged edges within 'important' bodies ", logger):
        size_before = len(merge_table_df)
        merge_table_df.query(
            '(body_a != body_b) and not (important_a and important_b)',
            inplace=True)
        size_after = len(merge_table_df)
        logger.info(
            f"Discarded {size_before - size_after} edges, ended with {len(merge_table_df)} edges"
        )

    edges = merge_table_df[['id_a', 'id_b']].values
    assert edges.dtype == np.uint64

    if return_after_setup:
        logger.info("Returning setup instead of searching for paths")
        return (edges, original_mapping, important_bodies, max_depth,
                stop_after_endpoint_num)

    logger.info(
        f"Finding paths among {len(important_bodies)} important bodies")
    # FIXME: Need to augment original_mapping with rows for single-sv bodies that are 'important'
    all_paths = find_all_paths(edges, original_mapping, important_bodies,
                               max_depth, stop_after_endpoint_num)
    return all_paths
def split_disconnected_bodies(labels_orig):
    """
    Produces 3D volume split into connected components.

    This function identifies bodies that are the same label
    but are not connected.  It splits these bodies and
    produces a dict that maps these newly split bodies to
    the original body label.

    Special exception: Segments with label 0 are not relabeled.

    Args:
        labels_orig (numpy.array): 3D array of labels

    Returns:
        (labels_new, new_to_orig)

        labels_new:
            The partially relabeled array.
            Segments that were not split will keep their original IDs.
            Among split segments, the largest 'child' of a split segment retains the original ID.
            The smaller segments are assigned new labels in the range (N+1)..(N+1+S) where N is
            highest original label and S is the number of new segments after splitting.
        
        new_to_orig:
            A pseudo-minimal (but not quite minimal) mapping of labels
            (N+1)..(N+1+S) -> some subset of (1..N),
            which maps new segment IDs to the segments they came from.
            Segments that were not split at all are not mentioned in this mapping,
            for split segments, every mapping pair for the split is returned, including the k->k (identity) pair.
    """
    # Compute connected components and cast back to original dtype
    labels_cc = skm.label(labels_orig, background=0, connectivity=1)
    assert labels_cc.dtype == np.int64
    if labels_orig.dtype == np.uint64:
        labels_cc = labels_cc.view(np.uint64)
    else:
        labels_cc = labels_cc.astype(labels_orig.dtype, copy=False)

    # Find overlapping segments between orig and CC volumes
    overlap_table_df = contingency_table(labels_orig, labels_cc)
    overlap_table_df.columns = ['orig', 'cc', 'voxels']
    overlap_table_df.sort_values('voxels', ascending=False, inplace=True)
    
    # If a label in 'orig' is duplicated, it has multiple components in labels_cc.
    # The largest component gets to keep the original ID;
    # the other components must take on new values.
    # (The new values must not conflict with any of the IDs in the original, so start at orig_max+1)
    new_cc_pos = overlap_table_df['orig'].duplicated()
    orig_max = overlap_table_df['orig'].max()
    new_cc_values = np.arange(orig_max+1, orig_max+1+new_cc_pos.sum(), dtype=labels_orig.dtype)

    overlap_table_df['final_cc'] = overlap_table_df['orig'].copy()
    overlap_table_df.loc[new_cc_pos, 'final_cc'] = new_cc_values
    
    # Relabel the CC volume to use the 'final_cc' labels
    mapper = LabelMapper(overlap_table_df['cc'].values, overlap_table_df['final_cc'].values)
    mapper.apply_inplace(labels_cc)

    # Generate the mapping that could (if desired) convert the new
    # volume into the original one, as described in the docstring above.
    emitted_mapping_rows = overlap_table_df['orig'].duplicated(keep=False)
    emitted_mapping_pairs = overlap_table_df.loc[emitted_mapping_rows, ['final_cc', 'orig']].values

    # Use tolist() to ensure plain Python int types
    # (This is required by some client code in Evaluate.py)
    new_to_orig = dict(emitted_mapping_pairs.tolist())
    
    return labels_cc, new_to_orig
Beispiel #31
0
def extract_important_merges(speculative_merge_tables,
                             important_bodies,
                             body_mapping=None,
                             mapping_instance_info=None,
                             drop_duplicate_body_pairs=False):
    assert (body_mapping is None) ^ (mapping_instance_info is None), \
        "You must set either body_mapping or mapping_instance_info (but not both)"

    if mapping_instance_info is not None:
        body_mapping = fetch_complete_mappings(mapping_instance_info)

    assert isinstance(body_mapping, pd.Series)
    mapper = LabelMapper(body_mapping.index.values, body_mapping.values)

    # pd.Index is faster than builtin set for large sets
    important_bodies = pd.Index(important_bodies)

    results = []
    for spec_merge_table_df in tqdm(speculative_merge_tables):
        logger.info(f"Processing table with {len(spec_merge_table_df)} rows")

        with Timer("Applying mapping", logger):
            spec_merge_table_df['body_a'] = mapper.apply(
                spec_merge_table_df['id_a'].values, allow_unmapped=True)
            spec_merge_table_df['body_b'] = mapper.apply(
                spec_merge_table_df['id_b'].values, allow_unmapped=True)

        with Timer("Dropping identity merges", logger):
            orig_size = len(spec_merge_table_df)
            spec_merge_table_df.query('body_a != body_b', inplace=True)
            logger.info(
                f"Dropped {orig_size-len(spec_merge_table_df)}/{orig_size} edges."
            )

        with Timer("Normalizing edges", logger):
            # Normalize for body ID, not SV ID
            # (This involves a lot of copying, but you've got plenty of RAM, right?)
            a_cols = list(
                filter(lambda s: s[-1] == 'a', spec_merge_table_df.columns))
            b_cols = list(
                filter(lambda s: s[-1] == 'b', spec_merge_table_df.columns))
            spec_merge_table = spec_merge_table_df.to_records(index=False)
            normalize_recarray_inplace(spec_merge_table, 'body_a', 'body_b',
                                       a_cols, b_cols)
            spec_merge_table_df = pd.DataFrame(spec_merge_table)

        with Timer("Filtering edges", logger):
            q = 'body_a in @important_bodies and body_b in @important_bodies'
            orig_len = len(spec_merge_table_df)
            spec_merge_table_df = spec_merge_table_df.query(q, inplace=True)
            logger.info(
                f"Filtered out {orig_len - len(spec_merge_table_df)} non-important edges."
            )

        if drop_duplicate_body_pairs:
            with Timer("Dropping duplicate body pairs", logger):
                orig_len = len(spec_merge_table_df)
                spec_merge_table_df.drop_duplicates(['body_a', 'body_b'],
                                                    inplace=True)
                logger.info(
                    f"Dropped {orig_len - len(spec_merge_table_df)} duplicate body pairs"
                )

        results.append(spec_merge_table_df)

    return results
Beispiel #32
0
def update_merge_table(server,
                       uuid,
                       instance,
                       table_df,
                       complete_mapping=None,
                       split_mapping=None):
    """
    Given a merge table (such as a focused proofreading decision table),
    find rows whose supervoxels no longer exist in the given instance (due to splits).

    For those invalid rows, determine the new supervoxel and body ID at the given coordinates
    to determine the updated supervoxel/body IDs.
    
    Updates (in-place) the supervoxel and body columns.
    
    Note: If any coordinate appears to be misplaced (i.e. the supervoxel ID at
    the coordinate is not a descendant of the listed supervoxel), the supervoxel is
    left unchanged and the body is mapped to 0.
    
    Args:
        server, uuid, instance:
            Table will be updated with respect to the given segmentation instance info
        
        table_df:
            DataFrame with SV columns and coordinate columns ('xa', 'ya', 'za', etc.)
        
        complete_mapping:
            Optional.  Will be fetched if not provided.
            Must be the complete mapping as returned by fetch_complete_mappings(..., include_retired=True)
        
        split_mapping:
            Optional.  Will be fetched if not provided.
            A mapping from supervoxel fragments to root supervoxel IDs.

    Returns:
        None. (The table is modified in place.)
    """
    seg_info = server, uuid, instance

    # Ensure proper table columns/dtypes
    if 'id_a' in table_df.columns:
        col_sv_a, col_sv_b = 'id_a', 'id_b'
    elif 'sv_a' in table_df.columns:
        col_sv_a, col_sv_b = 'sv_a', 'sv_b'
    else:
        raise RuntimeError("table has no sv columns")

    assert set([col_sv_a, col_sv_b, 'xa', 'ya', 'za', 'xb', 'yb',
                'zb']).issubset(table_df.columns)
    for col in ['xa', 'ya', 'za', 'xb', 'yb', 'zb']:
        table_df[col] = table_df[col].fillna(0).astype(np.int32)

    # Construct mappings if necessary
    kafka_msgs = None
    if complete_mapping is None or split_mapping is None:
        kafka_msgs = read_kafka_messages(*seg_info)

    if complete_mapping is None:
        complete_mapping = fetch_complete_mappings(*seg_info,
                                                   include_retired=True,
                                                   kafka_msgs=kafka_msgs)
    complete_mapper = LabelMapper(complete_mapping.index.values,
                                  complete_mapping.values)

    if split_mapping is None:
        split_events = fetch_supervoxel_splits(*seg_info)
        split_mapping = split_events_to_mapping(split_events,
                                                leaves_only=False)
    split_mapper = LabelMapper(split_mapping.index.values,
                               split_mapping.values)

    # Apply up-to-date body mapping
    # (Retired supervoxels will map to body 0)
    table_df['body_a'] = complete_mapper.apply(table_df[col_sv_a].values, True)
    table_df['body_b'] = complete_mapper.apply(table_df[col_sv_b].values, True)

    successfully_updated = 0
    failed_update = 0

    # Update the rows with invalid bodies.
    for index in tqdm_proxy(
            table_df.query('body_a == 0 or body_b == 0').index):

        def update_row(index, col_body, col_sv, col_z, col_y, col_x):
            nonlocal successfully_updated, failed_update

            # Extract from table
            coord = table_df.loc[index, [col_z, col_y, col_x]]
            old_sv = table_df.loc[index, col_sv]

            # Check current SV in the volume
            new_sv = fetch_label(*seg_info, coord, supervoxels=True)

            # The old/new SV must have come from the same root SV.
            # If not, the coordinate must be misplaced and can't be used here.
            svs = np.asarray([new_sv, old_sv], np.uint64)
            mapped_svs = split_mapper.apply(svs, True)
            if mapped_svs[0] != mapped_svs[1]:
                failed_update += 1
            else:
                body = complete_mapper.apply(np.array([new_sv], np.uint64))[0]
                table_df.loc[index, col_body] = body
                table_df.loc[index, col_sv] = new_sv
                successfully_updated += 1

        # id_a/body_a
        if table_df.loc[index, 'body_a'] == 0:
            update_row(index, 'body_a', col_sv_a, 'za', 'ya', 'xa')

        # id_b/body_b
        if table_df.loc[index, 'body_b'] == 0:
            update_row(index, 'body_b', col_sv_b, 'zb', 'yb', 'xb')

    logger.info(
        f"Updated {successfully_updated}, failed to update {failed_update}")
def find_edges_in_brick(brick,
                        closest_scale=None,
                        subset_groups=[],
                        subset_requirement=2):
    """
    Find all pairs of adjacent labels in the given brick,
    and find the central-most point along the edge between them.
    
    (Edges to/from label 0 are discarded.)
    
    If closest_scale is not None, then non-adjacent pairs will be considered,
    according to a particular heuristic to decide which pairs to consider.
    
    Args:
        brick:
            A Brick to analyze
        
        closest_scale:
            If None, then consider direct (touching) adjacencies only.
            If not-None, then non-direct "adjacencies" (i.e. close-but-not-touching) are found.
            In that case `closest_scale` should be an integer >=0 indicating the scale at which
            the analysis will be performed.
            Higher scales are faster, but less precise.
            See ``neuclease.util.approximate_closest_approach`` for more information.
        
        subset_groups:
            A DataFrame with columns [label, group].  Only the given labels will be analyzed
            for adjacencies.  Furthermore, edges (pairs) will only be returned if both labels
            in the edge are from the same group.
            
        subset_requirement:
            Whether or not both labels in each edge must be in subset_groups, or only one in each edge.
            (Currently, subset_requirement must be 2.)
        
    Returns:
        If the brick contains no edges at all (other than edges to label 0), return None.
        
        Otherwise, returns pd.DataFrame with columns:
            [label_a, label_b, forwardness, z, y, x, axis, edge_area, distance]. # fixme
        
        where label_a < label_b,
        'axis' indicates which axis the edge crosses at the chosen coordinate,
        
        (z,y,x) is always given as the coordinate to the left/above/front of the edge
        (depending on the axis).
        
        If 'forwardness' is True, then the given coordinate falls on label_a and
        label_b is one voxel "after" it (to the right/below/behind the coordinate).
        Otherwise, the coordinate falls on label_b, and label_a is "after".
        
        And 'edge_area' is the total count of the voxels touching both labels.
    """
    # Profiling indicates that query('... in ...') spends
    # most of its time in np.unique, believe it or not.
    # After looking at the implementation, I think it might help a
    # little if we sort the array first.
    brick_labels = np.sort(pd.unique(brick.volume.reshape(-1)))
    if (len(brick_labels) == 1) or (len(brick_labels) == 2 and
                                    (0 in brick_labels)):
        return None  # brick is solid -- no possible edges

    # Drop labels that aren't even present
    subset_groups = subset_groups.query('label in @brick_labels').copy()

    # Drop groups that don't have enough members (usually 2) in this brick.
    group_counts = subset_groups['group'].value_counts()
    _kept_groups = group_counts.loc[(group_counts >= subset_requirement)].index
    subset_groups = subset_groups.query('group in @_kept_groups').copy()

    if len(subset_groups) == 0:
        return None  # No possible edges to find in this brick.

    # Contruct a mapper that includes only the labels we'll keep.
    # (Other labels will be mapped to 0).
    # Also, the mapper converts to uint32 (required by _find_and_select_central_edges,
    # but also just good for RAM reasons).
    kept_labels = np.sort(np.unique(subset_groups['label'].values))
    remapped_kept_labels = np.arange(1, len(kept_labels) + 1, dtype=np.uint32)
    mapper = LabelMapper(kept_labels, remapped_kept_labels)
    reverse_mapper = LabelMapper(remapped_kept_labels, kept_labels)

    # Construct RAG -- finds all edges in the volume, on a per-pixel basis.
    remapped_volume = mapper.apply_with_default(brick.volume, 0)
    brick.compress()
    remapped_subset_groups = subset_groups.copy()
    remapped_subset_groups['label'] = mapper.apply(
        subset_groups['label'].values)

    try:
        if closest_scale is None:
            best_edges_df = _find_and_select_central_edges(
                remapped_volume, remapped_subset_groups, subset_requirement)
        else:
            best_edges_df = _find_closest_approaches(remapped_volume,
                                                     closest_scale,
                                                     remapped_subset_groups)
    except:
        brick_name = f"{brick.logical_box[:,::-1].tolist()}"
        np.save(f'problematic-remapped-brick-{brick_name}.npy',
                remapped_volume)
        logger.error(f"Error in brick (XYZ): {brick_name}"
                     )  # This will appear in the worker log.
        raise

    if best_edges_df is None:
        return None

    # Translate coordinates to global space
    best_edges_df.loc[:, ['za', 'ya', 'xa']] += brick.physical_box[0]
    best_edges_df.loc[:, ['zb', 'yb', 'xb']] += brick.physical_box[0]

    # Restore to original label set
    best_edges_df['label_a'] = reverse_mapper.apply(
        best_edges_df['label_a'].values)
    best_edges_df['label_b'] = reverse_mapper.apply(
        best_edges_df['label_b'].values)

    # Normalize
    swap_df_cols(best_edges_df, None, best_edges_df.eval('label_a > label_b'),
                 ['a', 'b'])

    return best_edges_df