Пример #1
0
def find_l2_shortest_path(cg, source_l2_id: np.uint64, target_l2_id: np.uint64):
    """
    Find a path of level 2 ids that connect two level 2 node ids through cross chunk edges.
    Return a list of level 2 ids representing this path.
    Return None if the two level 2 ids do not belong to the same object.

    :param cg: ChunkedGraph object
    :param source_l2_id: np.uint64
    :param target_l2_id: np.uint64
    :return: [np.uint64] or None
    """
    # Get the cross-chunk edges that we need to build the graph
    shared_parent_id = cg.get_first_shared_parent(source_l2_id, target_l2_id)
    if shared_parent_id is None:
        return None
    
    edge_array = get_lvl2_edge_list(cg, shared_parent_id)
    # Create a graph-tool graph of the mapped cross-chunk-edges
    weighted_graph, _, _, graph_indexed_l2_ids = flatgraph_utils.build_gt_graph(
        edge_array, is_directed=False
    )

    # Find the shortest path from the source_l2_id to the target_l2_id
    source_graph_id = np.where(graph_indexed_l2_ids == source_l2_id)[0][0]
    target_graph_id = np.where(graph_indexed_l2_ids == target_l2_id)[0][0]
    source_vertex = weighted_graph.vertex(source_graph_id)
    target_vertex = weighted_graph.vertex(target_graph_id)
    vertex_list, _ = graph_tool.topology.shortest_path(
        weighted_graph, source=source_vertex, target=target_vertex
    )

    # Remap the graph-tool ids to lvl2 ids and return the path
    vertex_indices = [weighted_graph.vertex_index[vertex] for vertex in vertex_list]
    l2_traversal_path = graph_indexed_l2_ids[vertex_indices]
    return l2_traversal_path
Пример #2
0
def compute_cross_chunk_connected_components(eh, node_ids, layer):
    """ Computes connected component for next layer

    :param eh: EditHelper
    :param node_ids: list of np.uint64s
    :param layer: np.int
    :return:
    """
    assert len(node_ids) > 0

    # On each layer we build the a graph with all cross chunk edges
    # that involve the nodes on the current layer
    # To do this efficiently, we acquire all candidate same layer nodes
    # that were previously connected to any of the currently assessed
    # nodes. In practice, we (1) gather all relevant parents in the next
    # layer and then (2) acquire their children

    old_this_layer_node_ids, old_next_layer_node_ids, \
        old_this_layer_partner_ids = \
            old_parent_childrens(eh, node_ids, layer)

    # Build network from cross chunk edges
    edge_id_map = {}
    cross_edges_lvl1 = []
    for node_id in node_ids:
        node_cross_edges = eh.read_cross_chunk_edges(node_id)[layer]
        edge_id_map.update(
            dict(zip(node_cross_edges[:, 0],
                     [node_id] * len(node_cross_edges))))
        cross_edges_lvl1.extend(node_cross_edges)

    for old_partner_id in old_this_layer_partner_ids:
        node_cross_edges = eh.read_cross_chunk_edges(old_partner_id)[layer]

        edge_id_map.update(
            dict(
                zip(node_cross_edges[:, 0],
                    [old_partner_id] * len(node_cross_edges))))
        cross_edges_lvl1.extend(node_cross_edges)

    cross_edges_lvl1 = np.array(cross_edges_lvl1)
    edge_id_map_vec = np.vectorize(edge_id_map.get)

    if len(cross_edges_lvl1) > 0:
        cross_edges = edge_id_map_vec(cross_edges_lvl1)
    else:
        cross_edges = np.empty([0, 2], dtype=np.uint64)

    assert np.sum(np.in1d(eh.old_node_ids, cross_edges)) == 0

    cross_edges = np.concatenate(
        [cross_edges, np.vstack([node_ids, node_ids]).T])

    graph, _, _, unique_graph_ids = flatgraph_utils.build_gt_graph(
        cross_edges, make_directed=True)

    ccs = flatgraph_utils.connected_components(graph)

    return ccs, unique_graph_ids
Пример #3
0
def merge_cross_chunk_edges_graph_tool(edges: Iterable[Sequence[np.uint64]],
                                       affs: Sequence[np.uint64]):
    """ Merges cross chunk edges
    :param edges: n x 2 array of uint64s
    :param affs: float array of length n
    :return:
    """

    # mask for edges that have to be merged
    cross_chunk_edge_mask = np.isinf(affs)

    # graph with edges that have to be merged
    graph, _, _, unique_supervoxel_ids = flatgraph_utils.build_gt_graph(
        edges[cross_chunk_edge_mask], make_directed=True)

    # connected components in this graph will be combined in one component
    ccs = flatgraph_utils.connected_components(graph)

    remapping = {}
    mapping = []

    for cc in ccs:
        nodes = unique_supervoxel_ids[cc]
        rep_node = np.min(nodes)

        remapping[rep_node] = nodes

        rep_nodes = np.ones(len(nodes), dtype=np.uint64).reshape(-1,
                                                                 1) * rep_node
        m = np.concatenate([nodes.reshape(-1, 1), rep_nodes], axis=1)

        mapping.append(m)

    if len(mapping) > 0:
        mapping = np.concatenate(mapping)
    u_nodes = np.unique(edges)
    u_unmapped_nodes = u_nodes[~np.in1d(u_nodes, mapping)]

    unmapped_mapping = np.concatenate(
        [u_unmapped_nodes.reshape(-1, 1),
         u_unmapped_nodes.reshape(-1, 1)],
        axis=1)
    if len(mapping) > 0:
        complete_mapping = np.concatenate([mapping, unmapped_mapping], axis=0)
    else:
        complete_mapping = unmapped_mapping

    sort_idx = np.argsort(complete_mapping[:, 0])
    idx = np.searchsorted(complete_mapping[:, 0], edges, sorter=sort_idx)
    mapped_edges = np.asarray(complete_mapping[:, 1])[sort_idx][idx]

    mapped_edges = mapped_edges[~cross_chunk_edge_mask]
    mapped_affs = affs[~cross_chunk_edge_mask]

    return mapped_edges, mapped_affs, mapping, complete_mapping, remapping
Пример #4
0
def _get_contact_site_edges(
    cg,
    first_node_unconnected_edges,
    first_node_unconnected_areas,
    second_node_connected_edges,
    second_node_sv_ids,
):
    """
    Given two sets of supervoxels, find all contact sites between the two sets, and for each
    contact site return an edge that represents the site.
    """
    # Retrieve edges that connect first node with second node
    unconnected_edge_mask = np.where(
        np.isin(first_node_unconnected_edges[:, 1], second_node_sv_ids))[0]
    filtered_unconnected_edges = first_node_unconnected_edges[
        unconnected_edge_mask]
    filtered_areas = first_node_unconnected_areas[unconnected_edge_mask]
    contact_sites_svs_area_dict = collections.defaultdict(int)
    for i in range(filtered_unconnected_edges.shape[0]):
        sv_id = filtered_unconnected_edges[i, 1]
        area = filtered_areas[i]
        contact_sites_svs_area_dict[sv_id] += area
    contact_sites_svs_area_dict_vec = np.vectorize(
        contact_sites_svs_area_dict.get)
    unique_contact_sites_svs = np.unique(filtered_unconnected_edges[:, 1])
    # Retrieve edges that connect second node contact sites with other second node contact sites
    connected_edge_test = np.isin(second_node_connected_edges,
                                  unique_contact_sites_svs)
    connected_edges_mask = np.where(np.all(connected_edge_test, axis=1))[0]
    filtered_connected_edges = second_node_connected_edges[
        connected_edges_mask]
    # Make fake edges from contact site svs to themselves to make sure they appear in the created graph
    self_edges = np.stack((unique_contact_sites_svs, unique_contact_sites_svs),
                          axis=-1)
    contact_sites_graph_edges = np.concatenate(
        (filtered_connected_edges, self_edges), axis=0)

    graph, _, _, unique_sv_ids = flatgraph_utils.build_gt_graph(
        contact_sites_graph_edges, make_directed=True)
    connected_components = flatgraph_utils.connected_components(graph)

    contact_site_edges = []
    for cc in connected_components:
        cc_sv_ids = unique_sv_ids[cc]
        contact_sites_areas = contact_sites_svs_area_dict_vec(cc_sv_ids)
        contact_site_edge, contact_site_edge_type = _choose_contact_site_edge(
            cg, filtered_unconnected_edges, cc_sv_ids)
        contact_site_edge_info = (
            contact_site_edge,
            contact_site_edge_type,
            np.sum(contact_sites_areas),
        )
        contact_site_edges.append(contact_site_edge_info)
    return contact_site_edges
Пример #5
0
    def _build_gt_graph(self, edges, affs):
        """
        Create the graph that will be used to compute the mincut.
        """
        self.source_edges = list(itertools.product(self.sources, self.sources))
        self.sink_edges = list(itertools.product(self.sinks, self.sinks))

        # Assemble edges: Edges after remapping combined with fake infinite affinity
        # edges between sinks and sources
        comb_edges = np.concatenate(
            [edges, self.source_edges, self.sink_edges])
        comb_affs = np.concatenate([
            affs, [float_max] * (len(self.source_edges) + len(self.sink_edges))
        ])

        # To make things easier for everyone involved, we map the ids to
        # [0, ..., len(unique_supervoxel_ids) - 1]
        # Generate weighted graph with graph_tool
        self.weighted_graph, self.capacities, self.gt_edges, self.unique_supervoxel_ids = flatgraph_utils.build_gt_graph(
            comb_edges, comb_affs, make_directed=True)

        self.source_graph_ids = np.where(
            np.in1d(self.unique_supervoxel_ids, self.sources))[0]
        self.sink_graph_ids = np.where(
            np.in1d(self.unique_supervoxel_ids, self.sinks))[0]

        if self.logger is not None:
            self.logger.debug(f"{self.sinks}, {self.sink_graph_ids}")
            self.logger.debug(f"{self.sources}, {self.source_graph_ids}")
Пример #6
0
def remove_edges(cg, operation_id: np.uint64,
                 atomic_edges: Sequence[Sequence[np.uint64]],
                 time_stamp: datetime.datetime):

    # This view of the to be removed edges helps us to compute the mask
    # of the retained edges in each chunk
    double_atomic_edges = np.concatenate([atomic_edges, atomic_edges[:, ::-1]],
                                         axis=0)
    double_atomic_edges_view = double_atomic_edges.view(dtype='u8,u8')
    n_edges = double_atomic_edges.shape[0]
    double_atomic_edges_view = double_atomic_edges_view.reshape(n_edges)

    rows = []  # list of rows to be written to BigTable
    lvl2_dict = {}
    lvl2_cross_chunk_edge_dict = {}

    # Analyze atomic_edges --> translate them to lvl2 edges and extract cross
    # chunk edges to be removed
    lvl2_edges, old_cross_edge_dict = analyze_atomic_edges(cg, atomic_edges)
    lvl2_node_ids = np.unique(lvl2_edges)

    for lvl2_node_id in lvl2_node_ids:
        chunk_id = cg.get_chunk_id(lvl2_node_id)
        chunk_edges, _, _ = cg.get_subgraph_chunk(lvl2_node_id,
                                                  make_unique=False)

        child_chunk_ids = cg.get_child_chunk_ids(chunk_id)

        assert len(child_chunk_ids) == 1
        child_chunk_id = child_chunk_ids[0]

        children_ids = np.unique(chunk_edges)
        children_chunk_ids = cg.get_chunk_ids_from_node_ids(children_ids)
        children_ids = children_ids[children_chunk_ids == child_chunk_id]

        # These edges still contain the removed edges.
        # For consistency reasons we can only write to BigTable one time.
        # Hence, we have to evict the to be removed "atomic_edges" from the
        # queried edges.
        retained_edges_mask = ~np.in1d(
            chunk_edges.view(dtype='u8,u8').reshape(chunk_edges.shape[0]),
            double_atomic_edges_view)

        chunk_edges = chunk_edges[retained_edges_mask]

        edge_layers = cg.get_cross_chunk_edges_layer(chunk_edges)
        cross_edge_mask = edge_layers != 1

        cross_edges = chunk_edges[cross_edge_mask]
        cross_edge_layers = edge_layers[cross_edge_mask]
        chunk_edges = chunk_edges[~cross_edge_mask]

        isolated_child_ids = children_ids[~np.in1d(children_ids, chunk_edges)]
        isolated_edges = np.vstack([isolated_child_ids, isolated_child_ids]).T

        graph, _, _, unique_graph_ids = flatgraph_utils.build_gt_graph(
            np.concatenate([chunk_edges, isolated_edges]), make_directed=True)

        ccs = flatgraph_utils.connected_components(graph)

        new_parent_ids = cg.get_unique_node_id_range(chunk_id, len(ccs))

        for i_cc, cc in enumerate(ccs):
            new_parent_id = new_parent_ids[i_cc]
            cc_node_ids = unique_graph_ids[cc]

            lvl2_dict[new_parent_id] = [lvl2_node_id]

            # Write changes to atomic nodes and new lvl2 parent row
            val_dict = {column_keys.Hierarchy.Child: cc_node_ids}
            rows.append(
                cg.mutate_row(serializers.serialize_uint64(new_parent_id),
                              val_dict,
                              time_stamp=time_stamp))

            for cc_node_id in cc_node_ids:
                val_dict = {column_keys.Hierarchy.Parent: new_parent_id}

                rows.append(
                    cg.mutate_row(serializers.serialize_uint64(cc_node_id),
                                  val_dict,
                                  time_stamp=time_stamp))

            # Cross edges ---
            cross_edge_m = np.in1d(cross_edges[:, 0], cc_node_ids)
            cc_cross_edges = cross_edges[cross_edge_m]
            cc_cross_edge_layers = cross_edge_layers[cross_edge_m]
            u_cc_cross_edge_layers = np.unique(cc_cross_edge_layers)

            lvl2_cross_chunk_edge_dict[new_parent_id] = {}

            for l in range(2, cg.n_layers):
                empty_edges = column_keys.Connectivity.CrossChunkEdge.deserialize(
                    b'')
                lvl2_cross_chunk_edge_dict[new_parent_id][l] = empty_edges

            val_dict = {}
            for cc_layer in u_cc_cross_edge_layers:
                edge_m = cc_cross_edge_layers == cc_layer
                layer_cross_edges = cc_cross_edges[edge_m]

                if len(layer_cross_edges) > 0:
                    val_dict[column_keys.Connectivity.CrossChunkEdge[cc_layer]] = \
                        layer_cross_edges
                    lvl2_cross_chunk_edge_dict[new_parent_id][
                        cc_layer] = layer_cross_edges

            if len(val_dict) > 0:
                rows.append(
                    cg.mutate_row(serializers.serialize_uint64(new_parent_id),
                                  val_dict,
                                  time_stamp=time_stamp))

        if cg.n_layers == 2:
            rows.extend(
                update_root_id_lineage(cg,
                                       new_parent_ids, [lvl2_node_id],
                                       operation_id=operation_id,
                                       time_stamp=time_stamp))

    # Write atomic nodes
    rows.extend(
        _write_atomic_split_edges(cg, atomic_edges, time_stamp=time_stamp))

    # Propagate changes up the tree
    if cg.n_layers > 2:
        new_root_ids, new_rows = propagate_edits_to_root(
            cg,
            lvl2_dict.copy(),
            lvl2_cross_chunk_edge_dict,
            operation_id=operation_id,
            time_stamp=time_stamp)
        rows.extend(new_rows)
    else:
        new_root_ids = np.array(list(lvl2_dict.keys()))

    return new_root_ids, list(lvl2_dict.keys()), rows
Пример #7
0
def add_edges(cg,
              operation_id: np.uint64,
              atomic_edges: Sequence[Sequence[np.uint64]],
              time_stamp: datetime.datetime,
              areas: Optional[Sequence[np.uint64]] = None,
              affinities: Optional[Sequence[np.float32]] = None):
    """ Add edges to chunkedgraph

    Computes all new rows to be written to the chunkedgraph

    :param cg: ChunkedGraph instance
    :param operation_id: np.uint64
    :param atomic_edges: list of list of np.uint64
        edges between supervoxels
    :param time_stamp: datetime.datetime
    :param areas: list of np.uint64
    :param affinities: list of np.float32
    :return: list
    """
    def _read_cc_edges_thread(node_ids):
        for node_id in node_ids:
            cc_dict[node_id] = cg.read_cross_chunk_edges(node_id)

    cc_dict = {}

    atomic_edges = np.array(atomic_edges,
                            dtype=column_keys.Connectivity.Partner.basetype)

    # # Comply to resolution of BigTables TimeRange
    # time_stamp = get_google_compatible_time_stamp(time_stamp,
    #                                               round_up=False)

    if affinities is None:
        affinities = np.ones(len(atomic_edges),
                             dtype=column_keys.Connectivity.Affinity.basetype)
    else:
        affinities = np.array(affinities,
                              dtype=column_keys.Connectivity.Affinity.basetype)

    if areas is None:
        areas = np.ones(len(atomic_edges),
                        dtype=column_keys.Connectivity.Area.basetype) * np.inf
    else:
        areas = np.array(areas, dtype=column_keys.Connectivity.Area.basetype)

    assert len(affinities) == len(atomic_edges)

    rows = []  # list of rows to be written to BigTable
    lvl2_dict = {}
    lvl2_cross_chunk_edge_dict = {}

    # Analyze atomic_edges --> translate them to lvl2 edges and extract cross
    # chunk edges
    lvl2_edges, new_cross_edge_dict = analyze_atomic_edges(cg, atomic_edges)

    # Compute connected components on lvl2
    graph, _, _, unique_graph_ids = flatgraph_utils.build_gt_graph(
        lvl2_edges, make_directed=True)

    # Read cross chunk edges efficiently
    cc_dict = {}
    node_ids = np.unique(lvl2_edges)
    n_threads = int(np.ceil(len(node_ids) / 5))

    node_id_blocks = np.array_split(node_ids, n_threads)

    mu.multithread_func(_read_cc_edges_thread,
                        node_id_blocks,
                        n_threads=n_threads,
                        debug=False)

    ccs = flatgraph_utils.connected_components(graph)
    for cc in ccs:
        lvl2_ids = unique_graph_ids[cc]
        chunk_id = cg.get_chunk_id(lvl2_ids[0])

        new_node_id = cg.get_unique_node_id(chunk_id)
        lvl2_dict[new_node_id] = lvl2_ids

        cross_chunk_edge_dict = {}
        for lvl2_id in lvl2_ids:
            lvl2_id_cross_chunk_edges = cc_dict[lvl2_id]
            cross_chunk_edge_dict = \
                combine_cross_chunk_edge_dicts(
                    cross_chunk_edge_dict,
                    lvl2_id_cross_chunk_edges)

            if lvl2_id in new_cross_edge_dict:
                cross_chunk_edge_dict = \
                    combine_cross_chunk_edge_dicts(
                        new_cross_edge_dict[lvl2_id],
                        lvl2_id_cross_chunk_edges)

        lvl2_cross_chunk_edge_dict[new_node_id] = cross_chunk_edge_dict

        if cg.n_layers == 2:
            rows.extend(
                update_root_id_lineage(cg, [new_node_id],
                                       lvl2_ids,
                                       operation_id=operation_id,
                                       time_stamp=time_stamp))

        children_ids = cg.get_children(lvl2_ids, flatten=True)

        rows.extend(
            create_parent_children_rows(cg, new_node_id, children_ids,
                                        cross_chunk_edge_dict, time_stamp))

    # Write atomic nodes
    rows.extend(
        _write_atomic_merge_edges(cg,
                                  atomic_edges,
                                  affinities,
                                  areas,
                                  time_stamp=time_stamp))

    # Propagate changes up the tree
    if cg.n_layers > 2:
        new_root_ids, new_rows = propagate_edits_to_root(
            cg,
            lvl2_dict.copy(),
            lvl2_cross_chunk_edge_dict,
            operation_id=operation_id,
            time_stamp=time_stamp)
        rows.extend(new_rows)
    else:
        new_root_ids = np.array(list(lvl2_dict.keys()))

    return new_root_ids, list(lvl2_dict.keys()), rows
Пример #8
0
def mincut_graph_tool(edges: Iterable[Sequence[np.uint64]],
                      affs: Sequence[np.uint64],
                      sources: Sequence[np.uint64],
                      sinks: Sequence[np.uint64],
                      logger: Optional[logging.Logger] = None) -> np.ndarray:
    """ Computes the min cut on a local graph
    :param edges: n x 2 array of uint64s
    :param affs: float array of length n
    :param sources: uint64
    :param sinks: uint64
    :return: m x 2 array of uint64s
        edges that should be removed
    """
    time_start = time.time()

    original_edges = edges

    # Stitch supervoxels across chunk boundaries and represent those that are
    # connected with a cross chunk edge with a single id. This may cause id
    # changes among sinks and sources that need to be taken care of.
    edges, affs, mapping, remapping = merge_cross_chunk_edges(
        edges.copy(), affs.copy())

    dt = time.time() - time_start
    if logger is not None:
        logger.debug("Cross edge merging: %.2fms" % (dt * 1000))
    time_start = time.time()

    mapping_vec = np.vectorize(lambda a: mapping[a] if a in mapping else a)

    if len(edges) == 0:
        return []

    if len(mapping) > 0:
        assert np.unique(list(mapping.keys()),
                         return_counts=True)[1].max() == 1

    remapped_sinks = mapping_vec(sinks)
    remapped_sources = mapping_vec(sources)

    sinks = remapped_sinks
    sources = remapped_sources

    # Assemble edges: Edges after remapping combined with edges between sinks
    # and sources
    sink_edges = list(itertools.product(sinks, sinks))
    source_edges = list(itertools.product(sources, sources))

    comb_edges = np.concatenate([edges, sink_edges, source_edges])

    comb_affs = np.concatenate(
        [affs, [
            float_max,
        ] * (len(sink_edges) + len(source_edges))])

    # To make things easier for everyone involved, we map the ids to
    # [0, ..., len(unique_ids) - 1]
    # Generate weighted graph with graph_tool
    weighted_graph, cap, gt_edges, unique_ids = \
        flatgraph_utils.build_gt_graph(comb_edges, comb_affs,
                                       make_directed=True)

    sink_graph_ids = np.where(np.in1d(unique_ids, sinks))[0]
    source_graph_ids = np.where(np.in1d(unique_ids, sources))[0]

    if logger is not None:
        logger.debug(f"{sinks}, {sink_graph_ids}")
        logger.debug(f"{sources}, {source_graph_ids}")

    dt = time.time() - time_start
    if logger is not None:
        logger.debug("Graph creation: %.2fms" % (dt * 1000))
    time_start = time.time()

    # # Get rid of connected components that are not involved in the local
    # # mincut
    # cc_prop, ns = graph_tool.topology.label_components(weighted_graph)
    #
    # if len(ns) > 1:
    #     cc_labels = cc_prop.get_array()
    #
    #     for i_cc in range(len(ns)):
    #         cc_list = np.where(cc_labels == i_cc)[0]
    #
    #         # If connected component contains no sources and/or no sinks,
    #         # remove its nodes from the mincut computation
    #         if not np.any(np.in1d(source_graph_ids, cc_list)) or \
    #                 not np.any(np.in1d(sink_graph_ids, cc_list)):
    #             weighted_graph.delete_vertices(cc) # wrong

    # Compute mincut
    src, tgt = weighted_graph.vertex(source_graph_ids[0]), \
               weighted_graph.vertex(sink_graph_ids[0])

    res = graph_tool.flow.boykov_kolmogorov_max_flow(weighted_graph, src, tgt,
                                                     cap)

    part = graph_tool.flow.min_st_cut(weighted_graph, src, cap, res)

    labeled_edges = part.a[gt_edges]
    cut_edge_set = gt_edges[labeled_edges[:, 0] != labeled_edges[:, 1]]

    dt = time.time() - time_start
    if logger is not None:
        logger.debug("Mincut comp: %.2fms" % (dt * 1000))
    time_start = time.time()

    if len(cut_edge_set) == 0:
        return []

    time_start = time.time()

    # Make sure we did not do something wrong: Check if sinks and sources are
    # among each other and not in different sets
    for i_cc in np.unique(part.a):
        # Make sure to read real ids and not graph ids
        cc_list = unique_ids[np.array(np.where(part.a == i_cc)[0],
                                      dtype=np.int)]

        # if logger is not None:
        #     logger.debug("CC size = %d" % len(cc_list))

        if np.any(np.in1d(sources, cc_list)):
            assert np.all(np.in1d(sources, cc_list))
            assert ~np.any(np.in1d(sinks, cc_list))

        if np.any(np.in1d(sinks, cc_list)):
            assert np.all(np.in1d(sinks, cc_list))
            assert ~np.any(np.in1d(sources, cc_list))

    dt = time.time() - time_start
    if logger is not None:
        logger.debug("Verifying local graph: %.2fms" % (dt * 1000))

    # Extract original ids
    # This has potential to be optimized
    remapped_cutset = []
    for s, t in flatgraph_utils.remap_ids_from_graph(cut_edge_set, unique_ids):

        if s in remapping:
            s = remapping[s]
        else:
            s = [s]

        if t in remapping:
            t = remapping[t]
        else:
            t = [t]

        remapped_cutset.extend(list(itertools.product(s, t)))
        remapped_cutset.extend(list(itertools.product(t, s)))

    remapped_cutset = np.array(remapped_cutset, dtype=np.uint64)

    remapped_cutset_flattened_view = remapped_cutset.view(dtype='u8,u8')
    edges_flattened_view = original_edges.view(dtype='u8,u8')

    cutset_mask = np.in1d(remapped_cutset_flattened_view, edges_flattened_view)

    return remapped_cutset[cutset_mask]
Пример #9
0
def get_contact_sites(cg,
                      root_id,
                      bounding_box=None,
                      bb_is_coordinate=True,
                      compute_partner=True,
                      end_time=None,
                      voxel_location=True,
                      areas_only=False,
                      as_list=False):
    """
    Given a root id, return two lists: the first contains all the contact sites with other roots in the dataset,
    the second is metadata specifying exactly what data the first list contains.

    If compute_partner=True, the first returned list is a list of tuples of length two. The first element of the tuple
    is a contact site partner root id. The second element is a list of all the contact sites (tuples) root_id makes
    with this contact partner. 
    
    If compute_partner=False, the first returned list is a list of all the contact sites. 

    The voxel_location and areas_only parameters affect the tuples in the list of contact sites mentioned above.
    If voxel_location=False, then the first two entries of the tuple are chunk coordinates that 
    bound part of the contact site; the third entry is the area of the contact site. If voxel_location=True, 
    the first two entries are the positions of those two chunks in global coordinates instead.
    If areas_only=True, then the tuple is just the area and no location is returned.
    """
    contact_sites_graph_edges, contact_sites_svs_area_dict, any_contact_sites = _get_edges_for_contact_site_graph(
        cg, root_id, bounding_box, bb_is_coordinate, end_time)

    if not any_contact_sites:
        return collections.defaultdict(list)

    contact_sites_svs_area_dict_vec = np.vectorize(
        contact_sites_svs_area_dict.get)

    graph, _, _, unique_sv_ids = flatgraph_utils.build_gt_graph(
        contact_sites_graph_edges, make_directed=True)

    connected_components = flatgraph_utils.connected_components(graph)

    contact_site_dict = collections.defaultdict(list)
    # First create intermediary map of supervoxel to contact sites, so we
    # can call cg.get_roots() on all supervoxels at once.
    intermediary_sv_dict = {}
    for cc in connected_components:
        cc_sv_ids = unique_sv_ids[cc]
        contact_sites_areas = contact_sites_svs_area_dict_vec(cc_sv_ids)

        representative_sv = cc_sv_ids[0]
        # Tuple of location and area of contact site
        chunk_coordinates = cg.get_chunk_coordinates(representative_sv)
        if areas_only:
            data_pair = np.sum(contact_sites_areas)
        elif voxel_location:
            voxel_lower_bound = (cg.vx_vol_bounds[:, 0] +
                                 cg.chunk_size * chunk_coordinates)
            voxel_upper_bound = cg.vx_vol_bounds[:, 0] + cg.chunk_size * (
                chunk_coordinates + 1)
            data_pair = (
                voxel_lower_bound * cg.segmentation_resolution,
                voxel_upper_bound * cg.segmentation_resolution,
                np.sum(contact_sites_areas),
            )
        else:
            data_pair = (
                chunk_coordinates,
                chunk_coordinates + 1,
                np.sum(contact_sites_areas),
            )

        if compute_partner:
            # Cast np.uint64 to int for dict key because int is hashable
            intermediary_sv_dict[int(representative_sv)] = data_pair
        else:
            contact_site_dict[len(contact_site_dict)].append(data_pair)
    if compute_partner:
        sv_list = np.array(list(intermediary_sv_dict.keys()), dtype=np.uint64)
        partner_roots = cg.get_roots(sv_list)
        for i in range(len(partner_roots)):
            contact_site_dict[int(partner_roots[i])].append(
                intermediary_sv_dict.get(int(sv_list[i])))

    contact_site_list = []
    for partner_id in contact_site_dict:
        if compute_partner:
            contact_site_list.append(
                (np.uint64(partner_id), contact_site_dict[partner_id]))
        else:
            contact_site_list.append((*contact_site_dict[partner_id]))

    if compute_partner:
        contact_site_metadata = [
            'segment id', 'lower bound coordinate', 'upper bound coordinate',
            'area'
        ]
    else:
        contact_site_metadata = [
            'lower bound coordinate', 'upper bound coordinate', 'area'
        ]

    return contact_site_list, contact_site_metadata