Example #1
0
def overlay_segmentation(db_host,
                         db_name,
                         roi_offset,
                         roi_size,
                         selected_attr,
                         solved_attr,
                         edge_collection,
                         segmentation_container,
                         segmentation_dataset,
                         segmentation_number,
                         voxel_size=(40, 4, 4)):

    graph_provider = MongoDbGraphProvider(db_name,
                                          db_host,
                                          directed=False,
                                          position_attribute=['z', 'y', 'x'],
                                          edges_collection=edge_collection)
    graph_roi = daisy.Roi(roi_offset, roi_size)

    segmentation = daisy.open_ds(segmentation_container, segmentation_dataset)
    intersection_roi = segmentation.roi.intersect(graph_roi).snap_to_grid(
        voxel_size)

    nx_graph = graph_provider.get_graph(intersection_roi,
                                        nodes_filter={selected_attr: True},
                                        edges_filter={selected_attr: True})

    for node_id, data in nx_graph.nodes(data=True):
        node_position = daisy.Coordinate((data["z"], data["y"]), data["x"])
        nx_graph.nodes[node_id]["segmentation_{}".format(
            segmentation_number)] = segmentation[node_position]

    graph_provider.write_nodes(intersection_roi)
Example #2
0
def get_graph(db_host, db_name, roi_offset, roi_size, solve_number,
              edge_collection):
    """
    Get selected subgraph containing no isolated
    selected edges. (All edges with not 
    2 selected vertices are filtered out)
    """

    selected_attr = "selected_{}".format(solve_number)
    solved_attr = "solved_{}".format(solve_number)

    graph_provider = MongoDbGraphProvider(db_name,
                                          db_host,
                                          directed=False,
                                          position_attribute=['z', 'y', 'x'],
                                          edges_collection=edge_collection)
    roi = daisy.Roi(roi_offset, roi_size)

    nx_graph = graph_provider.get_graph(
        roi,
        nodes_filter={selected_attr: True},
        edges_filter={selected_attr: True},
        node_attrs=[selected_attr, solved_attr, "x", "y", "z"],
        edge_attrs=[selected_attr, solved_attr, "x", "y", "z"])

    # NOTE: We filter edges that do not have
    # two selected vertices.
    edges_to_remove = set()
    nodes_to_remove = set()
    for e in nx_graph.edges():
        u = nx_graph.nodes()[e[0]]
        v = nx_graph.nodes()[e[1]]

        if not u:
            edges_to_remove.add((e[0], e[1]))
            nodes_to_remove.add(e[0])
        if not v:
            edges_to_remove.add((e[0], e[1]))
            nodes_to_remove.add(e[1])

    nx_graph.remove_edges_from(edges_to_remove)
    nx_graph.remove_nodes_from(nodes_to_remove)

    return nx_graph
Example #3
0
def label_connected_components(db_host,
                               db_name,
                               roi,
                               selected_attr,
                               solved_attr,
                               edge_collection,
                               label_attribute="label"):

    graph_provider = MongoDbGraphProvider(db_name,
                                          db_host,
                                          directed=False,
                                          position_attribute=['z', 'y', 'x'],
                                          edges_collection=edge_collection)
    graph = graph_provider.get_graph(roi,
                                     nodes_filter={selected_attr: True},
                                     edges_filter={selected_attr: True})

    lut = find_connected_components(graph,
                                    node_component_attribute=label_attribute,
                                    return_lut=True)

    return graph, lut
class DaisyGraphProvider(BatchProvider):
    """
    See documentation for mongo graph provider at
    https://github.com/funkelab/daisy/blob/0.3-dev/daisy/persistence/mongodb_graph_provider.py#L17
    """
    def __init__(
        self,
        dbname: str,
        url: str,
        points: List[GraphKey],
        graph_specs: Optional[Union[GraphSpec, List[GraphSpec]]] = None,
        directed: bool = False,
        total_roi: Roi = None,
        nodes_collection: str = "nodes",
        edges_collection: str = "edges",
        meta_collection: str = "meta",
        endpoint_names: Tuple[str, str] = ("u", "v"),
        position_attribute: str = "position",
        node_attrs: Optional[List[str]] = None,
        edge_attrs: Optional[List[str]] = None,
        nodes_filter: Optional[Dict[str, Any]] = None,
        edges_filter: Optional[Dict[str, Any]] = None,
        edge_inclusion: str = "either",
        node_inclusion: str = "dangling",
        fail_on_inconsistent_node: bool = False,
    ):
        self.points = points
        graph_specs = (graph_specs if graph_specs is not None else GraphSpec(
            Roi(Coordinate([None] * 3), Coordinate([None] * 3)),
            directed=False))
        specs = (graph_specs if isinstance(graph_specs, list)
                 and len(graph_specs) == len(points) else [graph_specs] *
                 len(points))
        self.specs = {key: spec for key, spec in zip(points, specs)}

        self.directed = directed
        self.nodes_collection = nodes_collection
        self.edges_collection = edges_collection
        self.meta_collection = meta_collection
        self.endpoint_names = endpoint_names
        self.position_attribute = position_attribute

        self.position_attribute = position_attribute
        self.node_attrs = node_attrs
        self.edge_attrs = edge_attrs
        self.nodes_filter = nodes_filter
        self.edges_filter = edges_filter

        self.edge_inclusion = edge_inclusion
        self.node_inclusion = node_inclusion

        self.dbname = dbname
        self.url = url
        self.nodes_collection = nodes_collection

        self.fail_on_inconsistent_node = fail_on_inconsistent_node

        self.graph_provider = None

    def setup(self):
        for key, spec in self.specs.items():
            self.provides(key, spec)

        if self.graph_provider is None:
            self.graph_provider = MongoDbGraphProvider(
                self.dbname,
                self.url,
                mode="r+",
                directed=self.directed,
                total_roi=None,
                nodes_collection=self.nodes_collection,
                edges_collection=self.edges_collection,
                meta_collection=self.meta_collection,
                endpoint_names=self.endpoint_names,
                position_attribute=self.position_attribute,
            )

    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion=self.edge_inclusion,
                node_inclusion=self.node_inclusion,
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            failed_nodes = []

            for node, attrs in requested_graph.nodes.items():
                try:
                    attrs["location"] = np.array(
                        attrs[self.position_attribute], dtype=np.float32)
                except KeyError:
                    logger.warning(
                        f"node: {node} was written (probably part of an edge), but never given coordinates!"
                    )
                    failed_nodes.append(node)
                attrs["id"] = node

            for node in failed_nodes:
                if self.fail_on_inconsistent_node:
                    raise ValueError(
                        f"Mongodb contains node {node} without location! "
                        f"It was probably written as part of an edge")
                requested_graph.remove_node(node)

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            points = Graph.from_nx_graph(requested_graph, spec)
            points.relabel_connected_components()
            points.crop(spec.roi)
            batch[key] = points

            logger.debug(f"{key} with {len(list(points.nodes))} nodes")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Example #5
0
def solve_in_block(db_host,
                   db_name,
                   evidence_factor,
                   comb_angle_factor,
                   start_edge_prior,
                   selection_cost,
                   time_limit,
                   solve_number,
                   graph_number,
                   selected_attr="selected",
                   solved_attr="solved",
                   **kwargs):

    print("Solve in block")

    graph_provider = MongoDbGraphProvider(
        db_name,
        db_host,
        mode='r+',
        position_attribute=['z', 'y', 'x'],
        edges_collection="edges_g{}".format(graph_number))

    client = daisy.Client()

    while True:
        print("Acquire block")
        block = client.acquire_block()

        if not block:
            return

        logger.debug("Solving in block %s", block)

        if check_function(block, 'solve_s{}'.format(solve_number), db_name,
                          db_host):
            client.release_block(block, 0)
            continue

        start_time = time.time()
        graph = graph_provider.get_graph(block.read_roi)

        num_nodes = graph.number_of_nodes()
        num_edges = graph.number_of_edges()
        logger.info(
            "Reading graph with %d nodes and %d edges took %s seconds" %
            (num_nodes, num_edges, time.time() - start_time))

        if num_edges == 0:
            logger.info("No edges in roi %s. Skipping" % block.read_roi)
            write_done(block, 'solve_s{}'.format(solve_number), db_name,
                       db_host)
            client.release_block(block, 0)
            continue

        print("solve")
        solver = Solver(graph, evidence_factor, comb_angle_factor,
                        start_edge_prior, selection_cost, time_limit,
                        selected_attr, solved_attr)

        solver.initialize()
        solver.solve()

        start_time = time.time()
        graph.update_edge_attrs(block.write_roi,
                                attributes=[selected_attr, solved_attr])

        graph.update_node_attrs(block.write_roi,
                                attributes=[selected_attr, solved_attr])

        logger.info(
            "Updating attributes %s & %s for %d edges took %s seconds" %
            (selected_attr, solved_attr, num_edges, time.time() - start_time))

        print("Write done")
        write_done(block, 'solve_s{}'.format(solve_number), db_name, db_host)

        print("Release block")
        client.release_block(block, 0)

    return 0
Example #6
0
class MongoWriteGraph(gp.BatchFilter):
    def __init__(
        self,
        mst,
        db_host,
        db_name,
        read_size,
        write_size,
        voxel_size,
        directed,
        mode="r+",
        collection="",
        position_attr="position",
        node_attrs: List[str] = None,
        edge_attrs: List[str] = None,
    ):

        self.mst = mst
        self.db_host = db_host
        self.db_name = db_name
        self.client = None

        self.voxel_size = voxel_size
        self.read_size = read_size
        self.write_size = write_size
        self.context = (read_size - write_size) / 2

        assert self.write_size + (self.context * 2) == self.read_size

        self.directed = directed
        self.mode = mode
        self.position_attr = position_attr

        self.nodes_collection = f"{collection}_nodes"
        self.edges_collection = f"{collection}_edges"

        self.node_attrs = node_attrs if node_attrs is not None else []
        self.node_attrs += [position_attr]
        self.edge_attrs = edge_attrs if edge_attrs is not None else []

    def setup(self):

        # Initialize client. Doesn't the daisy mongodb graph provider handle this?
        if self.client is None:
            self.client = MongoDbGraphProvider(
                self.db_name,
                self.db_host,
                mode=self.mode,
                directed=self.directed,
                nodes_collection=self.nodes_collection,
                edges_collection=self.edges_collection,
            )

        self.updates(self.mst, self.spec[self.mst].copy())

    def prepare(self, request):
        deps = gp.BatchRequest()
        deps[self.mst] = request[self.mst].copy()
        assert (request[self.mst].roi.get_shape() == self.read_size
                ), f"Got wrong size graph in request"
        return deps

    def process(self, batch, request):

        voxel_size = self.voxel_size

        read_roi = batch[self.mst].spec.roi
        write_roi = read_roi.grow(-self.context, -self.context)

        mst = batch[self.mst].to_nx_graph()

        # get saved graph with 'dangling' edges. Some nodes may not be in write_roi
        # Edges crossing boundary should only be saved if lower id is contained in write_roi
        mongo_graph = self.client.get_graph(
            write_roi,
            node_attrs=self.node_attrs,
            edge_attrs=self.edge_attrs,
            node_inclusion="dangling",
            edge_inclusion="either",
        )

        for node in list(mongo_graph.nodes()):
            if self.position_attr not in mongo_graph.nodes[node]:
                mongo_graph.remove_node(node)

        num_created_nodes = 0
        num_updated_nodes = 0
        for node, attrs in mst.nodes.items():
            loc = attrs["location"]
            pos = np.floor(loc / voxel_size)
            node_id = int(math.cantor_number(pos))

            # only update attributes of nodes in the write_roi
            if write_roi.contains(loc):
                attrs_to_save = {
                    self.position_attr: tuple(float(v) for v in loc)
                }
                for attr in self.node_attrs:
                    if attr in attrs:
                        value = attrs[attr]
                        if isinstance(value, Iterable):
                            value = tuple(float(v) for v in value)
                        if isinstance(value, np.float32) or (isinstance(
                                value, Iterable) and any(
                                    [isinstance(v, np.float32)
                                     for v in value])):
                            raise ValueError(
                                f"value of attr {attr} is a np.float32")
                        attrs_to_save[attr] = value

                if node_id in mongo_graph:
                    num_updated_nodes += 1
                    mongo_attrs = mongo_graph.nodes[node_id]
                    mongo_attrs.update(attrs_to_save)

                else:
                    num_created_nodes += 1
                    mongo_graph.add_node(node_id, **attrs_to_save)

        num_created_edges = 0
        num_updated_edges = 0
        for (u, v), attrs in mst.edges.items():
            u_loc = mst.nodes[u]["location"]
            u_pos = np.floor(u_loc / voxel_size)
            u_id = int(math.cantor_number(u_pos))

            v_loc = mst.nodes[v]["location"]
            v_pos = np.floor(v_loc / voxel_size)
            v_id = int(math.cantor_number(v_pos))

            # node a is the node with the lower id out of u, v
            a_loc, a_id, b_loc, b_id = ((u_loc, u_id, v_loc,
                                         v_id) if u_id < v_id else
                                        (v_loc, v_id, u_loc, u_id))
            # only write edge if a is contained
            # may create a node without any attributes if node creation is
            # dependent on your field of view and the neighboring block
            # fails to create the same node.
            if write_roi.contains(a_loc):

                attrs_to_save = {}
                for attr in self.edge_attrs:
                    if attr in attrs:
                        value = attrs[attr]
                        if isinstance(value, Iterable):
                            value = tuple(float(v) for v in value)
                        if isinstance(value, np.float32) or (isinstance(
                                value, Iterable) and any(
                                    [isinstance(v, np.float32)
                                     for v in value])):
                            raise ValueError(
                                f"value of attr {attr} is a np.float32")
                        attrs_to_save[attr] = value

                if (u_id, v_id) in mongo_graph.edges:
                    num_updated_edges += 1
                    mongo_attrs = mongo_graph.edges[(u_id, v_id)]
                    mongo_attrs.update(attrs_to_save)
                else:
                    num_created_edges += 1
                    mongo_graph.add_edge(u_id, v_id, **attrs_to_save)

        for node in mst.nodes:
            if node in mongo_graph.nodes and write_roi.contains(
                    mst.nodes[node]["location"]):
                assert all(
                    np.isclose(
                        mongo_graph.nodes[node][self.position_attr],
                        mst.nodes[node]["location"],
                    ))
                if write_roi.contains(
                        mongo_graph.nodes[node][self.position_attr]):
                    assert (mongo_graph.nodes[node]["component_id"] ==
                            mst.nodes[node]["component_id"])

        if len(mongo_graph.nodes) > 0:
            mongo_graph.write_nodes(roi=write_roi, attributes=self.node_attrs)

        if len(mongo_graph.edges) > 0:
            for edge, attrs in mongo_graph.edges.items():
                for attr in self.edge_attrs:
                    assert attr in attrs
            mongo_graph.write_edges(roi=write_roi, attributes=self.edge_attrs)
Example #7
0
def solve_in_block(
    db_host,
    skeletonization_db,
    subsampled_skeletonization_db,
    time_limit,
    solve_number,
    graph_number,
    location_attr,
    u_name,
    v_name,
    **kwargs,
):
    logger.info("Solve in block")

    subsampled_provider = MongoDbGraphProvider(subsampled_skeletonization_db,
                                               db_host,
                                               mode="r+",
                                               directed=True)

    skeletonization_provider = MongoDbGraphProvider(skeletonization_db,
                                                    db_host,
                                                    mode="r+")

    client = daisy.Client()

    while True:
        logger.info("Acquire block")
        block = client.acquire_block()

        if not block:
            return 0

        logger.debug("Solving in block %s", block)

        if check_function(
                block,
                "solve_s{}".format(solve_number),
                subsampled_skeletonization_db,
                db_host,
        ):
            client.release_block(block, 0)
            continue

        start_time = time.time()
        skeletonization = skeletonization_provider.get_graph(
            block.read_roi, node_inclusion="dangling", edge_inclusion="both")
        # anything in matched was solved previously and must be maintained.
        pre_solved = subsampled_provider.get_graph(block.read_roi,
                                                   node_inclusion="dangling",
                                                   edge_inclusion="both")

        # if len(skeletonization.nodes()) > 10_000:
        #     to_remove = set(skeletonization.nodes()) - set(pre_solved.nodes())
        #     skeletonization.remove_nodes_from(to_remove)
        #     logger.info(f"Solving for {len(skeletonization.nodes())} would take too long")
        #     logger.info(f"Ignoring {len(to_remove)} nodes and skipping this block!")

        logger.info(
            f"Reading skeletonization with {len(skeletonization.nodes)} nodes and "
            +
            f"{len(skeletonization.edges)} edges took {time.time() - start_time} seconds"
        )

        if len(skeletonization.nodes) == 0:
            logger.info(
                f"No consensus nodes in roi {block.read_roi}. Skipping")
            write_done(block, f"solve_s{solve_number}",
                       subsampled_skeletonization_db, db_host)
            client.release_block(block, 0)
            continue

        logger.info("PreProcess...")
        start_time = time.time()

        logger.info(
            f"Skeletoniation has {len(skeletonization.nodes)} nodes "
            f"and {len(skeletonization.edges)} edges before subsampling")

        num_removed = remove_line_nodes(skeletonization, location_attr)
        logger.info(f"Removed {num_removed} nodes from skeletonization!")

        num_nodes, num_edges = write_matched(
            db_host,
            subsampled_skeletonization_db,
            block,
            skeletonization,
            pre_solved,
            location_attr,
            u_name,
            v_name,
        )

        logger.info(
            f"Writing matched graph with {num_nodes} nodes and {num_edges} edges "
            f"took {time.time()-start_time} seconds")

        logger.info("Write done")
        write_done(
            block,
            "solve_s{}".format(solve_number),
            subsampled_skeletonization_db,
            db_host,
        )

        logger.info("Release block")
        client.release_block(block, 0)

    return 0
class FilteredDaisyGraphProvider(BatchProvider):
    """
    See documentation for mongo graph provider at
    https://github.com/funkelab/daisy/blob/0.3-dev/daisy/persistence/mongodb_graph_provider.py#L17
    """
    def __init__(
        self,
        dbname: str,
        url: str,
        points: List[GraphKey],
        graph_specs: Optional[Union[GraphSpec, List[GraphSpec]]] = None,
        directed: bool = False,
        total_roi: Roi = None,
        nodes_collection: str = "nodes",
        edges_collection: str = "edges",
        meta_collection: str = "meta",
        endpoint_names: Tuple[str, str] = ("u", "v"),
        position_attribute: str = "position",
        node_attrs: Optional[List[str]] = None,
        edge_attrs: Optional[List[str]] = None,
        nodes_filter: Optional[Dict[str, Any]] = None,
        edges_filter: Optional[Dict[str, Any]] = None,
        num_nodes=100000,
        dist_attribute=None,
        min_dist=29000,
    ):
        self.points = points
        graph_specs = (graph_specs if graph_specs is not None else GraphSpec(
            Roi(Coordinate([None] * 3), Coordinate([None] * 3)),
            directed=False))
        specs = (graph_specs if isinstance(graph_specs, list)
                 and len(graph_specs) == len(points) else [graph_specs] *
                 len(points))
        self.specs = {key: spec for key, spec in zip(points, specs)}

        self.position_attribute = position_attribute
        self.node_attrs = node_attrs
        self.edge_attrs = edge_attrs
        self.nodes_filter = nodes_filter
        self.edges_filter = edges_filter
        self.dist_attribute = dist_attribute
        self.min_dist = min_dist

        self.num_nodes = num_nodes

        self.graph_provider = MongoDbGraphProvider(
            dbname,
            url,
            mode="r+",
            directed=directed,
            total_roi=None,
            nodes_collection=nodes_collection,
            edges_collection=edges_collection,
            meta_collection=meta_collection,
            endpoint_names=endpoint_names,
            position_attribute=position_attribute,
        )

    def setup(self):
        for key, spec in self.specs.items():
            self.provides(key, spec)

    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion="either",
                node_inclusion="dangling",
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )
            for node, attrs in list(requested_graph.nodes.items()):
                if self.dist_attribute in attrs:
                    if attrs[self.dist_attribute] < self.min_dist:
                        requested_graph.remove_node(node)

            logger.debug(
                f"{len(requested_graph.nodes)} nodes remaining after filtering by distance"
            )

            if len(requested_graph.nodes) > self.num_nodes:
                nodes = list(requested_graph.nodes)
                nodes_to_keep = set(random.sample(nodes, self.num_nodes))

                for node in list(requested_graph.nodes()):
                    if node not in nodes_to_keep:
                        requested_graph.remove_node(node)

            for node, attrs in requested_graph.nodes.items():
                attrs["location"] = np.array(attrs[self.position_attribute],
                                             dtype=np.float32)
                attrs["id"] = node

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            logger.debug(
                f"providing {key} with {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            points = Graph.from_nx_graph(requested_graph, spec)
            points.crop(spec.roi)
            batch[key] = points

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
config = read_data_config(Path("config.ini"))

print(config["sample"])

consensus_provider = MongoDbGraphProvider(config["consensus_db"],
                                          config["db_host"],
                                          mode="r+",
                                          directed=True)

subdivided_provider = MongoDbGraphProvider(config["subdivided_db"],
                                           config["db_host"],
                                           mode="w",
                                           directed=True)

consensus = consensus_provider.get_graph(config["roi"],
                                         node_inclusion="dangling",
                                         edge_inclusion="either")

print(f"Read consensus with {len(consensus.nodes)} nodes")

subdivided = subdivided_provider.get_graph(config["roi"],
                                           node_inclusion="dangling",
                                           edge_inclusion="either")
print(f"Read subdivided with {len(subdivided.nodes)} nodes")

target_edge_len = config["target_edge_len"]
pos_attr = config["location_attr"]

assert len(subdivided.nodes) == 0

node_id = itertools.count()
class GlobalConnectedComponents(BatchFilter):
    def __init__(
        self,
        db_name,
        db_host,
        collection,
        graph,
        size_attr,
        component_attr,
        read_size,
        write_size,
        mode="r+",
    ):
        self.db_name = db_name
        self.db_host = db_host
        self.collection = collection
        self.component_collection = f"{self.collection}_components"
        self.component_edge_collection = f"{self.collection}_component_edges"
        self.client = None
        self.mode = mode

        self.graph = graph
        self.size_attr = size_attr
        self.component_attr = component_attr
        self.read_size = read_size
        self.write_size = write_size

        self.context = (read_size - write_size) / 2

        assert self.read_size == self.write_size + self.context * 2

    def setup(self):
        self.updates(self.graph, self.spec[self.graph].copy())

        if self.client is None:
            self.client = MongoDbGraphProvider(
                self.db_name,
                self.db_host,
                mode=self.mode,
                directed=False,
                nodes_collection=self.component_collection,
                edges_collection=self.component_edge_collection,
            )

    def prepare(self, request):
        deps = BatchRequest()
        deps[self.graph] = request[self.graph].copy()
        assert (request[self.graph].roi.get_shape() == self.read_size
                ), f"Got wrong size graph in request"

        return deps

    def process(self, batch, request):
        g = batch[self.graph].to_nx_graph()

        logger.debug(
            f"{self.name()} got graph with {g.number_of_nodes()} nodes, and "
            f"{g.number_of_edges()} edges!")

        write_roi = batch[self.graph].spec.roi.grow(-self.context,
                                                    -self.context)

        contained_nodes = [
            node for node, attr in g.nodes.items()
            if write_roi.contains(attr["location"])
        ]
        contained_components = set(g.nodes[n][self.component_attr]
                                   for n in contained_nodes)

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        component_graph = self.client.get_graph(roi=write_roi,
                                                node_inclusion="dangling",
                                                edge_inclusion="either")

        for node in contained_nodes:
            attrs = g.nodes[node]
            block_component_id = attrs[self.component_attr]
            global_component_id = component_graph.nodes[block_component_id][
                self.component_attr]
            attrs[self.component_attr] = global_component_id
            attrs[self.size_attr] = component_graph.nodes[block_component_id][
                self.size_attr]

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        outputs = Batch()
        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())

        return outputs
class MongoWriteComponents(BatchFilter):
    def __init__(
        self,
        graph,
        db_host,
        db_name,
        read_size,
        write_size,
        component_attr,
        node_attrs=None,
        collection="",
        mode="r+",
    ):
        self.graph = graph
        self.db_host = db_host
        self.db_name = db_name
        self.read_size = read_size
        self.write_size = write_size
        self.context = (read_size - write_size) / 2
        self.mode = mode
        self.component_attr = component_attr

        assert self.write_size + (self.context * 2) == self.read_size

        self.collection = collection
        self.component_collection = f"{self.collection}_components"
        self.component_edge_collection = f"{self.collection}_component_edges"

        if node_attrs is None:
            self.node_attrs = []
        else:
            self.node_attrs = node_attrs

        self.client = None

    def setup(self):
        self.updates(self.graph, self.spec[self.graph].copy())

        if self.client is None:
            self.client = MongoDbGraphProvider(
                self.db_name,
                self.db_host,
                mode=self.mode,
                directed=False,
                nodes_collection=self.component_collection,
                edges_collection=self.component_edge_collection,
            )

    def prepare(self, request):
        deps = BatchRequest()
        deps[self.graph] = request[self.graph].copy()
        assert (request[self.graph].roi.get_shape() == self.read_size
                ), f"Got wrong size graph in request"

        return deps

    def process(self, batch, request):

        graph = batch[self.graph].to_nx_graph()

        logger.debug(
            f"{self.name()} got graph with {graph.number_of_nodes()} nodes and "
            f"{graph.number_of_edges()} edges!")

        write_roi = batch[self.graph].spec.roi.grow(-self.context,
                                                    -self.context)

        contained_nodes = [
            node for node, attr in graph.nodes.items()
            if write_roi.contains(attr["location"])
        ]
        contained_components = set(graph.nodes[n][self.component_attr]
                                   for n in contained_nodes)

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        mongo_graph = self.client.get_graph(
            write_roi,
            node_attrs=[],
            edge_attrs=[],
            node_inclusion="dangling",
            edge_inclusion="either",
        )

        num_new_components = 0
        num_new_component_edges = 0
        for node, attrs in graph.nodes.items():
            if write_roi.contains(attrs["location"]):
                cc_id = attrs[self.component_attr]
                node_attrs = {}
                for attr in self.node_attrs:
                    node_attrs[attr] = attrs[attr]
                if cc_id not in mongo_graph.nodes:
                    num_new_components += 1
                    mongo_graph.add_node(cc_id,
                                         position=write_roi.get_center(),
                                         **node_attrs)
                else:
                    for k, v in node_attrs.items():
                        assert mongo_graph.nodes[cc_id][k] == v

        # Always write crossing edges if neither component id is None. Other end
        # point wont have a component id unless it has already been processed,
        # thus it is the duty of the second pass to write the edge, regardless
        # of whether the lower or upper end point is contained.
        for u, v in graph.edges:
            u_loc = graph.nodes[u]["location"]
            v_loc = graph.nodes[v]["location"]
            if write_roi.contains(u_loc) or write_roi.contains(v_loc):
                u_cc_id = graph.nodes[u].get(self.component_attr)
                v_cc_id = graph.nodes[v].get(self.component_attr)

                if u_cc_id is None or v_cc_id is None:
                    continue
                elif u_cc_id == v_cc_id:
                    continue
                elif (u_cc_id, v_cc_id) not in mongo_graph.edges:
                    num_new_component_edges += 1
                    mongo_graph.add_edge(u_cc_id, v_cc_id)

        logger.debug(
            f"{self.name()} writing {num_new_components} new components and "
            f"{num_new_component_edges} new component edges!")

        mongo_graph.write_nodes()
        mongo_graph.write_edges(ignore_v=False)