示例#1
0
    def process(self, batch, request):
        outputs = Batch()
        g = batch[self.graph].to_nx_graph()

        branch_points = [n for n in g.nodes if g.degree(n) > 2]

        for branch_point in branch_points:
            if g.is_directed():
                successors = list(g.successors(branch_point))
                predecessors = list(g.predecessors(branch_point))
                lowest = min(successors + predecessors)
                for successor in successors:
                    if successor != lowest:
                        g.remove_edge(branch_point, successor)
                for predecessor in predecessors:
                    if predecessor != lowest:
                        g.remove_edge(predecessor, branch_point)
            else:
                neighbors = sorted(list(g.neighbors(branch_point)))
                for neighbor in neighbors[1:]:
                    g.remove_edge(branch_point, neighbor)

        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())
        return outputs
    def process(self, batch, request):
        g = batch[self.graph].to_nx_graph()
        assert batch[self.graph].spec.roi.get_shape() == self.read_size

        logger.debug(
            f"{self.name()} got graph with {g.number_of_nodes()} nodes, and "
            f"{g.number_of_edges()} edges!")

        write_roi = batch[self.graph].spec.roi.grow(-self.context,
                                                    -self.context)

        cc_func = (nx.connected_components
                   if not g.is_directed() else nx.weakly_connected_components)

        for cc in cc_func(g):
            contained_nodes = [
                n for n in cc if write_roi.contains(g.nodes[n]["location"])
            ]
            if len(contained_nodes) == 0:
                continue
            else:
                cc_id = min(contained_nodes)
                cc_subgraph = g.subgraph(cc)

                # total edge length of this connected component in this write_roi
                total_edge_len = 0

                for u, v in cc_subgraph.edges:
                    u_loc = cc_subgraph.nodes[u]["location"]
                    v_loc = cc_subgraph.nodes[v]["location"]
                    edge_len = np.linalg.norm(u_loc - v_loc)
                    if write_roi.contains(u_loc) and write_roi.contains(v_loc):
                        total_edge_len += edge_len
                    elif write_roi.contains(u_loc) or write_roi.contains(
                            v_loc):
                        total_edge_len += edge_len / 2

                for u in contained_nodes:
                    attrs = cc_subgraph.nodes[u]
                    attrs[self.component_attr] = int(cc_id)
                    attrs[self.size_attr] = float(total_edge_len)

        count = 0
        for node, attrs in g.nodes.items():
            if write_roi.contains(attrs["location"]):
                assert self.component_attr in attrs
                count += 1

        logger.debug(
            f"{self.name()} updated component id of {count} nodes in write_roi"
        )

        outputs = Batch()
        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())

        return outputs
示例#3
0
    def process(self, batch, request):
        outputs = Batch()

        g = batch[self.graph].to_nx_graph()

        for node, attrs in list(g.nodes.items()):
            if attrs[self.size_attr] < self.size_threshold:
                g.remove_node(node)

        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())
        return outputs
示例#4
0
    def process(self, batch, request):
        outputs = Batch()

        g = batch[self.graph].to_nx_graph()
        logger.debug(f"g has {len(g.nodes())} nodes pre filtering")

        cc_func = (nx.weakly_connected_components
                   if g.is_directed() else nx.connected_components)

        ccs = cc_func(g)
        for cc in list(ccs):
            finished = False
            while not finished:
                finished = True
                g_component = g.subgraph(cc)

                branch_points = [
                    n for n in g_component.nodes if g_component.degree(n) > 2
                ]
                logger.debug(
                    f"Connected component has {len(g_component.nodes)} nodes and {len(branch_points)} branch points"
                )
                removed = 0
                for i, branch_point in enumerate(branch_points):
                    remaining = [n for n in cc if n != branch_point]
                    remaining_g = g_component.subgraph(remaining)

                    remaining_ccs = list(cc_func(remaining_g))
                    logger.debug(
                        f"After removing branch point {i}, cc is broken into pieces sized: {[len(x) for x in remaining_ccs]}"
                    )
                    for remaining_cc in list(remaining_ccs):
                        if (self.cable_len(g,
                                           list(remaining_cc) + [branch_point])
                                <= self.node_threshold):
                            for n in remaining_cc:
                                g.remove_node(n)
                                finished = False
                                removed += 1
                logger.debug(f"Removed {removed} nodes from this cc")
        logger.debug(f"g has {len(g.nodes())} nodes post filtering")

        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())
        return outputs
    def process(self, batch, request):
        g = batch[self.graph].to_nx_graph()

        logger.debug(
            f"{self.name()} got graph with {g.number_of_nodes()} nodes, and "
            f"{g.number_of_edges()} edges!")

        write_roi = batch[self.graph].spec.roi.grow(-self.context,
                                                    -self.context)

        contained_nodes = [
            node for node, attr in g.nodes.items()
            if write_roi.contains(attr["location"])
        ]
        contained_components = set(g.nodes[n][self.component_attr]
                                   for n in contained_nodes)

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        component_graph = self.client.get_graph(roi=write_roi,
                                                node_inclusion="dangling",
                                                edge_inclusion="either")

        for node in contained_nodes:
            attrs = g.nodes[node]
            block_component_id = attrs[self.component_attr]
            global_component_id = component_graph.nodes[block_component_id][
                self.component_attr]
            attrs[self.component_attr] = global_component_id
            attrs[self.size_attr] = component_graph.nodes[block_component_id][
                self.size_attr]

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        outputs = Batch()
        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())

        return outputs
示例#6
0
    def process(self, batch, request):
        mst = batch[self.mst].to_nx_graph()
        dense_mst = batch[self.dense_mst].to_nx_graph()
        embeddings = batch[self.embeddings].data
        voxel_size = batch[self.embeddings].spec.voxel_size
        offset = batch[self.embeddings].spec.roi.get_begin()

        for (u, v), chain in self.get_edge_chains(mst, dense_mst):
            chain_embeddings = []
            for n in chain:
                n_loc = dense_mst.nodes[n]["location"]
                n_ind = tuple(int(x) for x in ((n_loc - offset) // voxel_size))
                chain_embeddings.append(
                    embeddings[(slice(None), ) * (len(embeddings.shape) - 3) +
                               n_ind])

            mst.edges[(u, v)][self.distance_attr] = self.get_stat(chain)

        outputs = Batch()
        outputs[self.mst] = Graph.from_nx_graph(mst, batch[self.mst].spec)

        return outputs
示例#7
0
    def process(self, batch, request):
        mst = batch[self.msts[0]].to_nx_graph()
        
        logger.info(
            f"mst has {mst.number_of_nodes()} nodes and "
            f"{mst.number_of_edges()} edges and "
            f"{len(list(nx.connected_components(mst)))} components"
        )

        # threshold out edges
        if (self.msts_dense is not None) and (self.msts_dense[1] in request):
            dense_mst = batch[self.msts_dense[0]].to_nx_graph()

            for (u, v), chain in self.get_edge_chains(mst, dense_mst):

                distance = mst.edges[(u, v)][self.distance_attr]
                if distance > self.edge_threshold:
                    mst.remove_edge(u, v)
                    for u, v in zip(chain[:-1], chain[1:]):
                        dense_mst.remove_edge(u, v)

        else:
            for (u, v), attrs in mst.edges.items():
                distance = attrs[self.distance_attr]
                if distance > self.edge_threshold:
                    mst.remove_edge(u, v)

        # threshold out small components
        components_to_remove = []
        for component in nx.connected_components(mst):
            if np.isclose(self.component_threshold, 0):
                continue
            lower_bound = None
            upper_bound = None
            for node in component:
                loc = mst.nodes[node]["location"]
                if lower_bound is None:
                    lower_bound = loc
                if upper_bound is None:
                    upper_bound = loc
                lower_bound = np.min(np.array([lower_bound, loc]), axis=0)
                upper_bound = np.max(np.array([upper_bound, loc]), axis=0)
            if np.linalg.norm(upper_bound - lower_bound) < self.component_threshold:
                components_to_remove.append(component)

        for component in components_to_remove:
            mst_subgraph = mst.subgraph(component)

            if (self.msts_dense is not None) and (self.msts_dense[1] in request):
                for (u, v), chain in self.get_edge_chains(mst_subgraph, dense_mst):
                    mst.remove_edge(u, v)
                    for (u, v) in zip(chain[:-1], chain[1:]):
                        dense_mst.remove_edge(u, v)

            else:
                for (u, v) in mst_subgraph.edges:
                    mst.remove_edge(u, v)

        for node in list(mst.nodes):
            if mst.degree(node) == 0:
                mst.remove_node(node)

        if (self.msts_dense is not None) and (self.msts_dense[1] in request):
            for node in list(dense_mst.nodes):
                if dense_mst.degree(node) == 0:
                    dense_mst.remove_node(node)

        logger.info(
            f"mst has {mst.number_of_nodes()} nodes and "
            f"{mst.number_of_edges()} edges and "
            f"{len(list(nx.connected_components(mst)))} components"
        )

        outputs = Batch()
        outputs[self.msts[1]] = Graph.from_nx_graph(mst, batch[self.msts[0]].spec)
        outputs[self.msts[1]].relabel_connected_components()

        if (self.msts_dense is not None) and (self.msts_dense[1] in request):
            outputs[self.msts_dense[1]] = Graph.from_nx_graph(
                dense_mst, batch[self.msts_dense[0]].spec
            )
            outputs[self.msts_dense[1]].relabel_connected_components()

        return outputs