Exemple #1
0
class SparseNodeGraph:
    """
    Wrapper around gt.Graph() that permits arbitrarily large
    node IDs, which are internally mapped to consecutive node IDs.
    
    Ideally, we could just use g.add_edge_list(..., hashed=True),
    but that feature appears to be unstable.
    (In my build, at least, it tends to segfault).
    """
    
    def __init__(self, edge_array, directed=True):
        import graph_tool as gt
    
        assert edge_array.dtype in (np.uint32, np.uint64)
    
        self.node_ids = pd.unique(edge_array.reshape(-1))
        self.cons_node_ids = np.arange(len(self.node_ids), dtype=np.uint32)
    
        self.mapper = LabelMapper(self.node_ids, self.cons_node_ids)
        cons_edges = self.mapper.apply(edge_array)
    
        self.g = gt.Graph(directed=directed)
        self.g.add_edge_list(cons_edges, hashed=False)
        

    def _map_node(self, node):
        return self.mapper.apply(np.array([node], self.node_ids.dtype))[0]


    def add_edge_weights(self, weights):
        if weights is not None:
            assert len(weights) == self.g.num_edges()
            if np.issubdtype(weights.dtype, np.integer):
                assert weights.dtype.itemsize <= 4, "Can't handle 8-byte ints, sadly."
                self.g.new_edge_property("int", vals=weights)
            elif np.issubdtype(weights.dtype, np.floating):
                self.g.new_edge_property("float", vals=weights)
            else:
                raise AssertionError("Can't handle non-numeric weights")


    def get_out_neighbors(self, node):
        cons_node = self._map_node(node)
        cons_neighbors = self.g.get_out_neighbors(cons_node)
        return self.node_ids[cons_neighbors]
    

    def get_in_neighbors(self, node):
        cons_node = self._map_node(node)
        cons_neighbors = self.g.get_in_neighbors(cons_node)
        return self.node_ids[cons_neighbors]
Exemple #2
0
def _overwrite_body_id_column(block_sv_stats, segment_to_body_df=None):
    """
    Given a stats array with 'columns' as defined in STATS_DTYPE,
    overwrite the body_id column according to the given agglomeration mapping DataFrame.
    
    If no mapping is given, simply copy the segment_id column into the body_id column.
    
    Args:
        block_sv_stats: numpy.ndarray, with dtype=STATS_DTYPE
        segment_to_body_df: pandas.DataFrame, with columns ['segment_id', 'body_id']
    """
    assert block_sv_stats.dtype == STATS_DTYPE

    assert STATS_DTYPE[0][0] == 'body_id'
    assert STATS_DTYPE[1][0] == 'segment_id'
    
    block_sv_stats = block_sv_stats.view( [STATS_DTYPE[0], STATS_DTYPE[1], ('other_cols', STATS_DTYPE[2:])] )

    if segment_to_body_df is None:
        # No agglomeration
        block_sv_stats['body_id'] = block_sv_stats['segment_id']
    else:
        assert list(segment_to_body_df.columns) == AGGLO_MAP_COLUMNS
        
        # This could be done via pandas merge(), followed by fillna(), etc.,
        # but I suspect LabelMapper is faster and more frugal with RAM.
        mapper = LabelMapper(segment_to_body_df['segment_id'].values, segment_to_body_df['body_id'].values)
        del segment_to_body_df
    
        # Remap in batches to save RAM
        batch_size = 1_000_000
        for chunk_start in range(0, len(block_sv_stats), batch_size):
            chunk_stop = min(chunk_start+batch_size, len(block_sv_stats))
            chunk_segments = block_sv_stats['segment_id'][chunk_start:chunk_stop]
            block_sv_stats['body_id'][chunk_start:chunk_stop] = mapper.apply(chunk_segments, allow_unmapped=True)
Exemple #3
0
def connected_components_nonconsecutive(edges, node_ids):
    """
    Run connected components on the graph encoded by 'edges' and node_ids.
    All nodes from edges must be present in node_ids.
    (Additional nodes are permitted, in which case each is assigned its own CC.)
    The node_ids do not need to be consecutive, and can have arbitrarily large values.

    For graphs with consecutively-valued nodes, connected_components() (below)
    will be faster because it avoids a relabeling step.
    
    Args:
        edges:
            ndarray, shape=(E,2), dtype np.uint32 or uint64
        
        node_ids:
            ndarray, shape=(N,), dtype np.uint32 or uint64

    Returns:
        ndarray, same shape as node_ids, labeled by component index from 0..C
    """
    assert node_ids.ndim == 1
    assert node_ids.dtype in (np.uint32, np.uint64)

    cons_node_ids = np.arange(len(node_ids), dtype=np.uint32)
    mapper = LabelMapper(node_ids, cons_node_ids)
    cons_edges = mapper.apply(edges)
    return connected_components(cons_edges, len(node_ids))
Exemple #4
0
def _find_disconnected_components(cleaned_edges, output_labels):
    """
    Given a graph defined by cleaned_edges and a node labeling in output_labels,
    Check if any output labels are split among discontiguous groups,
    and return the set of output label IDs for such objects.
    """
    # Figure out which edges were 'cut' (endpoints got different labels)
    # and which were preserved
    mapper = LabelMapper(np.arange(output_labels.shape[0], dtype=np.uint32),
                         output_labels)
    labeled_edges = mapper.apply(cleaned_edges)
    preserved_edges = cleaned_edges[labeled_edges[:, 0] == labeled_edges[:, 1]]

    # Compute CC on the graph WITHOUT cut edges (keep only preserved edges)
    component_labels = connected_components(preserved_edges,
                                            len(output_labels))
    assert len(component_labels) == len(output_labels)

    # Align node output labels to their connected component labels
    cc_df = pd.DataFrame({'label': output_labels, 'cc': component_labels})

    # How many unique connected component labels are associated with each output label?
    cc_counts = cc_df.groupby('label').nunique()['cc']

    # Any output labels that map to multiple CC labels are 'disconnected components' in the output.
    disconnected_cc_counts = cc_counts[cc_counts > 1]
    disconnected_components = set(disconnected_cc_counts.index) - set([0])

    return disconnected_components
def _overwrite_body_id_column(block_sv_stats, segment_to_body_df=None):
    """
    Given a stats array with 'columns' as defined in STATS_DTYPE,
    overwrite the body_id column according to the given agglomeration mapping DataFrame.
    
    If no mapping is given, simply copy the segment_id column into the body_id column.
    
    Args:
        block_sv_stats: numpy.ndarray, with dtype=STATS_DTYPE
        segment_to_body_df: pandas.DataFrame, with columns ['segment_id', 'body_id']
    """
    assert block_sv_stats.dtype == STATS_DTYPE

    assert STATS_DTYPE[0][0] == 'body_id'
    assert STATS_DTYPE[1][0] == 'segment_id'
    
    block_sv_stats = block_sv_stats.view( [STATS_DTYPE[0], STATS_DTYPE[1], ('other_cols', STATS_DTYPE[2:])] )

    if segment_to_body_df is None:
        # No agglomeration
        block_sv_stats['body_id'] = block_sv_stats['segment_id']
    else:
        assert list(segment_to_body_df.columns) == AGGLO_MAP_COLUMNS
        
        # This could be done via pandas merge(), followed by fillna(), etc.,
        # but I suspect LabelMapper is faster and more frugal with RAM.
        mapper = LabelMapper(segment_to_body_df['segment_id'].values, segment_to_body_df['body_id'].values)
        del segment_to_body_df
    
        # Remap in batches to save RAM
        batch_size = 1_000_000
        for chunk_start in range(0, len(block_sv_stats), batch_size):
            chunk_stop = min(chunk_start+batch_size, len(block_sv_stats))
            chunk_segments = block_sv_stats['segment_id'][chunk_start:chunk_stop]
            block_sv_stats['body_id'][chunk_start:chunk_stop] = mapper.apply(chunk_segments, allow_unmapped=True)
Exemple #6
0
def apply_mapping_to_mergetable(merge_table_df, mapping):
    """
    Set the 'body' column of the given merge table (append one if it didn't exist)
    by applying the given SV->body mapping to the merge table's id_a column.
    """
    if isinstance(mapping, str):
        with Timer("Loading mapping", logger):
            mapping = load_mapping(mapping)

    assert isinstance(mapping, pd.Series), "Mapping must be a pd.Series"
    with Timer("Applying mapping to merge table", logger):
        mapper = LabelMapper(mapping.index.values, mapping.values)
        body_a = mapper.apply(merge_table_df['id_a'].values,
                              allow_unmapped=True)
        body_b = mapper.apply(merge_table_df['id_b'].values,
                              allow_unmapped=True)

        # Cut edges that span across bodies
        body_a[body_a != body_b] = 0
        merge_table_df['body'] = body_a
Exemple #7
0
def apply_mappings(supervoxels, mappings):
    assert isinstance(mappings, dict)
    df = pd.DataFrame(index=supervoxels.astype(np.uint64, copy=False))
    df.index.name = 'sv'

    for name, mapping in mappings.items():
        assert isinstance(mapping, pd.Series)
        index_values = mapping.index.values.astype(np.uint64, copy=False)
        mapping_values = mapping.values.astype(np.uint64, copy=False)
        mapper = LabelMapper(index_values, mapping_values)
        df[name] = mapper.apply(df.index.values, allow_unmapped=True)

    return df
Exemple #8
0
def check_tarsupervoxels_status_via_exists(server, uuid, tsv_instance, bodies, seg_instance=None, mapping=None, kafka_msgs=None):
    """
    For the given bodies, query the given tarsupervoxels instance and return a
    DataFrame indicating which supervoxels are 'missing' from the instance.
    
    Bodies that no longer exist in the segmentation instance are ignored.

    This function downloads the complete mapping in advance and uses it to determine
    which supervoxels belong to each body.  Then uses the /exists endpoint to
    query for missing supervoxels, rather than /missing, which incurs a disk
    read in DVID.
    """
    if seg_instance is None:
        # Determine segmentation instance
        info = fetch_instance_info(server, uuid, tsv_instance)
        seg_instance = info["Base"]["Syncs"][0]
    
    if mapping is None:
        mapping = fetch_complete_mappings(server, uuid, seg_instance, kafka_msgs=kafka_msgs)
    
    # Filter out bodies we don't care about,
    # and append unmapped (singleton/identity) bodies
    _bodies = set(bodies)
    mapping = pd.DataFrame(mapping).query('body in @_bodies')['body'].copy()
    unmapped_bodies = _bodies - set(mapping)
    unmapped_bodies = np.fromiter(unmapped_bodies, np.uint64)
    singleton_mapping = pd.Series(index=unmapped_bodies, data=unmapped_bodies, dtype=np.uint64)
    mapping = pd.concat((mapping, singleton_mapping))
    
    assert mapping.index.values.dtype == np.uint64
    assert mapping.values.dtype == np.uint64
    
    # Faster than mapping.loc[], apparently
    mapper = LabelMapper(mapping.index.values, mapping.values)

    statuses = fetch_exists(server, uuid, tsv_instance, mapping.index, batch_size=10_000, processes=16)
    missing_svs = statuses[~statuses].index.values
    assert missing_svs.dtype == np.uint64
    missing_bodies = mapper.apply(missing_svs, True)

    missing_df = pd.DataFrame({'sv': missing_svs, 'body': missing_bodies})
    assert missing_df['sv'].dtype == np.uint64
    assert missing_df['body'].dtype == np.uint64
    
    # Return a series, indexed by sv
    return missing_df.set_index('sv')['body']
Exemple #9
0
def find_all_hotknife_edges_for_plane(server, uuid, instance, plane_center_coord_s0, tile_shape_s0, plane_spacing_s0, min_overlap_s0=1, min_jaccard=0.0, *, scale=0, mapping=None):
    """
    Find all hotknife edges around the given X-plane, found in batches according to tile_shape_s0.
     
    See find_hotknife_edges() for more details.
    """
    plane_box_s0 = fetch_volume_box(server, uuid, instance)
    plane_bounds_s0 = plane_box_s0[:,:2]

    edge_tables = []
    tile_boxes = boxes_from_grid(plane_bounds_s0, tile_shape_s0, clipped=True)
    for tile_index, tile_bounds_s0 in enumerate(tile_boxes):
        tile_logger = PrefixedLogger(logger, f"Tile {tile_index:03d}/{len(tile_boxes):03d}: ")
        edge_table = find_hotknife_edges( server,
                                          uuid,
                                          instance,
                                          plane_center_coord_s0,
                                          tile_bounds_s0,
                                          plane_spacing_s0,
                                          min_overlap_s0,
                                          min_jaccard,
                                          scale=scale,
                                          supervoxels=True, # Compute edges across supervoxels, and then filter out already-merged bodies afterwards
                                          logger=tile_logger )
        edge_tables.append(edge_table)

    edge_table = pd.concat(edge_tables, ignore_index=True)
    assert (edge_table.columns == ['left', 'right', 'xa', 'ya', 'za', 'xb', 'yb', 'zb', 'overlap', 'jaccard', 'left_cc_size', 'right_cc_size']).all()
    edge_table.columns = ['id_a', 'id_b', 'xa', 'ya', 'za', 'xb', 'yb', 'zb', 'overlap', 'jaccard', 'left_cc_size', 'right_cc_size']
    assert edge_table['id_a'].dtype == np.uint64
    assert edge_table['id_b'].dtype == np.uint64
    
    # "Complete" mappings not necessary for our purposes.
    if mapping is None:
        mapping = fetch_mappings(server, uuid, instance, as_array=True)
    mapper = LabelMapper(mapping[:,0], mapping[:,1])
    
    edge_table['body_a'] = mapper.apply(edge_table['id_a'].values, True)
    edge_table['body_b'] = mapper.apply(edge_table['id_b'].values, True)

    edge_table.query('body_a != body_b', inplace=True)
    return edge_table
Exemple #10
0
def compute_body_sizes(sv_sizes, mapping, include_unmapped_singletons=False):
    """
    Given a Series of supervoxel sizes and an sv-to-body mapping,
    compute the size of each body in the mapping, and its count of supervoxels.
    
    Any supervoxels in the mapping that are missing from sv_sizes will be ignored.

    Body 0 will be excluded from the results, even if it is present in the mapping.
    
    Args:
        sv_sizes:
            pd.Series, indexed by sv, or a path to an hdf5 file which
            can be loaded via load_supervoxel_sizes()
        
        mapping:
            pd.Series, indexed by sv, with body as value,
            or a path to a file which can be loaded by load_mapping()
       
        include_unmapped_singletons:
            If True, then the result will also include all
            supervoxels from sv_sizes that weren't mentioned in the mapping.
            (They are presumed to be singleton-supervoxel bodies.)
    
    Returns:
        pd.DataFrame, indexed by body, with columns ['voxel_count', 'sv_count'],
        and sorted by decreasing voxel_count.

    Example:
    
        >>> mesh_job_dir = '/groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/8nm'
        >>> sv_sizes_path = f'{mesh_job_dir}/compute-8nm-extended-fixed-STATS-ONLY-20180402.192015/supervoxel-sizes.h5'
        >>> sv_sizes = load_supervoxel_sizes(sv_sizes_path)
        >>> mapping = fetch_complete_mappings('emdata3:8900', '52f9', 'segmentation')
        >>> body_sizes = compute_body_sizes(sv_sizes, mapping)
    """
    if isinstance(sv_sizes, str):
        logger.info("Loading supervoxel sizes")
        assert os.path.splitext(sv_sizes)[1] == '.h5'
        sv_sizes = load_supervoxel_sizes(sv_sizes)

    if isinstance(mapping, str):
        logger.info("Loading mapping")
        mapping = load_mapping(mapping)

    assert isinstance(sv_sizes, pd.Series)
    assert isinstance(mapping, pd.Series)

    assert sv_sizes.index.dtype == np.uint64

    sv_sizes = sv_sizes.astype(np.uint64, copy=False)
    size_mapper = LabelMapper(sv_sizes.index.values, sv_sizes.values)

    # Just drop SVs that we don't have sizes for.
    logger.info("Dropping unknown supervoxels")
    mapping = mapping.loc[mapping.index.isin(sv_sizes.index)]

    logger.info("Applying sizes to mapping")
    df = pd.DataFrame({'body': mapping})
    df['voxel_count'] = size_mapper.apply(mapping.index.values)

    logger.info("Aggregating sizes by body")
    body_stats = df.groupby('body').agg({'voxel_count': ['sum', 'size']})
    body_stats.columns = ['voxel_count', 'sv_count']
    body_stats['sv_count'] = body_stats['sv_count'].astype(np.uint32)
    #body_sizes = body_stats['voxel_count']

    if include_unmapped_singletons:
        logger.info("Appending singleton sizes")
        nonsingleton_rows = sv_sizes.index.isin(mapping.index)
        singleton_sizes = sv_sizes[~nonsingleton_rows]
        singleton_stats = pd.DataFrame({'voxel_count': singleton_sizes})
        singleton_stats['sv_count'] = np.uint32(1)
        body_stats = pd.concat((body_stats, singleton_stats))

    if 0 in body_stats.index:
        body_stats.drop(0, inplace=True)

    logger.info("Sorting sizes")
    body_stats.index.name = 'body'
    body_stats.sort_values(['voxel_count', 'sv_count'],
                           inplace=True,
                           ascending=False)
    return body_stats
Exemple #11
0
def cleave(edges,
           edge_weights,
           seeds_dict,
           node_ids,
           node_sizes=None,
           method='seeded-mst'):
    """
    Cleave the graph with the given edges and edge weights.
    
    Args:
        
        edges:
            array, (E,2), uint32
        
        edge_weights:
            array, (E,), float32
        
        seeds_dict:
            dict, { seed_class : [node_id, node_id, ...] }
        
        node_ids:
            The complete list of node IDs in the graph. Must contain a superset of the ids given in edges.
            Extra ids in node_ids (i.e. not mentioned in 'edges') will be included
            in the results as disconnected components.
        
        method:
            One of: 'seeded-mst', 'seeded-watershed', 'agglomerative-clustering', 'echo-seeds'

    Returns:
    
        CleaveResults, namedtuple with fields:
        (node_ids, output_labels, disconnected_components, contains_unlabeled_components)
        
        Where:
            node_ids:
                The graph node_ids.
                
            output_labels:
                array (N,), uint32
                Agglomerated node labeling, in the same order as node_ids.
                
            disconnected_components:
                A set of seeds which ended up with more than one component in the result.
            
            contains_unlabeled_components:
                True if the input contains one or more disjoint components that were not seeded
                and thus not labeled during agglomeration. False otherwise.
        
    """
    assert isinstance(node_ids, np.ndarray)
    assert node_ids.dtype in (np.uint32, np.uint64)
    assert node_ids.ndim == 1
    assert node_sizes is None or node_sizes.shape == node_ids.shape

    cleave_func, requires_sizes = get_cleave_method(method)
    assert not requires_sizes or node_sizes is not None, \
        f"The specified cleave method ({method}) requires node sizes but none were provided."

    # Relabel node ids consecutively
    cons_node_ids = np.arange(len(node_ids), dtype=np.uint32)
    mapper = LabelMapper(node_ids, cons_node_ids)

    # Initialize sparse seed label array
    seed_labels = np.zeros_like(cons_node_ids)
    for seed_class, seed_nodes in seeds_dict.items():
        seed_nodes = np.asarray(seed_nodes, dtype=np.uint64)
        mapper.apply_inplace(seed_nodes)
        seed_labels[seed_nodes] = seed_class

    if len(edges) == 0:
        # No edges: Return empty results (just seeds)
        return CleaveResults(seed_labels, set(seeds_dict.keys()),
                             not seed_labels.all())

    # Clean the edges (normalized form, no duplicates, no loops)
    edges.sort(axis=1)
    edges_df = pd.DataFrame({
        'u': edges[:, 0],
        'v': edges[:, 1],
        'weight': edge_weights
    })
    edges_df.drop_duplicates(['u', 'v'], keep='last', inplace=True)
    edges_df = edges_df.query('u != v')
    edges = edges_df[['u', 'v']].values
    edge_weights = edges_df['weight'].values

    # Relabel edges for consecutive nodes
    cons_edges = mapper.apply(edges)
    assert cons_edges.dtype == np.uint32

    cleave_results = cleave_func(cons_edges, edge_weights, seed_labels,
                                 node_sizes)
    assert isinstance(cleave_results, CleaveResults)
    return cleave_results
def test_split_disconnected_bodies():

    _ = 2  # for readability in the array below

    # Note that we multiply these by 10 for this test!
    orig = [
        [1, 1, 1, 1, _, _, 3, 3, 3, 3],
        [1, 1, 1, 1, _, _, 3, 3, 3, 3],
        [1, 1, 1, 1, _, _, 3, 3, 3, 3],
        [1, 1, 1, 1, _, _, 3, 3, 3, 3],
        [_, _, _, _, _, _, _, _, _, _],
        [0, 0, _, _, 4, 4, _, _, 0,
         0],  # Note that the zeros here will not be touched.
        [_, _, _, _, 4, 4, _, _, _, _],
        [1, 1, 1, _, _, _, _, 3, 3, 3],
        [1, 1, 1, _, 1, 1, _, 3, 3, 3],
        [1, 1, 1, _, 1, 1, _, 3, 3, 3]
    ]

    orig = np.array(orig).astype(np.uint64)
    orig *= 10

    split, mapping, split_unique = split_disconnected_bodies(orig)

    # New IDs are generated starting after the original max value
    assert (split_unique == [0, 10, 20, 30, 40, 41, 42, 43]).all()

    assert ((orig == 20) == (split == 20)).all(), \
        "Label 2 is a single component and therefore should remain untouched in the output"

    assert ((orig == 40) == (split == 40)).all(), \
        "Label 4 is a single component and therefore should remain untouched in the output"

    assert (split[:4,:4] == 10).all(), \
        "The largest segment in each split object is supposed to keep it's label"
    assert (split[:4,-4:] == 30).all(), \
        "The largest segment in each split object is supposed to keep it's label"

    lower_left_label = split[-1, 1]
    lower_right_label = split[-1, -1]
    bottom_center_label = split[-1, 5]

    assert lower_left_label != 10, "Split object was not relabeled"
    assert bottom_center_label != 10, "Split object was not relabeled"
    assert lower_right_label != 30, "Split object was not relabeled"

    assert lower_left_label in (
        41, 42, 43), "New labels are supposed to be consecutive with the old"
    assert lower_right_label in (
        41, 42, 43), "New labels are supposed to be consecutive with the old"
    assert bottom_center_label in (
        41, 42, 43), "New labels are supposed to be consecutive with the old"

    assert (split[-3:, :3] == lower_left_label).all()
    assert (split[-3:, -3:] == lower_right_label).all()
    assert (split[-2:, 4:6] == bottom_center_label).all()

    assert set(mapping.keys()) == set([10, 30, 41, 42,
                                       43]), f"mapping: {mapping}"

    mapper = LabelMapper(np.fromiter(mapping.keys(), np.uint64),
                         np.fromiter(mapping.values(), np.uint64))
    assert (mapper.apply(split, True) == orig).all(), \
        "Applying mapping to the relabeled image did not recreate the original image."
def prepare_graph(cc_df, bois, consecutivize=False):
    """
    Create networkx Graph for the given edges,
    and annotate each node with its body ID and whether
    or not it is a 'boi' (body of interest).
    
    If consecutivize=True, do not use body IDs as the node IDs.
    Instead, use consecutive integers 1..N as the node IDs.
    This is useful when vizualizing the graph with hvplot,
    as a workaround for the following issue:
    https://github.com/pyviz/hvplot/issues/218
    
    Args:
        cc_df:
            Edge table with columns:
                ['label_a', 'label_b', 'distance']
            OR
                ['body_a', 'body_b', 'distance']

        bois:
            Either a list of BOI IDs or a DataFrame with a 'body' column.
            If a DataFrame, other columns will be used to populate node attributes.

        consecutivize:
            If True, don't use the body IDs as the graph node IDs.
            Enumerate the bodies from 1..N+1 instead,
            and store 'body' as a node attribute.
            (This is a workaround for a bug in hvplot.)

    Returns:
        nx.Graph
    """
    if isinstance(bois, pd.DataFrame):
        boi_df = bois
        assert boi_df.index.name == 'body'
        bois = boi_df.index
    else:
        boi_df = None

    if 'body_a' in cc_df.columns and 'body_b' in cc_df.columns:
        edges = cc_df[['body_a', 'body_b']].values
    elif 'label_a' in cc_df.columns and 'label_b' in cc_df.columns:
        edges = cc_df[['label_a', 'label_b']].values
    else:
        raise RuntimeError("Could not find label or body columns.")

    # Pre-sort nodes to avoid visualization issues such as:
    # https://github.com/pyviz/hvplot/issues/223
    bodies = np.sort(pd.unique(edges.reshape(-1)))
    nodes = bodies

    if consecutivize:
        assert bodies.dtype == edges.dtype == np.uint64
        nodes = np.arange(1, 1 + len(bodies), dtype=np.uint32)
        mapper = LabelMapper(bodies, nodes)
        edges = mapper.apply(edges)

    g = nx.Graph()
    g.add_nodes_from(nodes.astype(np.int64))
    for (a, b), distance in zip(edges, cc_df['distance']):
        g.add_edge(a, b, distance=distance)

    for node, body in zip(nodes, bodies):
        g.nodes[node]['body'] = body
        g.nodes[node]['boi'] = (body in bois)

        # Append more node metadata if it's available.
        if boi_df is None:
            continue

        for col in boi_df.columns:
            if body in bois:
                g.nodes[node][col] = boi_df.loc[body, col]
            else:
                g.nodes[node][col] = -1
    return g
    def sparse_brick_coords_for_labels(self, labels, clip=True):
        """
        Return a DataFrame indicating the brick
        coordinates (starting corner) that encompass the given labels.

        Args:
            labels:
                A list of body IDs (if ``self.supervoxels`` is False),
                or supervoxel IDs (if ``self.supervoxels`` is True).

            clip:
                If True, filter the results to exclude any coordinates
                that fall outside this service's bounding-box.
                Otherwise, all brick coordinates that encompass the given labels
                will be returned, whether or not they fall within the bounding box.

        Returns:
            DataFrame with columns [z,y,x,label],
            where z,y,x represents the starting corner (in full-res coordinates)
            of a brick that contains the label.
        """
        assert not isinstance(labels,
                              set), "Pass labels as a list or array, not a set"
        labels = pd.unique(labels)
        is_supervoxels = self.supervoxels
        brick_shape = self.preferred_message_shape
        assert (brick_shape % self.block_width == 0).all(), \
            ("Brick shape ('preferred-message-shape') must be a multiple of the "
             f"block width ({self.block_width}) in all dimensions, not {brick_shape}")

        bad_labels = []

        if not is_supervoxels:
            # No supervoxel filtering.
            # Sort by body, since that should be slightly nicer for dvid performance.
            bodies_and_svs = {label: None for label in sorted(labels)}
        else:
            # Arbitrary heuristic for whether to do the body-lookups on DVID or on the client.
            if len(labels) < 100_000:
                # If we're only dealing with a few supervoxels,
                # ask dvid to map them to bodies for us.
                mapping = fetch_mapping(*self.instance_triple,
                                        labels,
                                        as_series=True)
            else:
                # If we're dealing with a lot of supervoxels, ask for
                # the entire mapping, and look up the bodies ourselves.
                complete_mapping = fetch_mappings(*self.instance_triple)
                mapper = LabelMapper(complete_mapping.index.values,
                                     complete_mapping.values)

                labels = np.asarray(labels, np.uint64)
                bodies = mapper.apply(labels, True)
                mapping = pd.Series(index=labels, data=bodies, name='body')
                mapping.index.rename('sv', inplace=True)

            bad_svs = mapping[mapping == 0]
            bad_labels.extend(bad_svs.index.tolist())

            # Group by body
            mapping = mapping[mapping != 0]
            grouped_svs = mapping.reset_index().groupby('body').agg(
                {'sv': list})['sv']

            # Sort by body, since that should be slightly nicer for dvid performance.
            bodies_and_svs = grouped_svs.sort_index().to_dict()

        # Extract these to avoid pickling 'self' (just for speed)
        server, uuid, instance = self.instance_triple
        if self._use_resource_manager_for_sparse_coords:
            mgr = self.resource_manager_client
        else:
            mgr = ResourceManagerClient("", 0)

        def fetch_brick_coords(body, supervoxel_subset):
            """
            Fetch the block coordinates for the given body,
            filter them for the given supervoxels (if any),
            and convert the block coordinates to brick coordinates.
            """
            assert is_supervoxels or supervoxel_subset is None

            try:
                with mgr.access_context(server, True, 1, 1):
                    labelindex = fetch_labelindex(server, uuid, instance, body,
                                                  'protobuf')
                coords_df = convert_labelindex_to_pandas(labelindex).blocks

            except HTTPError as ex:
                if (ex.response is not None
                        and ex.response.status_code == 404):
                    return (body, None)
                raise
            except RuntimeError as ex:
                if 'does not map to any body' in str(ex):
                    return (body, None)
                raise

            if len(coords_df) == 0:
                return (body, None)

            if is_supervoxels:
                supervoxel_subset = set(supervoxel_subset)
                coords_df = coords_df.query('sv in @supervoxel_subset').copy()

            coords_df[['z', 'y', 'x']] //= brick_shape
            coords_df['body'] = np.uint64(body)
            coords_df.drop_duplicates(inplace=True)
            return (body, coords_df)

        def fetch_and_concatenate_brick_coords(bodies_and_supervoxels):
            """
            To reduce the number of tiny DataFrames collected to the driver,
            it's best to concatenate the partitions first, on the workers,
            rather than a straightforward call to starmap(fetch_brick_coords).

            Hence, this function that consolidates each partition.
            """
            bad_bodies = []
            coord_dfs = []
            for (body, supervoxel_subset) in bodies_and_supervoxels:
                _, coords_df = fetch_brick_coords(body, supervoxel_subset)
                if coords_df is None:
                    bad_bodies.append(body)
                else:
                    coord_dfs.append(coords_df)
                    del coords_df

            if coord_dfs:
                return [(pd.concat(coord_dfs, ignore_index=True), bad_bodies)]
            else:
                return [(None, bad_bodies)]

        with Timer(
                f"Fetching coarse sparsevols for {len(labels)} labels ({len(bodies_and_svs)} bodies)",
                logger=logger):
            import dask.bag as db
            coords_and_bad_bodies = (
                db.from_sequence(
                    bodies_and_svs.items(), npartitions=4096
                )  # Instead of fancy heuristics, just pick 4096
                .map_partitions(fetch_and_concatenate_brick_coords).compute())

        coords_df_partitions, bad_body_partitions = zip(*coords_and_bad_bodies)

        for body in chain(*bad_body_partitions):
            if is_supervoxels:
                bad_labels.extend(bodies_and_svs[body])
            else:
                bad_labels.append(body)

        if bad_labels:
            name = 'sv' if is_supervoxels else 'body'
            pd.Series(bad_labels,
                      name=name).to_csv('labels-without-sparsevols.csv',
                                        index=False,
                                        header=True)
            if len(bad_labels) < 100:
                msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels: {bad_labels}"
            else:
                msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels. See labels-without-sparsevols.csv"

            logger.warning(msg)

        coords_df_partitions = list(
            filter(lambda df: df is not None, coords_df_partitions))
        if len(coords_df_partitions) == 0:
            raise RuntimeError(
                "Could not find bricks for any of the given labels")

        coords_df = pd.concat(coords_df_partitions, ignore_index=True)

        if self.supervoxels:
            coords_df['label'] = coords_df['sv']
        else:
            coords_df['label'] = coords_df['body']

        coords_df.drop_duplicates(['z', 'y', 'x', 'label'], inplace=True)
        coords_df[['z', 'y', 'x']] *= brick_shape

        if clip:
            # Keep if the last pixel in the brick is to the right of the bounding-box start
            # and the first pixel in the brick is to the left of the bounding-box stop
            keep = (coords_df[['z', 'y', 'x']] + brick_shape >
                    self.bounding_box_zyx[0]).all(axis=1)
            keep &= (coords_df[['z', 'y', 'x']] <
                     self.bounding_box_zyx[1]).all(axis=1)
            coords_df = coords_df.loc[keep]

        return coords_df[['z', 'y', 'x', 'label']]
Exemple #15
0
def compute_focused_paths(server,
                          uuid,
                          instance,
                          original_mapping,
                          important_bodies,
                          speculative_merge_tables,
                          split_mapping=None,
                          max_depth=10,
                          stop_after_endpoint_num=None,
                          return_after_setup=False):

    from ..merge_graph import LabelmapMergeGraph

    with Timer("Loading speculative merge graph", logger):
        merge_graph = LabelmapMergeGraph(speculative_merge_tables, uuid)

    if split_mapping is not None:
        _bad_edges = merge_graph.append_edges_for_split_supervoxels(
            split_mapping, server, uuid, instance, parent_sv_handling='drop')
    merge_table_df = merge_graph.merge_table_df

    if isinstance(original_mapping, str):
        with Timer("Loading mapping", logger):
            original_mapping = load_mapping(original_mapping)
    else:
        assert isinstance(original_mapping, pd.Series)

    with Timer("Applying mapping", logger):
        mapper = LabelMapper(original_mapping.index.values,
                             original_mapping.values)
        merge_table_df['body_a'] = mapper.apply(merge_table_df['id_a'].values,
                                                allow_unmapped=True)
        merge_table_df['body_b'] = mapper.apply(merge_table_df['id_b'].values,
                                                allow_unmapped=True)

    if isinstance(important_bodies, str):
        with Timer("Reading importances", logger):
            important_bodies = read_csv_col(important_bodies).values
        important_bodies = pd.Index(important_bodies, dtype=np.uint64)

    with Timer("Assigning importances", logger):
        merge_table_df['important_a'] = merge_table_df['body_a'].isin(
            important_bodies)
        merge_table_df['important_b'] = merge_table_df['body_b'].isin(
            important_bodies)

    with Timer("Discarding merged edges within 'important' bodies ", logger):
        size_before = len(merge_table_df)
        merge_table_df.query(
            '(body_a != body_b) and not (important_a and important_b)',
            inplace=True)
        size_after = len(merge_table_df)
        logger.info(
            f"Discarded {size_before - size_after} edges, ended with {len(merge_table_df)} edges"
        )

    edges = merge_table_df[['id_a', 'id_b']].values
    assert edges.dtype == np.uint64

    if return_after_setup:
        logger.info("Returning setup instead of searching for paths")
        return (edges, original_mapping, important_bodies, max_depth,
                stop_after_endpoint_num)

    logger.info(
        f"Finding paths among {len(important_bodies)} important bodies")
    # FIXME: Need to augment original_mapping with rows for single-sv bodies that are 'important'
    all_paths = find_all_paths(edges, original_mapping, important_bodies,
                               max_depth, stop_after_endpoint_num)
    return all_paths
Exemple #16
0
def find_missing_adjacencies(server, uuid, instance, body, known_edges, svs=None, search_distance=1, connect_non_adjacent=False):
    """
    Given a body and an intra-body merge graph defined by the given
    list of "known" supervoxel-to-supervoxel edges within that body,
    
    1. Determine whether or not all supervoxels in the body are
       connected by a single component within the given graph.
       If so, return immediately.
    
    2. Attempt to augment the graph with additional edges based on
       supervoxel adjacencies in the segmentation from DVID.
       This is done by downloading the DVID labelindex to determine
       which blocks might contain adjacent supervoxels that could unify
       the graph, and then downloading those blocks (only) to search
       for the adjacencies.
    
    Notes:
        - Requires scikit-image (which, currently, is not otherwise
          listed as a dependency of neuclease's conda-recipe).
        
        - This function does not attempt to find ALL adjacencies between supervoxels;
          it stops looking as soon as they form a single connected component.

        - This function only considers two supervoxels "adjacent" if they are
          literally touching each other in the scale-0 segmentation. If there is
          a small gap between them, then they are not considered adjacent.
        
        - This function does not attempt to find inter-block adjacencies;
          only adjacencies within each block are detected.
          So, in pathological cases where a supervoxel is only adjacent to the
          rest of the body on a block-aligned edge, the adjacency will not be
          detected by this funciton.
        
    Args:
        server, uuid, instance:
            DVID segmentation labelmap instance
        
        body:
            ID of the body to inspect
        
        known_edges:
            ndarray (N,2), array of supervoxel pairs;
            known edges of the intra-body merge graph
        
        svs:
            Optional. The complete list of supervoxels
            that belong to this body, according to DVID.
            Providing this enhances performance in one important case:
            If the known_edges ALREADY constitute a single connected component
            which covers all supervoxels in the body, there is no need to
            download the labelindex.
        
        search_distance:
            If > 1, supervoxels are considered adjacent if they are within
            the given distance from each other, even if they aren't directly adjacent.
        
        connect_non_adjacent:
            If searching by adjacency failed to fully connect all supervoxels in the
            body into a single connected component, generate edges for supervoxels
            that are not adjacent, but merely are in the same block (if it helps
            unify the body).
    
    Returns:
        (new_edges, orig_num_cc, final_num_cc, block_tables),
        Where:
            new_edges are the new edges found via inspection of supervoxel adjacencies,
            
            orig_num_cc is the number of disjoint components in the given merge graph before
                this function runs,
            
            final_num_cc is the number of disjoint components after adding the new_edges,
            
            block_tables contains debug information about the adjacencies found in each
                block of analyzed segmentation
                
        Ideally, final_num_cc == 1, but in some cases the body's supervoxels may not be
        directly adjacent, or the adjacencies were not detected.  (See notes above.)
    """
    from skimage.morphology import dilation
    
    BLOCK_TABLE_COLS = ['z', 'y', 'x', 'sv_a', 'sv_b', 'cc_a', 'cc_b', 'detected', 'applied']
    known_edges = np.asarray(known_edges, np.uint64)
    if svs is None:
        # We could compute the supervoxel list ourselves from 
        # the labelindex, but dvid can do it faster.
        svs = fetch_supervoxels_for_body(server, uuid, instance, body)

    cc = connected_components_nonconsecutive(known_edges, svs)
    orig_num_cc = final_num_cc = cc.max()+1
    
    if orig_num_cc == 1:
        return np.zeros((0,2), np.uint64), orig_num_cc, final_num_cc, pd.DataFrame(columns=BLOCK_TABLE_COLS)

    labelindex = fetch_labelindex(server, uuid, instance, body, format='protobuf')
    encoded_block_coords = np.fromiter(labelindex.blocks.keys(), np.uint64, len(labelindex.blocks))
    coords_zyx = decode_labelindex_blocks(encoded_block_coords)

    cc_mapper = LabelMapper(svs, cc)
    svs_set = set(svs)

    sv_adj_found = []
    cc_adj_found = set()
    block_tables = {}
    
    searched_block_svs = {}
    
    for coord_zyx, sv_counts in zip(coords_zyx, labelindex.blocks.values()):
        # Given the supervoxels in this block, what CC adjacencies
        # MIGHT we find if we were to inspect the segmentation?
        block_svs = np.fromiter(sv_counts.counts.keys(), np.uint64)
        block_ccs = cc_mapper.apply(block_svs)
        possible_cc_adjacencies = set(combinations( set(block_ccs), 2 ))
        
        # We only aim to find (at most) a single link between each CC pair.
        # That is, we don't care about adjacencies between CC that we've already linked so far.
        possible_cc_adjacencies -= cc_adj_found
        if not possible_cc_adjacencies:
            continue

        searched_block_svs[(*coord_zyx,)] = block_svs
        
        # Not used in the search; only returned for debug purposes.
        try:
            block_adj_table = _init_adj_table(coord_zyx, block_svs, cc_mapper)
        except:
            raise

        block_vol = fetch_block_vol(server, uuid, instance, coord_zyx, svs_set)
        if search_distance > 0:
            # It would be nice to do a proper spherical dilation,
            # but apparently dilation() is special-cased to be WAY
            # faster with a square structuring element, and we prefer
            # speed over cleaner dilation.
            # footprint = skimage.morphology.ball(dilation)
            radius = search_distance//2
            footprint = np.ones(3*(1+2*radius,), np.uint8)
            dilated_block_vol = dilation(block_vol, footprint)
            
            # Since dilation is a max-filter, we might have accidentally
            # erased small, low-valued supervoxels, erasing the adjacendies.
            # Overlay the original volume to make sure they still count.
            block_vol = np.where(block_vol, block_vol, dilated_block_vol)
        
        sv_adjacencies = compute_label_adjacencies(block_vol)
        sv_adjacencies['cc_a'] = cc_mapper.apply( sv_adjacencies['sv_a'].values )
        sv_adjacencies['cc_b'] = cc_mapper.apply( sv_adjacencies['sv_b'].values )
        
        found_new_adj = False
        for row in sv_adjacencies.itertuples(index=False):
            if (row.cc_a != row.cc_b):
                sv_adj = (row.sv_a, row.sv_b)
                cc_adj = (row.cc_a, row.cc_b)
                
                # Normalize
                if row.cc_a > row.cc_b:
                    cc_adj = (row.cc_b, row.cc_a)

                if row.sv_a > row.sv_b:
                    sv_adj = (row.sv_b, row.sv_a)

                block_adj_table.loc[sv_adj, 'detected'] = True
                    
                if cc_adj not in cc_adj_found:
                    found_new_adj = True
                    cc_adj_found.add( cc_adj )
                    sv_adj_found.append( sv_adj )
                    
                    block_adj_table.loc[sv_adj, 'applied'] = True

        block_tables[(*coord_zyx,)] = block_adj_table

        # If we made at least one change and we've 
        # finally unified all components, then we're done.
        if found_new_adj:
            final_num_cc = connected_components(np.array(list(cc_adj_found), np.uint64), orig_num_cc).max()+1
            if final_num_cc == 1:
                break
    
    # If we couldn't connect everything via direct adjacencies,
    # we can just add edges for any supervoxels that share a block.
    if final_num_cc > 1 and connect_non_adjacent:
        for coord_zyx, block_svs in searched_block_svs.items():
            block_ccs = cc_mapper.apply(block_svs)
            
            # We only need one SV per connected component,
            # so load them into a dict.
            selected_svs = dict(zip(block_ccs, block_svs))
            for (sv_a, sv_b) in combinations(sorted(selected_svs.values()), 2):
                (cc_a, cc_b) = cc_mapper.apply(np.array([sv_a, sv_b], np.uint64))
                if cc_a > cc_b:
                    cc_a, cc_b = cc_b, cc_a
                
                if (cc_a, cc_b) not in cc_adj_found:
                    if sv_a > sv_b:
                        sv_a, sv_b = sv_b, sv_a

                    cc_adj_found.add( (cc_a, cc_b) )
                    sv_adj_found.append( (sv_a, sv_b) )

                    block_tables[(*coord_zyx,)].loc[(sv_a, sv_b), 'applied'] = True

        final_num_cc = connected_components(np.array(list(cc_adj_found), np.uint64), orig_num_cc).max()+1
    
    if len(block_tables) == 0:
        block_table = pd.DataFrame(columns=BLOCK_TABLE_COLS)
    else:
        block_table = pd.concat(block_tables.values(), sort=False).reset_index()
        block_table = block_table[BLOCK_TABLE_COLS]
    
    new_edges = np.array(sv_adj_found, np.uint64)
    return new_edges, int(orig_num_cc), int(final_num_cc), block_table
    def execute(self):
        """
        Computes connected components across an entire volume,
        possibly considering only a subset of all labels,
        and writes the result to another volume.
        
        (Even if only some labels were analyzed for CC, all labels are written to the output.)
        
        Objects that were not "disconnected" in the first place are not modified.
        Their voxels retain their original values and they are simply copied to
        the output.  Only objects which consist of multiple disconnected pieces
        are given new label values.

        In addition to exporting a few intermediate dataframes for debugging purposes,
        this workflow also exports the final table of all objects that were relabeled
        in the output, with the IDs that were used to compute the table for every
        processed brick.  The table is written to 'node_df_final.pkl'

        Note:
            The table does NOT include components that didn't need to
            be re-written at all. (This workflow does not alter labels
            that consisted of only one component to begin with.)

        Example:

               lz0    ly0    lx0       orig         cc  link_cc  final_cc
            0  512  10752  11776  110390011  144347596        0         2
            1  512  10752  11776  110390011  144347597        1         3
            2  512  10752  11776  110390011  144347598        2         4
            3  512  10752  11776  110390011  144347599        3         5
            4  512  10752  11776  110390011  144347600        4         6

        Columns Descriptions:

            lz0, ly0, lx0:
                Logical starting corner of a brick.
                Does not include the 1-px halo that was used to ensure bricks overlapped
                in space.  Also does not account for the fact that bricks on the edge
                of the volume may have been cropped during processing due to a user-provided
                bounding-box.
            orig:
                Label in the original volume
            cc:
                label after running connected components on each brick independently
                (but adding an offset to each brick's labels to ensure that no IDs
                are duplicated across bricks)
            link_cc:
                The global component ID after unifying cc labels across neighboring bricks
            final_cc:
                The final component label written to the output.
                Computed my mapping the link_cc values into the correct output
                label range as determined by orig-max-label

        Procedure Overview:

        1. Accepts any segmentation volume as input, divided into
           'bricks' as usual, but with a halo of at least 1 voxel.
        
        2. Computes connected components (CC) on each brick independently.
           Note: Before computing the CC, the "subset-labels" setting is used to
           mask out everything except the desired objects.
           Note that the resulting CC labels for each brick are not unique -- each
           brick has values of 1..cc_max.
           Also, the original label from which each CC label was created is stored in
           the "overlap mapping" (dataframe for each brick with columns: ['cc', 'orig']).
           The CC volume, overlap mapping, cc_max, and raw data max
           for each brick are cached (persisted).
        
        3. The CC values are made unique across all bricks by adding an offset to each brick.
           The offset is determined by multiplying the brick's ID number (scan-order index)
           with the maximum CC value found in all the brick CC volumes.
           Note: This means that the CC values will be unique, but not consecutive across all bricks.
           The overlap mapping (cc->orig) is also updated accordingly.
        
        4. Next, we need to determine how to unify objects which span across brick boundaries.
           First, the halos from each brick are extracted and matched with the halos from 
           their neighboring bricks. This is achieved via a dask DataFrame merge.
           Then these halo pairs are aligned so that the CC label in the "left" halo can
           be matched with the CC label in the "right" halo.
           These pairings "link" the CC objects in one brick to another brick,
           and are referred to as "pairwise links".
        
        5. These links form the edges of a graph, and the connected components *on the graph*
           determine which CC labels should be unified in the output.
           Note that here, we are referring to a graph-CC, not to be confused with the
           connected components operations we ran earlier, on each brick's volume data.
           This graph-CC operation yields a "final mapping" from brick-CC labels to unified
           label values.
           Note: Since we are not interested in relabeling objects that weren't
           discontiguous in the original volume, we drop rows of the mapping for objects whose
           'orig' value only appears once in the final mapping.  This reduces the size of the
           mapping, which must be sent to the workers.  In fact, before we even run the graph-CC
           operation, we drop CC values if their original label doesn't appear more than once in
           the overlap mapping from above. This reduces the size of the graph-CC problem.
        
        6. The final mapping is distrbuted to all workers in pieces, via pandas DataFrame merge.
        
        7. The final mapping is applied to the CC bricks, and written to the output.
           If the "subset-labels" option was used, the CC brick may consist of zeros for those
           voxels that did not contain a label of interest (see step 1, above).  But the output
           is guaranteed to contain all of the unsplit original objects, so we use the original
           volume to supply the remaining data before it is written to the output.
           
        """
        # TODO:
        #
        #  - Refactor this super-long function.
        #
        #  - Maybe parallelize some of those dataframe operations,
        #    rather than collecting it all on the driver right away?
        #
        #  - Don't bother including LB_COLS in dataframes? (see BrickWall.bricks_as_ddf)
        #
        #  - Unpersist unneeded bags? (maybe I need them all, though...)
        #
        #  - Refactor this so that the bulk of the code is re-usable for workflows in which 
        #    the segmentation did not come from a common source.
        #    (But see notes about necessary changes below.)
        #
        #  - Right now, voxels OUTSIDE the subset-labels are copied verbatim into the output.
        #    If the output segmentation is different from the input segmentation,
        #    it might be nice to fetch the output bricks and then paste the subset mask on top of it,
        #    rather than copying from the input mask.
        #
        #  - For DVID volumes, this workflow only writes supervoxels, not labels.
        #    To write labels, one would need to first split supervoxels (if necessary) via this workflow,
        #    and then partition complete labels according to groups of supervoxels.
        #    That might be best performed via the FindAdjacencies workflow, anyhow.

        self.init_services()

        input_service = self.input_service
        output_service = self.output_service
        options = self.config["connectedcomponents"]

        is_supervoxels = False
        if isinstance(input_service.base_service, DvidVolumeService):
            is_supervoxels = input_service.base_service.supervoxels

        # Load body list and eliminate duplicates
        subset_labels = load_body_list(options["subset-labels"], is_supervoxels)
        subset_labels = set(subset_labels)
        
        sparse_fetch = not options["skip-sparse-fetch"]
        input_wall = self.init_brickwall(input_service, sparse_fetch and subset_labels, options["roi"])
        
        def brick_cc(brick):
            orig_vol = brick.volume
            brick.compress()

            # Track the original max so we know what the first
            # available label is when we write the final results.
            orig_max = orig_vol.max()
            
            if subset_labels:
                orig_vol = apply_mask_for_labels(orig_vol, subset_labels)
            
            # Fast path for all-zero bricks
            if not orig_vol.any():
                cc_vol = orig_vol
                cc_overlaps = pd.DataFrame({'orig': [], 'cc': []}, dtype=np.uint64)
                cc_max = np.uint64(0)
            else:
                cc_vol = skm.label(orig_vol, background=0, connectivity=1)
                assert cc_vol.dtype == np.int64
                cc_vol = cc_vol.view(np.uint64)
                
                # Leave 0-pixels alone.
                cc_vol[orig_vol == 0] = np.uint64(0)
                
                # Keep track of which original values each cc corresponds to.
                cc_overlaps = pd.DataFrame({'orig': orig_vol.reshape(-1), 'cc': cc_vol.reshape(-1)})
                cc_overlaps.query('orig != 0 and cc != 0', inplace=True)
                cc_overlaps = cc_overlaps.drop_duplicates()
                assert (cc_overlaps.dtypes == np.uint64).all()
    
                if len(cc_overlaps) > 0:
                    cc_max = cc_overlaps['cc'].max()
                else:
                    cc_max = np.uint64(0)
            
            cc_brick = Brick( brick.logical_box,
                              brick.physical_box,
                              cc_vol,
                              location_id=brick.location_id,
                              compression=brick.compression )

            return cc_brick, cc_overlaps, cc_max, orig_max

        cc_results = input_wall.bricks.map(brick_cc).persist()
        cc_bricks, cc_overlaps, cc_maxes, orig_maxes = cc_results.unzip(4)

        # Persist like crazy... trying to work around a non-deterministic scheduler issue.
        cc_bricks = cc_bricks.persist()
        cc_overlaps = cc_overlaps.persist()
        cc_maxes = cc_maxes.persist()
        orig_maxes = orig_maxes.persist()

        with Timer("Computing blockwise CC", logger):
            max_brick_cc = cc_maxes.max().compute()
        
        with Timer("Saving brick maxes", logger):
            def corner_and_maxes(cc_brick, _cc_overlaps, cc_max, orig_max):
                return (*cc_brick.logical_box[0], cc_max, orig_max)
            brick_maxes = cc_results.starmap(corner_and_maxes).compute()
            brick_maxes_df = pd.DataFrame(brick_maxes, columns=['z', 'y', 'x', 'cc_max', 'orig_max'])
            brick_maxes_df.to_csv('brick-maxes.csv', header=True, index=False)
        
        wall_box = input_wall.bounding_box
        wall_grid = input_wall.grid
        
        def add_cc_offsets(brick, cc_overlaps):
            brick_index = BrickWall.compute_brick_index(brick, wall_box, wall_grid)
            cc_offset = np.uint64(brick_index * (max_brick_cc+1))
            
            # Don't overwrite zero voxels
            offset_cc_vol = np.where(brick.volume, brick.volume + cc_offset, np.uint64(0))
            cc_overlaps = cc_overlaps.copy()
            cc_overlaps.loc[(cc_overlaps['cc'] != 0), 'cc'] += np.uint64(cc_offset)
            
            # Append columns for brick location while we're at it.
            cc_overlaps['lz0'] = cc_overlaps['ly0'] = cc_overlaps['lx0'] = np.int32(0)
            cc_overlaps.loc[:, ['lz0', 'ly0', 'lx0']] = brick.logical_box[0]

            brick.compress()
            new_brick = Brick( brick.logical_box,
                               brick.physical_box,
                               offset_cc_vol,
                               location_id=brick.location_id,
                               compression=brick.compression )
            
            return new_brick, cc_overlaps

        # Now relabel each cc_brick so that label ids in different bricks never coincide
        offset_cc_results = bag_zip(cc_bricks, cc_overlaps).starmap(add_cc_offsets)
        offset_cc_results.persist()
        cc_bricks, cc_overlaps = offset_cc_results.unzip(2)

        # Extract halos.        
        # Note: No need to extract halos on all sides: outer-lower will overlap
        #       with inner-upper, which is good enough for computing CC.
        outer_halos = extract_halos(cc_bricks, input_wall.grid, 'outer', 'lower')
        inner_halos = extract_halos(cc_bricks, input_wall.grid, 'inner', 'upper')

        outer_halos_ddf = BrickWall.bricks_as_ddf(outer_halos, logical=False, physical=True, names='long')
        inner_halos_ddf = BrickWall.bricks_as_ddf(inner_halos, logical=False, physical=True, names='long')
        
        # Combine halo DFs along physical boxes, so that each outer halo is paired
        # with its overlapping partner (an inner halo, extracted from a different original brick).
        # Note: This is a pandas 'inner' merge, not to be confused with the 'inner' halos!
        combined_halos_ddf = outer_halos_ddf.merge(inner_halos_ddf, 'inner', PB_COLS, suffixes=['_outer', '_inner'])

        def find_pairwise_links(outer_brick, inner_brick):
            assert (outer_brick.physical_box == inner_brick.physical_box).all()

            # TODO: If this workflow is ever refactored into a set of utility functions,
            #       where each brick's segmentation might have been computed independently, 
            #       we'll probably want to permit the user to specify a minimum required
            #       overlap for neighboring objects to be considered 'linked'.

            table = pd.DataFrame({ 'cc_outer': outer_brick.volume.reshape(-1),
                                   'cc_inner': inner_brick.volume.reshape(-1) })
            table = table.drop_duplicates()
            #table = contingency_table(outer_brick.volume, inner_brick.volume).reset_index()
            #table.rename(columns={'left': 'cc_outer', 'right': 'cc_inner'}, inplace=True)

            outer_brick.compress()
            inner_brick.compress()

            # Omit label 0
            table = table.query('cc_outer != 0 and cc_inner != 0')
            return table

        def find_partition_links(combined_halos_df):
            tables = []
            for row in combined_halos_df.itertuples(index=False):
                table = find_pairwise_links(row.brick_outer, row.brick_inner)
                tables.append(table)

            if tables:
                return pd.concat(tables, ignore_index=True)
            else:
                return pd.DataFrame({'cc_outer': [], 'cc_inner': []}, dtype=np.uint64)
            
        links_meta = { 'cc_outer': np.uint64, 'cc_inner': np.uint64 }
        
        with Timer("Offsetting block CC and computing links", logger):
            links_df = combined_halos_ddf.map_partitions(find_partition_links, meta=links_meta).clear_divisions()
            try:
                links_df = links_df.compute()
            except CancelledError as ex:
                # This is an attempt to get more info when the scheduler is misbehaving.
                task_name = ex.args[0]
                story = self.client.cluster.scheduler.story(task_name)
                logger.error(f"Task cancelled: {task_name}.\nStory:\n{story}")
                raise

            assert links_df.columns.tolist() == ['cc_outer', 'cc_inner'] 
            assert (links_df.dtypes == np.uint64).all()

        with Timer("Writing links_df.pkl", logger):
            pickle.dump(links_df, open('links_df.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

        with Timer("Concatenating cc_overlaps", logger):
            cc_mapping_df = pd.concat(cc_overlaps.compute(), ignore_index=True)
            cc_mapping_df = cc_mapping_df[['lz0', 'ly0', 'lx0', 'orig', 'cc']]
            
        with Timer("Writing cc_mapping_df.pkl", logger):
            pickle.dump(cc_mapping_df, open('cc_mapping_df.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
            
        #
        # Append columns for original labels
        #
        with Timer("Joining link CC column with original labels", logger):
            # The pandas way...
            #cc_mapping = cc_mapping_df.set_index('cc')['orig']
            #links_df = links_df.merge(cc_mapping, 'left', left_on='cc_outer', right_index=True)[['cc_outer', 'cc_inner', 'orig']]
            #links_df = links_df.merge(cc_mapping, 'left', left_on='cc_inner', right_index=True, suffixes=['_outer', '_inner'])

            # The LabelMapper way...
            assert (links_df.dtypes == np.uint64).all()
            assert (cc_mapping_df[['cc', 'orig']].dtypes == np.uint64).all()

            cc = cc_mapping_df['cc'].values
            cc_orig = cc_mapping_df['orig'].values
            cc_mapper = LabelMapper(cc, cc_orig)
            links_df['orig_outer'] = cc_mapper.apply(links_df['cc_outer'].values)
            links_df['orig_inner'] = cc_mapper.apply(links_df['cc_inner'].values)
            
            # If we know the input segmentation source is the same for every brick
            # (i.e. it comes from a pre-computed source, where the halos should exactly match),
            # Then this assertion is true. 
            # It will not be true if we ever change this workflow to a general connected
            # components workflow, where each block of segmentaiton might be generated
            # independently, and thus halos may not match exactly.
            assert links_df.eval('orig_outer == orig_inner').all(), \
                "Something is wrong -- either the halos are not aligned, or the mapping of CC->orig is wrong."

            links_df = links_df[['cc_outer', 'cc_inner', 'orig_outer', 'orig_inner']]
            assert (links_df.dtypes == np.uint64).all()

        with Timer("Writing links_df.pkl", logger):
            pickle.dump(links_df, open('links_df.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

        with Timer("Dropping 'island' components", logger):
            # Before computing the CC of the whole graph, we can 
            # immediately discard any 'island' objects that only appear once.
            # That should make the CC operation run faster.

            # Note:
            #    Since our original segmentation blocks come from
            #    a common source (not computed independently), there is technically
            #    no need to also check for linked nodes, since by definition both sides
            #    of a link have the same original label.
            #    However, if we ever convert this code into a general connected components
            #    workflow, in which the segmentation blocks (including halos) might have
            #    been generated independently, then this additional check will be necessary.
            #    
            #     linked_nodes_orig = set(links_df['orig_outer'].values) | set(links_df['orig_inner'].values)
            #     node_df = cc_mapping_df.query(orig in @repeated_orig_labels or orig in @linked_nodes_orig')

            multiblock_rows = cc_mapping_df['orig'].duplicated(keep=False)
            node_df = cc_mapping_df.loc[multiblock_rows].copy()

        with Timer("Computing link CC", logger):
            # Compute connected components across all linked objects
            halo_links = links_df[['cc_outer', 'cc_inner']].values
            link_cc = connected_components_nonconsecutive(halo_links, node_df['cc'].values)
            node_df['link_cc'] = link_cc.astype(np.uint64)
            del halo_links, link_cc

        with Timer("Writing node_df_unfiltered.pkl", logger):
            pickle.dump(node_df, open('node_df_unfiltered.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

        with Timer("Dropping unsplit objects", logger):
            # Original objects that ended up with only a single
            # component need not be relabeled. Drop them from the table.

            # Note:
            #    Since our original segmentation blocks come from
            #    a common source (not computed independently), there is technically
            #    no need to also check for merged original objects, since by definition
            #    both sides of a link have the same original label.
            #    However, if we ever convert this code into a general connected components
            #    workflow, in which the segmentation blocks (including halos) might have
            #    been generated independently, then we must also keep objects that were part
            #    of a linked_cc with multiple original components.
            #
            #      original_cc_counts = node_df.groupby('link_cc').agg({'orig': 'nunique'}).rename(columns={'orig': 'num_components'})
            #      final_cc_counts = node_df.groupby('orig').agg({'link_cc': 'nunique'}).rename(columns={'link_cc': 'num_components'})
            #      merged_orig_objects = original_cc_counts.query('num_components > 1').index #@UnusedVariable
            #      split_orig_objects = final_cc_counts.query('num_components > 1').index #@UnusedVariable
            #      node_df = node_df.query('orig in @split_orig_objects or link_cc in @merged_orig_objects').copy()

            final_cc_counts = node_df.groupby('orig').agg({'link_cc': 'nunique'}).rename(columns={'link_cc': 'num_components'}, copy=False)
            split_orig_objects = final_cc_counts.query('num_components > 1').index #@UnusedVariable
            node_df = node_df.query('orig in @split_orig_objects').copy()
            num_final_fragments = len(pd.unique(node_df['link_cc']))

        # Compute the final label for each cc
        # Start by determining where the final label range should start.
        next_label = self.determine_next_label(num_final_fragments, orig_maxes)
        link_ccs = pd.unique(node_df['link_cc'].values)

        with Timer("Computing final mapping", logger):
            # Map link_cc to final_cc
            mapper = LabelMapper(link_ccs, np.arange(next_label, next_label+len(link_ccs), dtype=np.uint64))
            node_df['final_cc'] = mapper.apply(node_df['link_cc'].values)
        
        with Timer("Writing node_df_final.pkl", logger):
            pickle.dump(node_df, open('node_df_final.pkl', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

        if options["log-relabeled-objects"]:
            # This is mostly for convenient unit testing
            with Timer("Writing relabeled-objects.csv", logger):
                columns = {'final_cc': 'final_label', 'orig': 'orig_label'}
                csv_df = node_df[['final_cc', 'orig']].rename(columns=columns, copy=False).drop_duplicates()
                csv_df.to_csv('relabeled-objects.csv', index=False, header=True)
       
        #
        # Construct a Dask Dataframe holding the bricks we need to write.
        #
        wall_box = input_wall.bounding_box
        wall_grid = input_wall.grid

        with Timer("Preparing brick DataFrame", logging):
            def coords_and_bricks(orig_brick, cc_brick):
                assert (orig_brick.logical_box == cc_brick.logical_box).all()
                brick_index = BrickWall.compute_brick_index(orig_brick, wall_box, wall_grid)
                return (brick_index, orig_brick, cc_brick)
    
            dtypes = {'brick_index': np.int32, 'orig_brick': object, 'cc_brick': object}
            bricks_ddf = (bag_zip(input_wall.bricks, cc_bricks)
                            .starmap(coords_and_bricks)
                            .to_dataframe(dtypes))
    
            with Timer("Setting brick_index", logger):
                bricks_ddf = bricks_ddf.persist()
                bi = bricks_ddf['brick_index'].compute().tolist()
                assert bi == sorted(bi)
    
                bricks_ddf = bricks_ddf.set_index('brick_index', sorted=True)

        # The final mapping (node_df) might be too large to broadcast to all workers entirely.
        # We need to send only the relevant portion to each brick.
        with Timer("Preparing final mapping DataFrame", logger):
            class WrappedDf:
                """
                Trivial wrapper to allow us to store an entire
                DataFrame in every row of the following groupby()
                result."""
                def __init__(self, df):
                    self.df = df.copy()

                def __sizeof__(self):
                    return super().__sizeof__() + sys.getsizeof(self.df)
            
            with Timer("Grouping final CC mapping by brick", logger):
                grouped_mapping_df = (node_df[['lz0', 'ly0', 'lx0', 'cc', 'final_cc']]
                                        .groupby(['lz0', 'ly0', 'lx0'])
                                        .apply(WrappedDf)
                                        .rename('wrapped_brick_mapping_df')
                                        .reset_index())
    
            # We will need to merge according to brick location.
            # We could do that on [lz0,ly0,lx0], but a single-column merge will be faster,
            # so compute the brick_indexes and use that.
            brick_corners = grouped_mapping_df[['lz0', 'ly0', 'lx0']].values
            grouped_mapping_df['brick_index'] = BrickWall.compute_brick_indexes(brick_corners, wall_box, wall_grid)
            
            bi = grouped_mapping_df['brick_index'].tolist()
            assert bi == sorted(bi)
            
            grouped_mapping_df = grouped_mapping_df.set_index('brick_index')[['wrapped_brick_mapping_df']]
            grouped_mapping_ddf = ddf.from_pandas(grouped_mapping_df,
                                                  name='grouped_mapping',
                                                  npartitions=max(1, bricks_ddf.npartitions // 100))
            # Note:
            #   I've seen strange errors here from ddf.DataFrame.repartition() if all partitions are empty.
            #   If that's what you're seeing, make sure the input is reasonable.
            grouped_mapping_ddf = grouped_mapping_ddf.repartition(npartitions=bricks_ddf.npartitions)
            assert None not in grouped_mapping_ddf.divisions

        with Timer("Joining mapping and bricks", logger):
            # This merge associates each brick's part of the mapping with the correct row of bricks_ddf 
            bricks_ddf = bricks_ddf.merge(grouped_mapping_ddf, 'left', left_index=True, right_index=True)
            
            # We're done with these.
            del cc_bricks
            del input_wall

        def remap_cc_to_final(orig_brick, cc_brick, wrapped_brick_mapping_df):
            """
            Given an original brick and the corresponding CC brick,
            Relabel the CC brick according to the final label mapping,
            as provided in wrapped_brick_mapping_df.
            """
            assert (orig_brick.logical_box == cc_brick.logical_box).all()
            assert (orig_brick.physical_box == cc_brick.physical_box).all()

            # Check for NaN, which implies that the mapping for this
            # brick is empty (no objects to relabel).
            if isinstance(wrapped_brick_mapping_df, float):
                assert np.isnan(wrapped_brick_mapping_df)
                final_vol = orig_brick.volume
                orig_brick.compress()
            else:
                # Construct mapper from only the rows we need
                cc_labels = pd.unique(cc_brick.volume.reshape(-1)) # @UnusedVariable
                mapping = wrapped_brick_mapping_df.df.query('cc in @cc_labels')[['cc', 'final_cc']].values
                mapper = LabelMapper(*mapping.transpose())
    
                # Apply mapping to CC vol, writing zeros whereever the CC isn't mapped.
                final_vol = mapper.apply_with_default(cc_brick.volume, 0)
                
                # Overwrite zero voxels from the original segmentation.
                final_vol = np.where(final_vol, final_vol, orig_brick.volume)
            
                orig_brick.compress()
                cc_brick.compress()
            
            final_brick = Brick( orig_brick.logical_box,
                                 orig_brick.physical_box,
                                 final_vol,
                                 location_id=orig_brick.location_id,
                                 compression=orig_brick.compression )
            return final_brick

        collect_stats = options["compute-block-statistics"]

        bw = self.output_service.base_service.block_width
        if bw > 0:
            block_shape = 3*[self.output_service.base_service.block_width]
        else:
            block_shape = self.output_service.base_service.preferred_message_shape
         
        def write_brick(full_brick):
            brick = clip_to_logical(full_brick, False)
            
            # Don't re-compress; we're done with the brick entirely
            full_brick.destroy()

            vol = brick.volume
            brick.destroy()

            output_service.write_subvolume(vol, brick.physical_box[0], 0)
            if collect_stats:
                stats = block_stats_for_volume(block_shape, vol, brick.physical_box)
                return stats


        with Timer("Relabeling bricks and writing to output", logger):
            final_bricks = bricks_ddf.to_bag().starmap(remap_cc_to_final)
            del bricks_ddf
            all_stats = final_bricks.map(write_brick).compute()

        if collect_stats:
            with Timer("Writing block stats"):
                stats_df = pd.concat(all_stats, ignore_index=True)
                self.write_block_stats(stats_df)
Exemple #18
0
def compute_focused_bodies(server,
                           uuid,
                           instance,
                           synapse_samples,
                           min_tbars,
                           min_psds,
                           root_sv_sizes,
                           min_body_size,
                           sv_classifications=None,
                           marked_bad_bodies=None,
                           return_table=False,
                           kafka_msgs=None):
    """
    Compute the complete set of focused bodies, based on criteria for
    number of tbars, psds, or overall size, and excluding explicitly
    listed bad bodies.
    
    This function takes ~20 minutes to run on hemibrain inputs, with a ton of RAM.
    
    The procedure is:

    1. Apply synapse-based criteria
      a. Load synapse CSV file
      b. Map synapse SVs -> bodies (if needed)
        b2. If any SVs are 'retired', update those synapses to use the new IDs.
      c. Calculate synapses (tbars, psds) per body
      d. Initialize set with bodies that have enough synapses
    
    2. Apply size-based criteria
      a. Calculate body sizes (based on supervoxel sizes and current mapping)
      b. Add "big" bodies to the set
    
    3. Apply "bad body" criteria
      a. Read the list of "bad bodies"
      b. Remove bad bodies from the set
    
    Example:

        server = 'emdata3:8900'
        uuid = '7254'
        instance = 'segmentation'

        synapse_samples = '/nrs/flyem/bergs/complete-ffn-agglo/sampled-synapses-ef1d-locked.csv'

        min_tbars = 2
        min_psds = 10

        # old repo supervoxels (before server rebase)
        #
        # Note: This was taken from node 5501ae83e31247498303a159eef824d8, which is from a different repo.
        #       But that segmentation was eventually copied to the production repo as root node a776af.
        #       See /groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/8nm/copy-fixed-from-emdata2-to-emdata3-20180402.214505
        #
        #root_sv_sizes_dir = '/groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/8nm/compute-8nm-extended-fixed-STATS-ONLY-20180402.192015'
        #root_sv_sizes = f'{root_sv_sizes_dir}/supervoxel-sizes.h5'
        
        root_sv_sizes_dir = '/groups/flyem/data/scratchspace/copyseg-configs/labelmaps/hemibrain/flattened/compute-stats-from-corrupt-20181016.203848'
        root_sv_sizes = f'{root_sv_sizes_dir}/supervoxel-sizes-2884.h5'
        
        min_body_size = int(10e6)

        sv_classifications = '/nrs/flyem/bergs/sv-classifications.h5'
        marked_bad_bodies = '/nrs/flyem/bergs/complete-ffn-agglo/bad-bodies-2019-02-26.csv'
        
        table_description = f'{uuid}-{min_tbars}tbars-{min_psds}psds-{min_body_size / 1e6:.1f}Mv'
        focused_table = compute_focused_bodies(server, uuid, instance, synapse_samples, min_tbars, min_psds, root_sv_sizes, min_body_size, sv_classifications, marked_bad_bodies, return_table=True)

        # As npy:
        np.save(f'focused-{table_description}.npy', focused_table.to_records(index=True))

        # As CSV:
        focused_table.to_csv(f'focused-{table_description}.npy', index=True, header=True)
    
    Args:

        server, uuid, instance:
            labelmap instance

        root_sv_sizes:
            mapping of supervoxel sizes from the root node, as returned by load_supervoxel_sizes(),
            or a path to an hdf5 file from which it can be loaded
        
        synapse_samples:
            A DataFrame with columns 'body' (or 'sv') and 'kind', or a path to a CSV file with those columns.
            The 'kind' column is expected to have only 'PreSyn' and 'PostSyn' entries.
            If the table has an 'sv' column, any "retired" supervoxel IDs will be updated before
            continuing with the analysis.
        
        min_tbars:
            The minimum number pf PreSyn entries to pass the filter.
            Bodies with fewer tbars may still be included if they satisfy the min_psds criteria.
        
        min_psds:
            The minimum numer of PostSyn entires to pass the filter.
            Bodies with fewer psds may still pass the filter if they satisfy the min_tbars criteria.

        min_body_size:
            Determines which bodies are included on the basis of their size alone,
            regardless of synapse count.
        
        sv_classifications:
            Optional. Path to an hdf5 file containing supervoxel classifications.
            Must have datasets: 'supervoxel_ids', 'classifications', and 'class_names'.
            Used to exclude known-bad supervoxels. The file need not include all supervoxels.
            Any supervoxels MISSING from this file are not considered 'bad'.

        marked_bad_bodies:
            Optional. A list of known-bad bodies to exclude from the results,
            or a path to a .csv file with that list (in the first column),
            or a keyvalue instance name from which the list can be loaded as JSON.
        
        return_table:
            If True, return the focused bodies in a DataFrame, indexed by body,
            with columns for body size and synapse counts.
            If False, simply return the list of bodies (saves about 4 minutes).
    
    Returns:
        A list of body IDs that passed all filters, or a DataFrame with columns:
            ['voxel_count', 'PreSyn', 'PostSyn']
        (See return_table option.)
    """
    split_source = 'dvid'

    # Load full mapping. It's needed for both synapses and body sizes.
    mapping = fetch_complete_mappings(server,
                                      uuid,
                                      instance,
                                      include_retired=True,
                                      kafka_msgs=kafka_msgs)
    mapper = LabelMapper(mapping.index.values, mapping.values)

    ##
    ## Synapses
    ##
    if isinstance(synapse_samples, str):
        synapse_samples = load_synapses(synapse_samples)

    assert set(['sv', 'body']).intersection(set(synapse_samples.columns)), \
        "synapse samples must have either 'body' or 'sv' column"

    # If 'sv' column is present, use it to create (or update) the body column
    if 'sv' in synapse_samples.columns:
        synapse_samples = update_synapse_table(server,
                                               uuid,
                                               instance,
                                               synapse_samples,
                                               split_source=split_source)
        assert synapse_samples['sv'].dtype == np.uint64
        synapse_samples['body'] = mapper.apply(synapse_samples['sv'].values,
                                               True)

    with Timer("Filtering for synapses", logger):
        synapse_body_table = body_synapse_counts(synapse_samples)
        synapse_bodies = synapse_body_table.query(
            'PreSyn >= @min_tbars or PostSyn >= @min_psds').index.values
    logger.info(f"Found {len(synapse_bodies)} with sufficient synapses")

    focused_bodies = set(synapse_bodies)

    ##
    ## Body sizes
    ##
    with Timer("Filtering for body size", logger):
        sv_sizes = load_all_supervoxel_sizes(server,
                                             uuid,
                                             instance,
                                             root_sv_sizes,
                                             split_source=split_source)
        body_stats = compute_body_sizes(sv_sizes, mapping, True)
        big_body_stats = body_stats.query('voxel_count >= @min_body_size')
        big_bodies = big_body_stats.index
    logger.info(f"Found {len(big_bodies)} with sufficient size")

    focused_bodies |= set(big_bodies)

    ##
    ## SV classifications
    ##
    if sv_classifications is not None:
        with Timer(f"Filtering by supervoxel classifications"), h5py.File(
                sv_classifications, 'r') as f:
            sv_classes = pd.DataFrame({
                'sv':
                f['supervoxel_ids'][:],
                'klass':
                f['classifications'][:].astype(np.uint8)
            })

            # Get the set of bad supervoxels
            all_class_names = list(map(bytes.decode, f['class_names'][:]))
            bad_class_names = [
                'unknown', 'blood vessels', 'broken white tissue', 'glia',
                'oob'
            ]
            _bad_class_ids = set(map(all_class_names.index, bad_class_names))
            bad_svs = sv_classes.query('klass in @_bad_class_ids')['sv']

            # Add column for sizes
            bad_sv_sizes = pd.DataFrame(index=bad_svs).merge(
                pd.DataFrame(sv_sizes),
                how='left',
                left_index=True,
                right_index=True)

            # Append body
            bad_sv_sizes['body'] = mapper.apply(bad_sv_sizes.index.values,
                                                True)

            # For bodies that contain at least one bad supervoxel,
            # compute the total size of the bad supervoxels they contain
            body_bad_voxels = bad_sv_sizes.groupby('body').agg(
                {'voxel_count': 'sum'})

            # Append total body size for comparison
            body_bad_voxels = body_bad_voxels.merge(body_stats[['voxel_count'
                                                                ]],
                                                    how='left',
                                                    left_index=True,
                                                    right_index=True,
                                                    suffixes=('_bad',
                                                              '_total'))

            bad_bodies = body_bad_voxels.query(
                'voxel_count_bad > voxel_count_total//2').index

            bad_focused_bodies = focused_bodies & set(bad_bodies)
            logger.info(
                f"Dropping {len(bad_focused_bodies)} bodies with more than 50% bad supervoxels"
            )

            focused_bodies -= bad_focused_bodies

    ##
    ## Marked Bad bodies
    ##
    if marked_bad_bodies is not None:
        if isinstance(marked_bad_bodies, str):
            if marked_bad_bodies.endswith('.csv'):
                marked_bad_bodies = read_csv_col(marked_bad_bodies, 0)
            else:
                # If it ain't a CSV, maybe it's a key-value instance and key to read from.
                raise AssertionError(
                    "FIXME: Need convention for specifying key-value instance and key"
                )
                #marked_bad_bodies = fetch_key(server, uuid, marked_bad_bodies, as_json=True)

        with Timer(
                f"Dropping {len(marked_bad_bodies)} bad bodies (from {len(focused_bodies)})"
        ):
            focused_bodies -= set(marked_bad_bodies)

    ##
    ## Prepare results
    ##
    logger.info(f"Found {len(focused_bodies)} focused bodies")
    focused_bodies = np.fromiter(focused_bodies, dtype=np.uint64)
    focused_bodies.sort()

    if not return_table:
        return focused_bodies

    with Timer("Computing full focused table", logger):
        # Start with an empty DataFrame (except index)
        focus_table = pd.DataFrame(index=focused_bodies)

        # Merge size/sv_count
        focus_table = focus_table.merge(body_stats,
                                        how='left',
                                        left_index=True,
                                        right_index=True,
                                        copy=False)

        # Add synapse columns
        focus_table = focus_table.merge(synapse_body_table,
                                        how='left',
                                        left_index=True,
                                        right_index=True,
                                        copy=False)

        focus_table.fillna(0, inplace=True)
        focus_table['voxel_count'] = focus_table['voxel_count'].astype(
            np.uint64)
        focus_table['sv_count'] = focus_table['sv_count'].astype(np.uint32)
        focus_table['PreSyn'] = focus_table['PreSyn'].astype(np.uint32)
        focus_table['PostSyn'] = focus_table['PostSyn'].astype(np.uint32)

        # Sort biggest..smallest
        focus_table.sort_values('voxel_count', ascending=False, inplace=True)
        focus_table.index.name = 'body'

    return focus_table
Exemple #19
0
def distance_transform_watershed(mask,
                                 smoothing=0.0,
                                 seed_mask=None,
                                 seed_labels=None,
                                 flood_from='interior'):
    """
    Compute a watershed over the distance transform within a mask.
    You can either compute the watershed from inside-to-outside or outside-to-inside.
    For the former, the watershed is seeded from the most interior points,
    and the distance transform is inverted so the watershed can proceed from low to high as usual.
    For the latter, the distance transform is seeded from the voxels immediately outside the mask,
    using labels as found in the seed_labels volume. In this mode, the results effectively tell
    you which exterior segment (in the seed volume) is closest to any given point within the
    interior of the mask.

    Or you can provide your own seeds if you think you know what you're doing.

    Args:
        mask:
            Only the masked area will be processed
        smoothing:
            If non-zero, run gaussian smoothing on the distance transform with the
            given sigma before defining seed points or running the watershed.

        seed_mask:
        seed_labels:
        flood_from:

    Returns:
        dt, labeled_seeds, ws, max_id

    Notes:
        This function provides a subset of the options that can be found in other
        libraries, such as:
            - https://github.com/ilastik/wsdt/blob/3709b27/wsdt/wsDtSegmentation.py#L26
            - https://github.com/constantinpape/elf/blob/34f5e76/elf/segmentation/watershed.py#L69

        This uses vigra's efficient true euclidean distance transform, which is superior to
        distance transform approximations, e.g. as found in https://imagej.net/Distance_Transform_Watershed

        The imglib2 distance transform uses the same algorithm as vigra:
        https://github.com/imglib/imglib2-algorithm/tree/master/src/main/java/net/imglib2/algorithm/morphology/distance
        http://www.theoryofcomputing.org/articles/v008a019/
    """
    mask = mask.astype(bool, copy=False)
    mask = vigra.taggedView(mask, 'zyx').astype(np.uint32)

    imask = np.logical_not(mask)
    outer_edge_mask = binary_edge_mask(mask, 'outer')

    assert flood_from in ('interior', 'exterior')
    if flood_from == 'interior':
        # Negate the distance transform result,
        # since watershed must start at minima, not maxima.
        # Convert to uint8 to benefit from 'turbo' watershed mode
        # (uses a bucket queue).
        dt = distance_transform(mask, False, smoothing, negate=True)

        if seed_mask is None:
            # requires float32 input for some reason
            if dt.ndim == 2:
                minima = vigra.analysis.localMinima(dt,
                                                    marker=np.nan,
                                                    neighborhood=8,
                                                    allowAtBorder=True,
                                                    allowPlateaus=False)
            else:
                minima = vigra.analysis.localMinima3D(dt,
                                                      marker=np.nan,
                                                      neighborhood=26,
                                                      allowAtBorder=True,
                                                      allowPlateaus=False)
            seed_mask = np.isnan(minima)
            del minima

        dt = normalize_image_range(dt, np.uint8)
    else:
        if seed_labels is None and seed_mask is None:
            logger.warning(
                "Without providing your own seed mask and/or seed labels, "
                "the watershed operation will simply be the same as a "
                "connected components operation.  Is that what you meant?")

        if seed_mask is None:
            seed_mask = outer_edge_mask.copy()

        # Dilate the mask once more.
        outer_edge_mask[:] |= binary_edge_mask(outer_edge_mask | mask, 'outer')

        dt = distance_transform(mask, False, smoothing, negate=False)
        dt = normalize_image_range(dt, np.uint8)

    if seed_labels is None:
        seed_mask = vigra.taggedView(seed_mask, 'zyx')
        labeled_seeds = vigra.analysis.labelMultiArrayWithBackground(
            seed_mask.view('uint8'))
    else:
        labeled_seeds = np.where(seed_mask, seed_labels, 0)

    # Make sure seed_mask matches labeled_seeds,
    # Even if some seed_labels were zero-valued
    seed_mask = (labeled_seeds != 0)

    # Must remap to uint32 before calling vigra's watershed.
    seed_mapper = None
    seed_values = None
    if labeled_seeds.dtype in (np.uint64, np.int64):
        labeled_seeds = labeled_seeds.astype(np.uint64)
        seed_values = np.sort(pd.unique(labeled_seeds.ravel()))
        if seed_values[0] != 0:
            seed_values = np.array([0] + list(seed_values), np.uint64)

        assert seed_values.dtype == np.uint64
        assert labeled_seeds.dtype == np.uint64

        ws_seed_values = np.arange(len(seed_values), dtype=np.uint32)
        seed_mapper = LabelMapper(seed_values, ws_seed_values)
        ws_seeds = seed_mapper.apply(labeled_seeds)
        assert ws_seeds.dtype == np.uint32
    else:
        ws_seeds = labeled_seeds

    # Fill the non-masked area with one big seed,
    # except for a thin border around the mask.
    # This saves time in the watershed step,
    # since these voxels now don't need to be
    # consumed in the watershed.
    dummy_seed = ws_seeds.max() + np.uint32(1)
    ws_seeds[np.logical_not(mask | outer_edge_mask)] = dummy_seed
    ws_seeds[outer_edge_mask & ~seed_mask] = 0

    dt[outer_edge_mask] = 255
    dt[seed_mask] = 0

    dt = vigra.taggedView(dt, 'zyx')
    ws_seeds = vigra.taggedView(ws_seeds, 'zyx')
    ws, max_id = vigra.analysis.watershedsNew(dt,
                                              seeds=ws_seeds,
                                              method='Turbo')

    # Areas that were unreachable without crossing over the border
    # could end up with the dummy seed.
    # We treat such areas as if they are outside of the mask.
    ws[ws == dummy_seed] = 0
    ws_seeds[imask] = 0

    # If we converted from uint64 to uint32 to perform the watershed,
    # convert back before returning.
    if seed_mapper is not None:
        ws = seed_values[ws]
    return dt, labeled_seeds, ws
Exemple #20
0
def select_hulls_for_mito_bodies(mito_body_ct,
                                 mito_bodies_mask,
                                 mito_binary,
                                 body_seg,
                                 hull_masks,
                                 seed_bodies,
                                 box,
                                 scale,
                                 viewer=None,
                                 res0=8,
                                 progress=False):

    mito_bodies_mito_seg = np.where(mito_bodies_mask & mito_binary, body_seg,
                                    0)
    nonmito_body_seg = np.where(mito_bodies_mask, 0, body_seg)

    hull_cc_overlap_stats = []
    for hull_cc, (mask_box, mask) in tqdm_proxy(hull_masks.items(),
                                                disable=not progress):
        mbms = mito_bodies_mito_seg[box_to_slicing(*mask_box)]
        masked_hull_cc_bodies = np.where(mask, mbms, 0)
        # Faster to check for any non-zero values at all before trying to count them.
        # This early check saves a lot of time in practice.
        if not masked_hull_cc_bodies.any():
            continue

        # This hull was generated from a particular seed body (non-mito body).
        # If it accidentally overlaps with any other non-mito bodies,
        # then delete those voxels from the hull.
        # If that causes the hull to become split apart into multiple connected components,
        # then keep only the component(s) which overlap the seed body.
        seed_body = seed_bodies[hull_cc]
        nmbs = nonmito_body_seg[box_to_slicing(*mask_box)]
        other_bodies = set(pd.unique(nmbs[mask])) - {0, seed_body}
        if other_bodies:
            # Keep only the voxels on mito bodies or on the
            # particular non-mito body for this hull (the "seed body").
            mbm = mito_bodies_mask[box_to_slicing(*mask_box)]
            mask[:] &= (mbm | (nmbs == seed_body))
            mask = vigra.taggedView(mask, 'zyx')
            mask_cc = vigra.analysis.labelMultiArrayWithBackground(
                mask.view(np.uint8))
            if mask_cc.max() > 1:
                mask_ct = contingency_table(mask_cc, nmbs).reset_index()
                keep_ccs = mask_ct['left'].loc[(mask_ct['left'] != 0) &
                                               (mask_ct['right'] == seed_body)]
                mask[:] = mask_for_labels(mask_cc, keep_ccs)

        mito_bodies, counts = np.unique(masked_hull_cc_bodies,
                                        return_counts=True)
        overlaps = pd.DataFrame({
            'mito_body': mito_bodies,
            'overlap': counts,
            'hull_cc': hull_cc,
            'hull_size': mask.sum(),
            'hull_body': seed_body
        })
        hull_cc_overlap_stats.append(overlaps)

    if len(hull_cc_overlap_stats) == 0:
        logger.warning("Could not find any matches for any mito bodies!")
        mito_body_ct['hull_body'] = np.uint64(0)
        return mito_body_ct

    hull_cc_overlap_stats = pd.concat(hull_cc_overlap_stats, ignore_index=True)
    hull_cc_overlap_stats = hull_cc_overlap_stats.query(
        'mito_body != 0').copy()

    # Aggregate the stats for each body and the hull bodies it overlaps with,
    # Select the hull_body with the most overlap, or in the case of ties, the hull body that is largest overall.
    # (Ties are probably more common in the event that two hulls completely encompass a small mito body.)
    hull_body_overlap_stats = hull_cc_overlap_stats.groupby(
        ['mito_body', 'hull_body'])[['overlap', 'hull_size']].sum()
    hull_body_overlap_stats = hull_body_overlap_stats.sort_values(
        ['mito_body', 'overlap', 'hull_size'], ascending=False)
    hull_body_overlap_stats = hull_body_overlap_stats.reset_index()

    mito_hull_selections = (hull_body_overlap_stats.drop_duplicates(
        'mito_body').set_index('mito_body')['hull_body'])
    mito_body_ct = mito_body_ct.merge(mito_hull_selections,
                                      'left',
                                      left_index=True,
                                      right_index=True)
    mito_body_ct['hull_body'] = mito_body_ct['hull_body'].fillna(0)

    dtypes = {col: np.float32 for col in mito_body_ct.columns}
    dtypes['hull_body'] = np.uint64
    mito_body_ct = mito_body_ct.astype(dtypes)

    if viewer:
        assert mito_hull_selections.index.dtype == mito_hull_selections.values.dtype == np.uint64
        mito_hull_mapper = LabelMapper(mito_hull_selections.index.values,
                                       mito_hull_selections.values)
        remapped_body_seg = mito_hull_mapper.apply(body_seg, True)
        remapped_body_seg = apply_mask_for_labels(remapped_body_seg,
                                                  mito_hull_selections.values)
        update_seg_layer(viewer, 'altered-bodies', remapped_body_seg, scale,
                         box)

        # Show the final hull masks (after erasure of non-target bodies)
        assert sorted(hull_masks.keys()) == [*range(1, 1 + len(hull_masks))]
        hull_cc_overlap_stats = hull_cc_overlap_stats.sort_values('hull_size')
        hull_seg = np.zeros_like(remapped_body_seg)
        for row in hull_cc_overlap_stats.itertuples():
            mask_box, mask = hull_masks[row.hull_cc]
            view = hull_seg[box_to_slicing(*mask_box)]
            view[:] = np.where(mask, row.hull_body, view)
        update_seg_layer(viewer, 'final-hull-seg', hull_seg, scale, box)

    return mito_body_ct
def find_edges_in_brick(brick,
                        closest_scale=None,
                        subset_groups=[],
                        subset_requirement=2):
    """
    Find all pairs of adjacent labels in the given brick,
    and find the central-most point along the edge between them.
    
    (Edges to/from label 0 are discarded.)
    
    If closest_scale is not None, then non-adjacent pairs will be considered,
    according to a particular heuristic to decide which pairs to consider.
    
    Args:
        brick:
            A Brick to analyze
        
        closest_scale:
            If None, then consider direct (touching) adjacencies only.
            If not-None, then non-direct "adjacencies" (i.e. close-but-not-touching) are found.
            In that case `closest_scale` should be an integer >=0 indicating the scale at which
            the analysis will be performed.
            Higher scales are faster, but less precise.
            See ``neuclease.util.approximate_closest_approach`` for more information.
        
        subset_groups:
            A DataFrame with columns [label, group].  Only the given labels will be analyzed
            for adjacencies.  Furthermore, edges (pairs) will only be returned if both labels
            in the edge are from the same group.
            
        subset_requirement:
            Whether or not both labels in each edge must be in subset_groups, or only one in each edge.
            (Currently, subset_requirement must be 2.)
        
    Returns:
        If the brick contains no edges at all (other than edges to label 0), return None.
        
        Otherwise, returns pd.DataFrame with columns:
            [label_a, label_b, forwardness, z, y, x, axis, edge_area, distance]. # fixme
        
        where label_a < label_b,
        'axis' indicates which axis the edge crosses at the chosen coordinate,
        
        (z,y,x) is always given as the coordinate to the left/above/front of the edge
        (depending on the axis).
        
        If 'forwardness' is True, then the given coordinate falls on label_a and
        label_b is one voxel "after" it (to the right/below/behind the coordinate).
        Otherwise, the coordinate falls on label_b, and label_a is "after".
        
        And 'edge_area' is the total count of the voxels touching both labels.
    """
    # Profiling indicates that query('... in ...') spends
    # most of its time in np.unique, believe it or not.
    # After looking at the implementation, I think it might help a
    # little if we sort the array first.
    brick_labels = np.sort(pd.unique(brick.volume.reshape(-1)))
    if (len(brick_labels) == 1) or (len(brick_labels) == 2 and
                                    (0 in brick_labels)):
        return None  # brick is solid -- no possible edges

    # Drop labels that aren't even present
    subset_groups = subset_groups.query('label in @brick_labels').copy()

    # Drop groups that don't have enough members (usually 2) in this brick.
    group_counts = subset_groups['group'].value_counts()
    _kept_groups = group_counts.loc[(group_counts >= subset_requirement)].index
    subset_groups = subset_groups.query('group in @_kept_groups').copy()

    if len(subset_groups) == 0:
        return None  # No possible edges to find in this brick.

    # Contruct a mapper that includes only the labels we'll keep.
    # (Other labels will be mapped to 0).
    # Also, the mapper converts to uint32 (required by _find_and_select_central_edges,
    # but also just good for RAM reasons).
    kept_labels = np.sort(np.unique(subset_groups['label'].values))
    remapped_kept_labels = np.arange(1, len(kept_labels) + 1, dtype=np.uint32)
    mapper = LabelMapper(kept_labels, remapped_kept_labels)
    reverse_mapper = LabelMapper(remapped_kept_labels, kept_labels)

    # Construct RAG -- finds all edges in the volume, on a per-pixel basis.
    remapped_volume = mapper.apply_with_default(brick.volume, 0)
    brick.compress()
    remapped_subset_groups = subset_groups.copy()
    remapped_subset_groups['label'] = mapper.apply(
        subset_groups['label'].values)

    try:
        if closest_scale is None:
            best_edges_df = _find_and_select_central_edges(
                remapped_volume, remapped_subset_groups, subset_requirement)
        else:
            best_edges_df = _find_closest_approaches(remapped_volume,
                                                     closest_scale,
                                                     remapped_subset_groups)
    except:
        brick_name = f"{brick.logical_box[:,::-1].tolist()}"
        np.save(f'problematic-remapped-brick-{brick_name}.npy',
                remapped_volume)
        logger.error(f"Error in brick (XYZ): {brick_name}"
                     )  # This will appear in the worker log.
        raise

    if best_edges_df is None:
        return None

    # Translate coordinates to global space
    best_edges_df.loc[:, ['za', 'ya', 'xa']] += brick.physical_box[0]
    best_edges_df.loc[:, ['zb', 'yb', 'xb']] += brick.physical_box[0]

    # Restore to original label set
    best_edges_df['label_a'] = reverse_mapper.apply(
        best_edges_df['label_a'].values)
    best_edges_df['label_b'] = reverse_mapper.apply(
        best_edges_df['label_b'].values)

    # Normalize
    swap_df_cols(best_edges_df, None, best_edges_df.eval('label_a > label_b'),
                 ['a', 'b'])

    return best_edges_df
Exemple #22
0
def edge_weighted_watershed(cleaned_edges, edge_weights, seed_labels):
    """
    Run nifty.graph.edgeWeightedWatershedsSegmentation() on the given graph with N nodes and E edges.
    The graph node IDs must be consecutive, starting with zero, dtype=np.uint32
    
    
    Args:
        cleaned_edges:
            array, (E,2), uint32
            Node IDs should be consecutive (more-or-less).
            To avoid segfaults:
                - Must not contain duplicates.
                - Must not contain 'loops' (no self-edges).
        
        edge_weights:
            array, (E,), float32
        
        seed_labels:
            array (N,), uint32
            All un-seeded nodes should be marked as 0.
        
    Returns:
        (output_labels, disconnected_components, contains_unlabeled_components)
        
        Where:
        
            output_labels:
                array (N,), uint32
                Agglomerated node labeling.
                
            disconnected_components:
                A set of seeds which ended up with more than one component in the result.
            
            contains_unlabeled_components:
                True if the input contains one or more disjoint components that were not seeded
                and thus not labeled during agglomeration. False otherwise.
    """
    assert cleaned_edges.dtype == np.uint32
    assert cleaned_edges.ndim == 2
    assert cleaned_edges.shape[1] == 2
    assert edge_weights.shape == (len(cleaned_edges), )
    assert seed_labels.ndim == 1
    assert cleaned_edges.max() < len(seed_labels)

    g = nifty.graph.UndirectedGraph(len(seed_labels))
    g.insertEdges(cleaned_edges)
    output_labels = nifty.graph.edgeWeightedWatershedsSegmentation(
        g, seed_labels, edge_weights)
    contains_unlabeled_components = not output_labels.all()

    mapper = LabelMapper(np.arange(output_labels.shape[0], dtype=np.uint32),
                         output_labels)
    labeled_edges = mapper.apply(cleaned_edges)
    preserved_edges = cleaned_edges[labeled_edges[:, 0] == labeled_edges[:, 1]]

    component_labels = connected_components(preserved_edges,
                                            len(output_labels))
    assert len(component_labels) == len(output_labels) == len(seed_labels)
    cc_df = pd.DataFrame({'label': output_labels, 'cc': component_labels})
    cc_counts = cc_df.groupby('label').nunique()['cc']
    disconnected_cc_counts = cc_counts[cc_counts > 1]
    disconnected_components = set(disconnected_cc_counts.index) - set([0])

    return output_labels, disconnected_components, contains_unlabeled_components
Exemple #23
0
def cleave(edges,
           edge_weights,
           seeds_dict,
           node_ids=None,
           method='seeded-watershed'):
    """
    Cleave the graph with the given edges and edge weights.
    If node_ids is given, it must contain a superset of the ids given in edges.
    Extra ids in node_ids (i.e. not mentioned in 'edges') will be included
    in the results as disconnected components.
    
    Args:
        
        edges:
            array, (E,2), uint32
        
        edge_weights:
            array, (E,), float32
        
        seeds_dict:
            dict, { seed_class : [node_id, node_id, ...] }
        
        node_ids:
            The complete list of node IDs in the graph.
        
        method:
            Either 'seeded-watershed' or 'agglomerative-clustering'

    Returns:
    
        CleaveResults, namedtuple with fields:
        (node_ids, output_labels, disconnected_components, contains_unlabeled_components)
        
        Where:
            node_ids:
                The graph node_ids.
                
            output_labels:
                array (N,), uint32
                Agglomerated node labeling, in the same order as node_ids.
                
            disconnected_components:
                A set of seeds which ended up with more than one component in the result.
            
            contains_unlabeled_components:
                True if the input contains one or more disjoint components that were not seeded
                and thus not labeled during agglomeration. False otherwise.
        
    """
    if node_ids is None:
        node_ids = pd.unique(edges.flat)
        node_ids.sort()

    assert isinstance(node_ids, np.ndarray)
    assert node_ids.dtype in (np.uint32, np.uint64)
    assert node_ids.ndim == 1

    assert method in ('seeded-watershed', 'agglomerative-clustering')

    # Clean the edges (normalized form, no duplicates, no loops)
    edges.sort(axis=1)
    edges_df = pd.DataFrame({
        'u': edges[:, 0],
        'v': edges[:, 1],
        'weight': edge_weights
    })
    edges_df.drop_duplicates(['u', 'v'], keep='last', inplace=True)
    edges_df = edges_df.query('u != v')
    edges = edges_df[['u', 'v']].values
    edge_weights = edges_df['weight'].values

    # Relabel node ids consecutively
    cons_node_ids = np.arange(len(node_ids), dtype=np.uint32)
    mapper = LabelMapper(node_ids, cons_node_ids)
    cons_edges = mapper.apply(edges)
    assert cons_edges.dtype == np.uint32

    # Initialize sparse seed label array
    seed_labels = np.zeros_like(cons_node_ids)
    for seed_class, seed_nodes in seeds_dict.items():
        seed_nodes = np.asarray(seed_nodes, dtype=np.uint64)
        mapper.apply_inplace(seed_nodes)
        seed_labels[seed_nodes] = seed_class

    if method == 'agglomerative-clustering':
        output_labels, disconnected_components, contains_unlabeled_components = agglomerative_clustering(
            cons_edges, edge_weights, seed_labels)
    elif method == 'seeded-watershed':
        output_labels, disconnected_components, contains_unlabeled_components = edge_weighted_watershed(
            cons_edges, edge_weights, seed_labels)

    return CleaveResults(node_ids, output_labels, disconnected_components,
                         contains_unlabeled_components)
Exemple #24
0
def region_features(label_img,
                    grayscale_img=None,
                    features=['Box', 'Count'],
                    ignore_label=None):
    """
    Wrapper around vigra.analysis.extractRegionFeatures() that supports uint64 and
    returns each feature as a pandas Series or DataFrame, indexed by object ID.

    For simple features such as 'Box' and 'Count', most of the time is spent remapping the
    input array from uint64 to uint32, which is the only label image type supported by vigra.

    See vigra docs regarding the supported features:
        - http://ukoethe.github.io/vigra/doc-release/vigranumpy/index.html#vigra.analysis.extractRegionFeatures
        - http://ukoethe.github.io/vigra/doc-release/vigra/group__FeatureAccumulators.html

    Args:
        label_img:
            An integer-valued label image, containing no negative values

        grayscle_img:
            Optional.  If provided, then weighted features are available.
            See GRAYSCALE_FEATURE_NAMES, above.

        features:
            List of strings.  If no grayscale image was provided, you can only
            ask for the features in ``SEGMENTATION_FEATURE_NAMES``, above.

        ignore_label:
            A background label to ignore. If you don't want to ignore any thing, pass ``None``.

    Returns:
        dict {name: feature}, where each feature value is indexed by label ID.
        For keys where the feature is scalar-valued for each label, the returned value is a Series.
        For keys where the feature is a 1D array for each label, a DataFrame is returned, with columns 'zyx'.
        For keys where the feature is a 2D array (e.g. Box, RegionAxes, etc.), a Series is returned whose
        dtype=object, and each item in the series is a 2D array.
        TODO: This might be a good place to use Xarray
    """
    assert label_img.ndim in (2, 3)
    axes = 'zyx'[-label_img.ndim:]

    vfeatures = {*features}

    valid_names = {*SEGMENTATION_FEATURE_NAMES, *GRAYSCALE_FEATURE_NAMES}
    invalid_names = vfeatures - valid_names
    assert not invalid_names, \
        f"Invalid feature names: {invalid_names}"

    if 'Box' in features:
        vfeatures -= {'Box'}
        vfeatures |= {'Coord<Minimum>', 'Coord<Maximum>'}

    if 'Box0' in features:
        vfeatures -= {'Box0'}
        vfeatures |= {'Coord<Minimum>'}

    if 'Box1' in features:
        vfeatures -= {'Box1'}
        vfeatures |= {'Coord<Maximum>'}

    assert np.issubdtype(label_img.dtype, np.integer)

    label_ids = None
    if label_img.dtype == np.uint32:
        label_img32 = label_img
    elif label_img.dtype == np.int32:
        label_img32 = label_img.view(np.uint32)
    elif label_img.dtype in (np.int64, np.uint64):
        label_img = label_img.view(np.uint64)
        label_ids = np.sort(pd.unique(label_img.ravel()))

        # Map from uint64 -> uint32
        label_ids_32 = np.arange(len(label_ids), dtype=np.uint32)
        mapper = LabelMapper(label_ids, label_ids_32)
        label_img32 = mapper.apply(label_img)
        if ignore_label is not None:
            ignore_label = mapper.apply(np.array([ignore_label], np.uint64))[0]
    else:
        label_img32 = label_img.astype(np.uint32)

    assert label_img32.dtype == np.uint32

    if grayscale_img is None:
        invalid_names = vfeatures - {*SEGMENTATION_FEATURE_NAMES}
        assert not invalid_names, \
            f"Invalid segmentation feature names: {invalid_names}"
        grayscale_img = label_img32.view(np.float32)
    else:
        assert grayscale_img.dtype == np.float32, \
            "Grayscale image must be float32"

    grayscale_img = vigra.taggedView(grayscale_img, axes)
    label_img32 = vigra.taggedView(label_img32, axes)

    # TODO: provide histogramRange options
    acc = vigra.analysis.extractRegionFeatures(grayscale_img,
                                               label_img32, [*vfeatures],
                                               ignoreLabel=ignore_label)

    results = {}
    if 'Box0' in features:
        v = acc['Coord<Minimum >'].astype(np.int32)
        results['Box0'] = pd.DataFrame(v, columns=[*axes])
    if 'Box1' in features:
        v = 1 + acc['Coord<Maximum >'].astype(np.int32)
        results['Box1'] = pd.DataFrame(v, columns=[*axes])
    if 'Box' in features:
        box0 = acc['Coord<Minimum >'].astype(np.int32)
        box1 = (1 + acc['Coord<Maximum >']).astype(np.int32)
        boxes = np.stack((box0, box1), axis=1)
        obj_boxes = np.zeros(len(boxes), object)
        obj_boxes[:] = list(boxes)
        results['Box'] = pd.Series(obj_boxes, name='Box')

    for k, v in acc.items():
        k = k.replace(' ', '')

        # Only return the features the user explicitly requested.
        if k not in features:
            continue

        if v.ndim == 1:
            results[k] = pd.Series(v, name=k)
        elif v.ndim == 2:
            results[k] = pd.DataFrame(v, columns=[*axes])
        else:
            # If the data doesn't neatly fit into a 1-d Series
            # or a 2-d DataFrame, then construct a Series with dtype=object
            # and make each row a separate ndarray object.
            obj_v = np.zeros(len(v), dtype=object)
            obj_v[:] = list(v)
            results[k] = pd.Series(obj_v, name=k)

    # Set index to the original uint64 values
    if label_img.dtype == np.uint64:
        for v in results.values():
            v.index = label_ids

    return results
Exemple #25
0
def update_merge_table(server,
                       uuid,
                       instance,
                       table_df,
                       complete_mapping=None,
                       split_mapping=None):
    """
    Given a merge table (such as a focused proofreading decision table),
    find rows whose supervoxels no longer exist in the given instance (due to splits).

    For those invalid rows, determine the new supervoxel and body ID at the given coordinates
    to determine the updated supervoxel/body IDs.
    
    Updates (in-place) the supervoxel and body columns.
    
    Note: If any coordinate appears to be misplaced (i.e. the supervoxel ID at
    the coordinate is not a descendant of the listed supervoxel), the supervoxel is
    left unchanged and the body is mapped to 0.
    
    Args:
        server, uuid, instance:
            Table will be updated with respect to the given segmentation instance info
        
        table_df:
            DataFrame with SV columns and coordinate columns ('xa', 'ya', 'za', etc.)
        
        complete_mapping:
            Optional.  Will be fetched if not provided.
            Must be the complete mapping as returned by fetch_complete_mappings(..., include_retired=True)
        
        split_mapping:
            Optional.  Will be fetched if not provided.
            A mapping from supervoxel fragments to root supervoxel IDs.

    Returns:
        None. (The table is modified in place.)
    """
    seg_info = server, uuid, instance

    # Ensure proper table columns/dtypes
    if 'id_a' in table_df.columns:
        col_sv_a, col_sv_b = 'id_a', 'id_b'
    elif 'sv_a' in table_df.columns:
        col_sv_a, col_sv_b = 'sv_a', 'sv_b'
    else:
        raise RuntimeError("table has no sv columns")

    assert set([col_sv_a, col_sv_b, 'xa', 'ya', 'za', 'xb', 'yb',
                'zb']).issubset(table_df.columns)
    for col in ['xa', 'ya', 'za', 'xb', 'yb', 'zb']:
        table_df[col] = table_df[col].fillna(0).astype(np.int32)

    # Construct mappings if necessary
    kafka_msgs = None
    if complete_mapping is None or split_mapping is None:
        kafka_msgs = read_kafka_messages(*seg_info)

    if complete_mapping is None:
        complete_mapping = fetch_complete_mappings(*seg_info,
                                                   include_retired=True,
                                                   kafka_msgs=kafka_msgs)
    complete_mapper = LabelMapper(complete_mapping.index.values,
                                  complete_mapping.values)

    if split_mapping is None:
        split_events = fetch_supervoxel_splits(*seg_info)
        split_mapping = split_events_to_mapping(split_events,
                                                leaves_only=False)
    split_mapper = LabelMapper(split_mapping.index.values,
                               split_mapping.values)

    # Apply up-to-date body mapping
    # (Retired supervoxels will map to body 0)
    table_df['body_a'] = complete_mapper.apply(table_df[col_sv_a].values, True)
    table_df['body_b'] = complete_mapper.apply(table_df[col_sv_b].values, True)

    successfully_updated = 0
    failed_update = 0

    # Update the rows with invalid bodies.
    for index in tqdm_proxy(
            table_df.query('body_a == 0 or body_b == 0').index):

        def update_row(index, col_body, col_sv, col_z, col_y, col_x):
            nonlocal successfully_updated, failed_update

            # Extract from table
            coord = table_df.loc[index, [col_z, col_y, col_x]]
            old_sv = table_df.loc[index, col_sv]

            # Check current SV in the volume
            new_sv = fetch_label(*seg_info, coord, supervoxels=True)

            # The old/new SV must have come from the same root SV.
            # If not, the coordinate must be misplaced and can't be used here.
            svs = np.asarray([new_sv, old_sv], np.uint64)
            mapped_svs = split_mapper.apply(svs, True)
            if mapped_svs[0] != mapped_svs[1]:
                failed_update += 1
            else:
                body = complete_mapper.apply(np.array([new_sv], np.uint64))[0]
                table_df.loc[index, col_body] = body
                table_df.loc[index, col_sv] = new_sv
                successfully_updated += 1

        # id_a/body_a
        if table_df.loc[index, 'body_a'] == 0:
            update_row(index, 'body_a', col_sv_a, 'za', 'ya', 'xa')

        # id_b/body_b
        if table_df.loc[index, 'body_b'] == 0:
            update_row(index, 'body_b', col_sv_b, 'zb', 'yb', 'xb')

    logger.info(
        f"Updated {successfully_updated}, failed to update {failed_update}")
Exemple #26
0
def extract_important_merges(speculative_merge_tables,
                             important_bodies,
                             body_mapping=None,
                             mapping_instance_info=None,
                             drop_duplicate_body_pairs=False):
    assert (body_mapping is None) ^ (mapping_instance_info is None), \
        "You must set either body_mapping or mapping_instance_info (but not both)"

    if mapping_instance_info is not None:
        body_mapping = fetch_complete_mappings(mapping_instance_info)

    assert isinstance(body_mapping, pd.Series)
    mapper = LabelMapper(body_mapping.index.values, body_mapping.values)

    # pd.Index is faster than builtin set for large sets
    important_bodies = pd.Index(important_bodies)

    results = []
    for spec_merge_table_df in tqdm(speculative_merge_tables):
        logger.info(f"Processing table with {len(spec_merge_table_df)} rows")

        with Timer("Applying mapping", logger):
            spec_merge_table_df['body_a'] = mapper.apply(
                spec_merge_table_df['id_a'].values, allow_unmapped=True)
            spec_merge_table_df['body_b'] = mapper.apply(
                spec_merge_table_df['id_b'].values, allow_unmapped=True)

        with Timer("Dropping identity merges", logger):
            orig_size = len(spec_merge_table_df)
            spec_merge_table_df.query('body_a != body_b', inplace=True)
            logger.info(
                f"Dropped {orig_size-len(spec_merge_table_df)}/{orig_size} edges."
            )

        with Timer("Normalizing edges", logger):
            # Normalize for body ID, not SV ID
            # (This involves a lot of copying, but you've got plenty of RAM, right?)
            a_cols = list(
                filter(lambda s: s[-1] == 'a', spec_merge_table_df.columns))
            b_cols = list(
                filter(lambda s: s[-1] == 'b', spec_merge_table_df.columns))
            spec_merge_table = spec_merge_table_df.to_records(index=False)
            normalize_recarray_inplace(spec_merge_table, 'body_a', 'body_b',
                                       a_cols, b_cols)
            spec_merge_table_df = pd.DataFrame(spec_merge_table)

        with Timer("Filtering edges", logger):
            q = 'body_a in @important_bodies and body_b in @important_bodies'
            orig_len = len(spec_merge_table_df)
            spec_merge_table_df = spec_merge_table_df.query(q, inplace=True)
            logger.info(
                f"Filtered out {orig_len - len(spec_merge_table_df)} non-important edges."
            )

        if drop_duplicate_body_pairs:
            with Timer("Dropping duplicate body pairs", logger):
                orig_len = len(spec_merge_table_df)
                spec_merge_table_df.drop_duplicates(['body_a', 'body_b'],
                                                    inplace=True)
                logger.info(
                    f"Dropped {orig_len - len(spec_merge_table_df)} duplicate body pairs"
                )

        results.append(spec_merge_table_df)

    return results
Exemple #27
0
def fetch_roi_synapses(server,
                       uuid,
                       synapses_instance,
                       rois,
                       fetch_labels=False,
                       return_partners=False,
                       processes=16):
    """
    Fetch the coordinates and (optionally) body labels for 
    all synapses that fall within the given ROIs.
    
    Args:
    
        server:
            DVID server, e.g. 'emdata4:8900'
        
        uuid:
            DVID uuid, e.g. 'abc9'
        
        synapses_instance:
            DVID synapses instance name, e.g. 'synapses'
        
        rois:
            A single DVID ROI instance names or a list of them, e.g. 'EB' or ['EB', 'FB']
        
        fetch_labels:
            If True, also fetch the supervoxel and body label underneath each synapse,
            returned in columns 'sv' and 'body'.
            
        return_partners:
            If True, also return the partners table.

        processes:
            How many parallel processes to use when fetching synapses and supervoxel labels.
    
    Returns:
        pandas DataFrame with columns:
        ``['z', 'y', 'x', 'kind', 'conf']`` and ``['sv', 'body']`` (if ``fetch_labels=True``)
        If return_partners is True, also return the partners table.

    Example:
        df = fetch_roi_synapses('emdata4:8900', '3c281', 'synapses', ['PB(L5)', 'PB(L7)'], True, 8)
    """
    # Late imports to avoid circular imports in dvid/__init__
    from neuclease.dvid import fetch_combined_roi_volume, determine_point_rois, fetch_labels_batched, fetch_mapping, fetch_mappings

    assert rois, "No rois provided, result would be empty. Is that what you meant?"

    if isinstance(rois, str):
        rois = [rois]

    # Determine name of the segmentation instance that's
    # associated with the given synapses instance.
    syn_info = fetch_instance_info(server, uuid, synapses_instance)
    seg_instance = syn_info["Base"]["Syncs"][0]

    logger.info(f"Fetching mask for ROIs: {rois}")
    # Fetch the ROI as a low-res array (scale 5, i.e. 32-px resolution)
    roi_vol_s5, roi_box_s5, overlapping_pairs = fetch_combined_roi_volume(
        server, uuid, rois)

    if len(overlapping_pairs) > 0:
        logger.warning(
            "Some ROIs overlapped and are thus not completely represented in the output:\n"
            f"{overlapping_pairs}")

    # Convert to full-res box
    roi_box = (2**5) * roi_box_s5

    # fetch_synapses_in_batches() requires a box that is 64-px-aligned
    roi_box = round_box(roi_box, 64, 'out')

    logger.info("Fetching synapse points")
    # points_df is a DataFrame with columns for [z,y,x]
    points_df, partners_df = fetch_synapses_in_batches(server,
                                                       uuid,
                                                       synapses_instance,
                                                       roi_box,
                                                       processes=processes)

    # Append a 'roi_name' column to points_df
    logger.info("Labeling ROI for each point")
    determine_point_rois(server, uuid, rois, points_df, roi_vol_s5, roi_box_s5)

    logger.info("Discarding points that don't overlap with the roi")
    rois = {*rois}
    points_df = points_df.query('roi in @rois').copy()

    columns = ['z', 'y', 'x', 'kind', 'conf', 'roi_label', 'roi']

    if fetch_labels:
        logger.info("Fetching supervoxel under each point")
        svs = fetch_labels_batched(server,
                                   uuid,
                                   seg_instance,
                                   points_df[['z', 'y', 'x']].values,
                                   supervoxels=True,
                                   processes=processes)

        with Timer("Mapping supervoxels to bodies", logger):
            # Arbitrary heuristic for whether to do the
            # body-lookups on DVID or on the client.
            if len(svs) < 100_000:
                bodies = fetch_mapping(server, uuid, seg_instance, svs)
            else:
                mapping = fetch_mappings(server, uuid, seg_instance)
                mapper = LabelMapper(mapping.index.values, mapping.values)
                bodies = mapper.apply(svs, True)

        points_df['sv'] = svs
        points_df['body'] = bodies
        columns += ['body', 'sv']

    if return_partners:
        # Filter
        #partners_df = partners_df.query('post_id in @points_df.index and pre_id in @points_df.index').copy()

        # Faster filter (via merge)
        partners_df = partners_df.merge(points_df[[]],
                                        'inner',
                                        left_on='pre_id',
                                        right_index=True)
        partners_df = partners_df.merge(points_df[[]],
                                        'inner',
                                        left_on='post_id',
                                        right_index=True)
        return points_df[columns], partners_df
    else:
        return points_df[columns]