def overlay_segmentation(db_host, db_name, roi_offset, roi_size, selected_attr, solved_attr, edge_collection, segmentation_container, segmentation_dataset, segmentation_number, voxel_size=(40, 4, 4)): graph_provider = MongoDbGraphProvider(db_name, db_host, directed=False, position_attribute=['z', 'y', 'x'], edges_collection=edge_collection) graph_roi = daisy.Roi(roi_offset, roi_size) segmentation = daisy.open_ds(segmentation_container, segmentation_dataset) intersection_roi = segmentation.roi.intersect(graph_roi).snap_to_grid( voxel_size) nx_graph = graph_provider.get_graph(intersection_roi, nodes_filter={selected_attr: True}, edges_filter={selected_attr: True}) for node_id, data in nx_graph.nodes(data=True): node_position = daisy.Coordinate((data["z"], data["y"]), data["x"]) nx_graph.nodes[node_id]["segmentation_{}".format( segmentation_number)] = segmentation[node_position] graph_provider.write_nodes(intersection_roi)
def setup(self): self.updates(self.graph, self.spec[self.graph].copy()) if self.client is None: self.client = MongoDbGraphProvider( self.db_name, self.db_host, mode=self.mode, directed=False, nodes_collection=self.component_collection, edges_collection=self.component_edge_collection, )
def __init__( self, dbname: str, url: str, points: List[GraphKey], graph_specs: Optional[Union[GraphSpec, List[GraphSpec]]] = None, directed: bool = False, total_roi: Roi = None, nodes_collection: str = "nodes", edges_collection: str = "edges", meta_collection: str = "meta", endpoint_names: Tuple[str, str] = ("u", "v"), position_attribute: str = "position", node_attrs: Optional[List[str]] = None, edge_attrs: Optional[List[str]] = None, nodes_filter: Optional[Dict[str, Any]] = None, edges_filter: Optional[Dict[str, Any]] = None, num_nodes=100000, dist_attribute=None, min_dist=29000, ): self.points = points graph_specs = (graph_specs if graph_specs is not None else GraphSpec( Roi(Coordinate([None] * 3), Coordinate([None] * 3)), directed=False)) specs = (graph_specs if isinstance(graph_specs, list) and len(graph_specs) == len(points) else [graph_specs] * len(points)) self.specs = {key: spec for key, spec in zip(points, specs)} self.position_attribute = position_attribute self.node_attrs = node_attrs self.edge_attrs = edge_attrs self.nodes_filter = nodes_filter self.edges_filter = edges_filter self.dist_attribute = dist_attribute self.min_dist = min_dist self.num_nodes = num_nodes self.graph_provider = MongoDbGraphProvider( dbname, url, mode="r+", directed=directed, total_roi=None, nodes_collection=nodes_collection, edges_collection=edges_collection, meta_collection=meta_collection, endpoint_names=endpoint_names, position_attribute=position_attribute, )
def setup(self): # Initialize client. Doesn't the daisy mongodb graph provider handle this? if self.client is None: self.client = MongoDbGraphProvider( self.db_name, self.db_host, mode=self.mode, directed=self.directed, nodes_collection=self.nodes_collection, edges_collection=self.edges_collection, ) self.updates(self.mst, self.spec[self.mst].copy())
def setUp(self): # y: 1| 0(s)------1(s)------2(s) # 2| # 3| 3(s) 6(s) # 4| | | # 5| 4(s)-(ns)-9(s)-(ns)-7(s) # 6| | | # 7| 5(s) 8(s) # |--------------------------------> # x: 1 2 3 # # s = selected # ns = not selected self.nodes = [ {'id': 0, 'z': 1, 'y': 1, 'x': 1, 'selected': True, 'solved': True}, {'id': 1, 'z': 1, 'y': 1, 'x': 2, 'selected': True, 'solved': True}, {'id': 2, 'z': 1, 'y': 1, 'x': 3, 'selected': True, 'solved': True}, {'id': 3, 'z': 1, 'y': 3, 'x': 1, 'selected': True, 'solved': True}, {'id': 4, 'z': 1, 'y': 5, 'x': 1, 'selected': True, 'solved': True}, {'id': 5, 'z': 1, 'y': 7, 'x': 1, 'selected': True, 'solved': True}, {'id': 6, 'z': 1, 'y': 3, 'x': 3, 'selected': True, 'solved': True}, {'id': 7, 'z': 1, 'y': 5, 'x': 3, 'selected': True, 'solved': True}, {'id': 8, 'z': 1, 'y': 7, 'x': 3, 'selected': True, 'solved': True}, {'id': 9, 'z': 1, 'y': 5, 'x': 2, 'selected': True, 'solved': True} ] self.edges = [{'u': 0, 'v': 1, 'evidence': 0.5, 'selected': True, 'solved': True}, {'u': 1, 'v': 2, 'evidence': 0.5, 'selected': True, 'solved': True}, {'u': 3, 'v': 4, 'evidence': 0.5, 'selected': True, 'solved': True}, {'u': 4, 'v': 5, 'evidence': 0.5, 'selected': True, 'solved': True}, {'u': 6, 'v': 7, 'evidence': 0.5, 'selected': True, 'solved': True}, {'u': 7, 'v': 8, 'evidence': 0.5, 'selected': True, 'solved': True}, {'u': 4, 'v': 9, 'evidence': 0.5, 'selected': False, 'solved': False}, {'u': 7, 'v': 9, 'evidence': 0.5, 'selected': False, 'solved': False} ] #self.nodes = self.nodes[:3] self.db_name = 'micron_test_solver' config = configparser.ConfigParser() config.read(os.path.expanduser("../mongo.ini")) self.db_host = "mongodb://{}:{}@{}:{}".format(config.get("Credentials", "user"), config.get("Credentials", "password"), config.get("Credentials", "host"), config.get("Credentials", "port")) self.graph_provider = MongoDbGraphProvider(self.db_name, self.db_host, mode='w', position_attribute=['z', 'y', 'x']) self.roi = daisy.Roi((0,0,0), (4,4,4)) self.graph = self.graph_provider[self.roi] self.graph.add_nodes_from([(node['id'], node) for node in self.nodes]) self.graph.add_edges_from([(edge['u'], edge['v'], edge) for edge in self.edges]) self.solve_params = {"graph": self.graph, "evidence_factor": 12, "comb_angle_factor": 14, "start_edge_prior": 180, "selection_cost": -80}
def setup(self): for key, spec in self.specs.items(): self.provides(key, spec) if self.graph_provider is None: self.graph_provider = MongoDbGraphProvider( self.dbname, self.url, mode="r+", directed=self.directed, total_roi=None, nodes_collection=self.nodes_collection, edges_collection=self.edges_collection, meta_collection=self.meta_collection, endpoint_names=self.endpoint_names, position_attribute=self.position_attribute, )
def get_graph(db_host, db_name, roi_offset, roi_size, solve_number, edge_collection): """ Get selected subgraph containing no isolated selected edges. (All edges with not 2 selected vertices are filtered out) """ selected_attr = "selected_{}".format(solve_number) solved_attr = "solved_{}".format(solve_number) graph_provider = MongoDbGraphProvider(db_name, db_host, directed=False, position_attribute=['z', 'y', 'x'], edges_collection=edge_collection) roi = daisy.Roi(roi_offset, roi_size) nx_graph = graph_provider.get_graph( roi, nodes_filter={selected_attr: True}, edges_filter={selected_attr: True}, node_attrs=[selected_attr, solved_attr, "x", "y", "z"], edge_attrs=[selected_attr, solved_attr, "x", "y", "z"]) # NOTE: We filter edges that do not have # two selected vertices. edges_to_remove = set() nodes_to_remove = set() for e in nx_graph.edges(): u = nx_graph.nodes()[e[0]] v = nx_graph.nodes()[e[1]] if not u: edges_to_remove.add((e[0], e[1])) nodes_to_remove.add(e[0]) if not v: edges_to_remove.add((e[0], e[1])) nodes_to_remove.add(e[1]) nx_graph.remove_edges_from(edges_to_remove) nx_graph.remove_nodes_from(nodes_to_remove) return nx_graph
def test_solve(self): # y: 1| 0--1--2 # x: 1 2 3 # Test all combinations of # s/ns combinations = [(False, False), (False, True), (True, True)] for v0 in combinations: for v1 in combinations: for v2 in combinations: for e0 in combinations: for e1 in combinations: self.nodes = [ {'id': 0, 'z': 1, 'y': 1, 'x': 1, 'selected': v0[0], 'solved': v0[1]}, {'id': 1, 'z': 1, 'y': 1, 'x': 2, 'selected': v1[0], 'solved': v1[1]}, {'id': 2, 'z': 1, 'y': 1, 'x': 3, 'selected': v2[0], 'solved': v2[1]}, ] self.edges = [{'u': 0, 'v': 1, 'evidence': 0.5, 'selected': e0[0], 'solved': e0[1]}, {'u': 1, 'v': 2, 'evidence': 0.5, 'selected': e1[0], 'solved': e1[1]}] self.db_name = 'micron_test_solver' config = configparser.ConfigParser() config.read(os.path.expanduser("../mongo.ini")) self.db_host = "mongodb://{}:{}@{}:{}".format(config.get("Credentials", "user"), config.get("Credentials", "password"), config.get("Credentials", "host"), config.get("Credentials", "port")) self.graph_provider = MongoDbGraphProvider(self.db_name, self.db_host, mode='w', position_attribute=['z', 'y', 'x']) self.roi = daisy.Roi((0,0,0), (4,4,4)) self.graph = self.graph_provider[self.roi] self.graph.add_nodes_from([(node['id'], node) for node in self.nodes]) self.graph.add_edges_from([(edge['u'], edge['v'], edge) for edge in self.edges]) self.solve_params = {"graph": self.graph, "evidence_factor": 12, "comb_angle_factor": 14, "start_edge_prior": 180, "selection_cost": -80} solver = Solver(**self.solve_params) solver.initialize() solver.solve() graph = self.solve_params["graph"] selected_edges = [(e[0], e[1]) for e in graph.edges(data=True) if e[2]["selected"]] selected_nodes = [v[0] for v in graph.nodes(data=True) if v[1]["selected"]] client = pymongo.MongoClient(self.db_host) client.drop_database(self.db_name)
def label_connected_components(db_host, db_name, roi, selected_attr, solved_attr, edge_collection, label_attribute="label"): graph_provider = MongoDbGraphProvider(db_name, db_host, directed=False, position_attribute=['z', 'y', 'x'], edges_collection=edge_collection) graph = graph_provider.get_graph(roi, nodes_filter={selected_attr: True}, edges_filter={selected_attr: True}) lut = find_connected_components(graph, node_component_attribute=label_attribute, return_lut=True) return graph, lut
def get_graph(db_host, db_name, roi_offset=(158000, 121800, 448616), roi_size=(7600, 5200, 2200), selected_only=False, selected_attr="selected", solved_attr="solved", edge_collection="edges"): graph = MongoDbGraphProvider(db_name, db_host, directed=False, position_attribute=['z', 'y', 'x'], edges_collection=edge_collection) roi = daisy.Roi(roi_offset, roi_size) nodes, edges = graph.read_blockwise(roi, block_size=daisy.Coordinate( (10000, 10000, 10000)), num_workers=40) if len(edges["u"]) != 0: if selected_only: edges = [(u, v, selected, solved) for u, v, selected, solved in zip(edges["u"], edges["v"], edges[selected_attr], edges[solved_attr]) if selected] nodes = {node_id: (z,y,x) for z,y,x,node_id, selected in\ zip(nodes["z"], nodes["y"], nodes["x"], nodes["id"], nodes[selected_attr]) if selected} else: edges = [(u, v, selected, solved) for u, v, selected, solved in zip(edges["u"], edges["v"], edges[selected_attr], edges[solved_attr])] else: edges = [] return nodes, edges
def solve_in_block( db_host, skeletonization_db, subsampled_skeletonization_db, time_limit, solve_number, graph_number, location_attr, u_name, v_name, **kwargs, ): logger.info("Solve in block") subsampled_provider = MongoDbGraphProvider(subsampled_skeletonization_db, db_host, mode="r+", directed=True) skeletonization_provider = MongoDbGraphProvider(skeletonization_db, db_host, mode="r+") client = daisy.Client() while True: logger.info("Acquire block") block = client.acquire_block() if not block: return 0 logger.debug("Solving in block %s", block) if check_function( block, "solve_s{}".format(solve_number), subsampled_skeletonization_db, db_host, ): client.release_block(block, 0) continue start_time = time.time() skeletonization = skeletonization_provider.get_graph( block.read_roi, node_inclusion="dangling", edge_inclusion="both") # anything in matched was solved previously and must be maintained. pre_solved = subsampled_provider.get_graph(block.read_roi, node_inclusion="dangling", edge_inclusion="both") # if len(skeletonization.nodes()) > 10_000: # to_remove = set(skeletonization.nodes()) - set(pre_solved.nodes()) # skeletonization.remove_nodes_from(to_remove) # logger.info(f"Solving for {len(skeletonization.nodes())} would take too long") # logger.info(f"Ignoring {len(to_remove)} nodes and skipping this block!") logger.info( f"Reading skeletonization with {len(skeletonization.nodes)} nodes and " + f"{len(skeletonization.edges)} edges took {time.time() - start_time} seconds" ) if len(skeletonization.nodes) == 0: logger.info( f"No consensus nodes in roi {block.read_roi}. Skipping") write_done(block, f"solve_s{solve_number}", subsampled_skeletonization_db, db_host) client.release_block(block, 0) continue logger.info("PreProcess...") start_time = time.time() logger.info( f"Skeletoniation has {len(skeletonization.nodes)} nodes " f"and {len(skeletonization.edges)} edges before subsampling") num_removed = remove_line_nodes(skeletonization, location_attr) logger.info(f"Removed {num_removed} nodes from skeletonization!") num_nodes, num_edges = write_matched( db_host, subsampled_skeletonization_db, block, skeletonization, pre_solved, location_attr, u_name, v_name, ) logger.info( f"Writing matched graph with {num_nodes} nodes and {num_edges} edges " f"took {time.time()-start_time} seconds") logger.info("Write done") write_done( block, "solve_s{}".format(solve_number), subsampled_skeletonization_db, db_host, ) logger.info("Release block") client.release_block(block, 0) return 0
class GlobalConnectedComponents(BatchFilter): def __init__( self, db_name, db_host, collection, graph, size_attr, component_attr, read_size, write_size, mode="r+", ): self.db_name = db_name self.db_host = db_host self.collection = collection self.component_collection = f"{self.collection}_components" self.component_edge_collection = f"{self.collection}_component_edges" self.client = None self.mode = mode self.graph = graph self.size_attr = size_attr self.component_attr = component_attr self.read_size = read_size self.write_size = write_size self.context = (read_size - write_size) / 2 assert self.read_size == self.write_size + self.context * 2 def setup(self): self.updates(self.graph, self.spec[self.graph].copy()) if self.client is None: self.client = MongoDbGraphProvider( self.db_name, self.db_host, mode=self.mode, directed=False, nodes_collection=self.component_collection, edges_collection=self.component_edge_collection, ) def prepare(self, request): deps = BatchRequest() deps[self.graph] = request[self.graph].copy() assert (request[self.graph].roi.get_shape() == self.read_size ), f"Got wrong size graph in request" return deps def process(self, batch, request): g = batch[self.graph].to_nx_graph() logger.debug( f"{self.name()} got graph with {g.number_of_nodes()} nodes, and " f"{g.number_of_edges()} edges!") write_roi = batch[self.graph].spec.roi.grow(-self.context, -self.context) contained_nodes = [ node for node, attr in g.nodes.items() if write_roi.contains(attr["location"]) ] contained_components = set(g.nodes[n][self.component_attr] for n in contained_nodes) logger.debug(f"Graph contains {len(contained_nodes)} nodes with " f"{len(contained_components)} components in write_roi") component_graph = self.client.get_graph(roi=write_roi, node_inclusion="dangling", edge_inclusion="either") for node in contained_nodes: attrs = g.nodes[node] block_component_id = attrs[self.component_attr] global_component_id = component_graph.nodes[block_component_id][ self.component_attr] attrs[self.component_attr] = global_component_id attrs[self.size_attr] = component_graph.nodes[block_component_id][ self.size_attr] logger.debug(f"Graph contains {len(contained_nodes)} nodes with " f"{len(contained_components)} components in write_roi") outputs = Batch() outputs[self.graph] = Graph.from_nx_graph( g, batch[self.graph].spec.copy()) return outputs
def extract_edges_in_block(db_name, db_host, soft_mask_container, soft_mask_dataset, distance_threshold, evidence_threshold, graph_number, block): graph_provider = MongoDbGraphProvider( db_name, db_host, mode='r+', position_attribute=['z', 'y', 'x'], directed=False, edges_collection='edges_g{}'.format(graph_number)) if check_function(graph_provider.database, block, "edges_g{}".format(graph_number)): return 0 logger.debug("Finding edges in %s, reading from %s", block.write_roi, block.read_roi) start = time.time() soft_mask_array = daisy.open_ds(soft_mask_container, soft_mask_dataset) graph = graph_provider[block.read_roi.intersect(soft_mask_array.roi)] if graph.number_of_nodes() == 0: logger.info("No nodes in roi %s. Skipping", block.read_roi) write_done(graph_provider.database, block, 'edges_g{}'.format(graph_number)) return 0 logger.debug("Read %d candidates in %.3fs", graph.number_of_nodes(), time.time() - start) start = time.time() """ candidates = [(candidate_id, np.array([data[d] for d in ['z', 'y', 'x']])) for candidate_id, data in graph.nodes(data=True) if 'z' in data] """ candidates = np.array( [[candidate_id] + [data[d] for d in ['z', 'y', 'x']] for candidate_id, data in graph.nodes(data=True) if 'z' in data], dtype=np.uint64) kdtree_start = time.time() kdtree = KDTree([[candidate[1], candidate[2], candidate[3]] for candidate in candidates]) #kdtree = KDTree(candidates[]) pairs = kdtree.query_pairs(distance_threshold, p=2.0, eps=0) logger.debug("Query pairs in %.3fs", time.time() - kdtree_start) soft_mask_array = daisy.open_ds(soft_mask_container, soft_mask_dataset) voxel_size = np.array(soft_mask_array.voxel_size, dtype=np.uint32) soft_mask_roi = block.read_roi.snap_to_grid( voxel_size=voxel_size).intersect(soft_mask_array.roi) soft_mask_array_data = soft_mask_array.to_ndarray(roi=soft_mask_roi) sm_dtype = soft_mask_array_data.dtype if sm_dtype == np.uint8: # standard pipeline pm 0-255 pass elif sm_dtype == np.float32 or sm_dtype == np.float64: if not (soft_mask_array_data.min() >= 0 and soft_mask_array_data.max() <= 1): raise ValueError( "Provided soft_mask has dtype float but not in range [0,1], abort" ) else: soft_mask_array_data *= 255 else: raise ValueError("Soft mask dtype {} not understood".format(sm_dtype)) soft_mask_array_data = soft_mask_array_data.astype(np.float64) if evidence_threshold is not None: soft_mask_array_data = (soft_mask_array_data >= evidence_threshold * 255).astype(np.float64) * 255 offset = np.array(np.array(soft_mask_roi.get_offset()) / voxel_size, dtype=np.uint64) evidence_start = time.time() if pairs: pairs = np.array(list(pairs), dtype=np.uint64) evidence_array = cpp_get_evidence(candidates, pairs, soft_mask_array_data, offset, voxel_size) graph.add_weighted_edges_from(evidence_array, weight='evidence') logger.debug("Accumulate evidence in %.3fs", time.time() - evidence_start) logger.debug("Found %d edges", graph.number_of_edges()) logger.debug("Extracted edges in %.3fs", time.time() - start) start = time.time() graph.write_edges(block.write_roi) logger.debug("Wrote edges in %.3fs", time.time() - start) else: logger.debug("No pairs in block, skip") write_done(graph_provider.database, block, 'edges_g{}'.format(graph_number)) return 0
import itertools from pathlib import Path import logging from tqdm import tqdm from config_parser import read_data_config logging.basicConfig(level=logging.INFO) config = read_data_config(Path("config.ini")) print(config["sample"]) consensus_provider = MongoDbGraphProvider(config["consensus_db"], config["db_host"], mode="r+", directed=True) subdivided_provider = MongoDbGraphProvider(config["subdivided_db"], config["db_host"], mode="w", directed=True) consensus = consensus_provider.get_graph(config["roi"], node_inclusion="dangling", edge_inclusion="either") print(f"Read consensus with {len(consensus.nodes)} nodes") subdivided = subdivided_provider.get_graph(config["roi"], node_inclusion="dangling",
class DaisyGraphProvider(BatchProvider): """ See documentation for mongo graph provider at https://github.com/funkelab/daisy/blob/0.3-dev/daisy/persistence/mongodb_graph_provider.py#L17 """ def __init__( self, dbname: str, url: str, points: List[GraphKey], graph_specs: Optional[Union[GraphSpec, List[GraphSpec]]] = None, directed: bool = False, total_roi: Roi = None, nodes_collection: str = "nodes", edges_collection: str = "edges", meta_collection: str = "meta", endpoint_names: Tuple[str, str] = ("u", "v"), position_attribute: str = "position", node_attrs: Optional[List[str]] = None, edge_attrs: Optional[List[str]] = None, nodes_filter: Optional[Dict[str, Any]] = None, edges_filter: Optional[Dict[str, Any]] = None, edge_inclusion: str = "either", node_inclusion: str = "dangling", fail_on_inconsistent_node: bool = False, ): self.points = points graph_specs = (graph_specs if graph_specs is not None else GraphSpec( Roi(Coordinate([None] * 3), Coordinate([None] * 3)), directed=False)) specs = (graph_specs if isinstance(graph_specs, list) and len(graph_specs) == len(points) else [graph_specs] * len(points)) self.specs = {key: spec for key, spec in zip(points, specs)} self.directed = directed self.nodes_collection = nodes_collection self.edges_collection = edges_collection self.meta_collection = meta_collection self.endpoint_names = endpoint_names self.position_attribute = position_attribute self.position_attribute = position_attribute self.node_attrs = node_attrs self.edge_attrs = edge_attrs self.nodes_filter = nodes_filter self.edges_filter = edges_filter self.edge_inclusion = edge_inclusion self.node_inclusion = node_inclusion self.dbname = dbname self.url = url self.nodes_collection = nodes_collection self.fail_on_inconsistent_node = fail_on_inconsistent_node self.graph_provider = None def setup(self): for key, spec in self.specs.items(): self.provides(key, spec) if self.graph_provider is None: self.graph_provider = MongoDbGraphProvider( self.dbname, self.url, mode="r+", directed=self.directed, total_roi=None, nodes_collection=self.nodes_collection, edges_collection=self.edges_collection, meta_collection=self.meta_collection, endpoint_names=self.endpoint_names, position_attribute=self.position_attribute, ) def provide(self, request): timing = Timing(self) timing.start() batch = Batch() for key, spec in request.items(): logger.debug(f"fetching {key} in roi {spec.roi}") requested_graph = self.graph_provider.get_graph( spec.roi, edge_inclusion=self.edge_inclusion, node_inclusion=self.node_inclusion, node_attrs=self.node_attrs, edge_attrs=self.edge_attrs, nodes_filter=self.nodes_filter, edges_filter=self.edges_filter, ) logger.debug( f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges" ) failed_nodes = [] for node, attrs in requested_graph.nodes.items(): try: attrs["location"] = np.array( attrs[self.position_attribute], dtype=np.float32) except KeyError: logger.warning( f"node: {node} was written (probably part of an edge), but never given coordinates!" ) failed_nodes.append(node) attrs["id"] = node for node in failed_nodes: if self.fail_on_inconsistent_node: raise ValueError( f"Mongodb contains node {node} without location! " f"It was probably written as part of an edge") requested_graph.remove_node(node) if spec.directed: requested_graph = requested_graph.to_directed() else: requested_graph = requested_graph.to_undirected() points = Graph.from_nx_graph(requested_graph, spec) points.relabel_connected_components() points.crop(spec.roi) batch[key] = points logger.debug(f"{key} with {len(list(points.nodes))} nodes") timing.stop() batch.profiling_stats.add(timing) return batch
def solve_in_block(db_host, db_name, evidence_factor, comb_angle_factor, start_edge_prior, selection_cost, time_limit, solve_number, graph_number, selected_attr="selected", solved_attr="solved", **kwargs): print("Solve in block") graph_provider = MongoDbGraphProvider( db_name, db_host, mode='r+', position_attribute=['z', 'y', 'x'], edges_collection="edges_g{}".format(graph_number)) client = daisy.Client() while True: print("Acquire block") block = client.acquire_block() if not block: return logger.debug("Solving in block %s", block) if check_function(block, 'solve_s{}'.format(solve_number), db_name, db_host): client.release_block(block, 0) continue start_time = time.time() graph = graph_provider.get_graph(block.read_roi) num_nodes = graph.number_of_nodes() num_edges = graph.number_of_edges() logger.info( "Reading graph with %d nodes and %d edges took %s seconds" % (num_nodes, num_edges, time.time() - start_time)) if num_edges == 0: logger.info("No edges in roi %s. Skipping" % block.read_roi) write_done(block, 'solve_s{}'.format(solve_number), db_name, db_host) client.release_block(block, 0) continue print("solve") solver = Solver(graph, evidence_factor, comb_angle_factor, start_edge_prior, selection_cost, time_limit, selected_attr, solved_attr) solver.initialize() solver.solve() start_time = time.time() graph.update_edge_attrs(block.write_roi, attributes=[selected_attr, solved_attr]) graph.update_node_attrs(block.write_roi, attributes=[selected_attr, solved_attr]) logger.info( "Updating attributes %s & %s for %d edges took %s seconds" % (selected_attr, solved_attr, num_edges, time.time() - start_time)) print("Write done") write_done(block, 'solve_s{}'.format(solve_number), db_name, db_host) print("Release block") client.release_block(block, 0) return 0
class FilteredDaisyGraphProvider(BatchProvider): """ See documentation for mongo graph provider at https://github.com/funkelab/daisy/blob/0.3-dev/daisy/persistence/mongodb_graph_provider.py#L17 """ def __init__( self, dbname: str, url: str, points: List[GraphKey], graph_specs: Optional[Union[GraphSpec, List[GraphSpec]]] = None, directed: bool = False, total_roi: Roi = None, nodes_collection: str = "nodes", edges_collection: str = "edges", meta_collection: str = "meta", endpoint_names: Tuple[str, str] = ("u", "v"), position_attribute: str = "position", node_attrs: Optional[List[str]] = None, edge_attrs: Optional[List[str]] = None, nodes_filter: Optional[Dict[str, Any]] = None, edges_filter: Optional[Dict[str, Any]] = None, num_nodes=100000, dist_attribute=None, min_dist=29000, ): self.points = points graph_specs = (graph_specs if graph_specs is not None else GraphSpec( Roi(Coordinate([None] * 3), Coordinate([None] * 3)), directed=False)) specs = (graph_specs if isinstance(graph_specs, list) and len(graph_specs) == len(points) else [graph_specs] * len(points)) self.specs = {key: spec for key, spec in zip(points, specs)} self.position_attribute = position_attribute self.node_attrs = node_attrs self.edge_attrs = edge_attrs self.nodes_filter = nodes_filter self.edges_filter = edges_filter self.dist_attribute = dist_attribute self.min_dist = min_dist self.num_nodes = num_nodes self.graph_provider = MongoDbGraphProvider( dbname, url, mode="r+", directed=directed, total_roi=None, nodes_collection=nodes_collection, edges_collection=edges_collection, meta_collection=meta_collection, endpoint_names=endpoint_names, position_attribute=position_attribute, ) def setup(self): for key, spec in self.specs.items(): self.provides(key, spec) def provide(self, request): timing = Timing(self) timing.start() batch = Batch() for key, spec in request.items(): logger.debug(f"fetching {key} in roi {spec.roi}") requested_graph = self.graph_provider.get_graph( spec.roi, edge_inclusion="either", node_inclusion="dangling", node_attrs=self.node_attrs, edge_attrs=self.edge_attrs, nodes_filter=self.nodes_filter, edges_filter=self.edges_filter, ) logger.debug( f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges" ) for node, attrs in list(requested_graph.nodes.items()): if self.dist_attribute in attrs: if attrs[self.dist_attribute] < self.min_dist: requested_graph.remove_node(node) logger.debug( f"{len(requested_graph.nodes)} nodes remaining after filtering by distance" ) if len(requested_graph.nodes) > self.num_nodes: nodes = list(requested_graph.nodes) nodes_to_keep = set(random.sample(nodes, self.num_nodes)) for node in list(requested_graph.nodes()): if node not in nodes_to_keep: requested_graph.remove_node(node) for node, attrs in requested_graph.nodes.items(): attrs["location"] = np.array(attrs[self.position_attribute], dtype=np.float32) attrs["id"] = node if spec.directed: requested_graph = requested_graph.to_directed() else: requested_graph = requested_graph.to_undirected() logger.debug( f"providing {key} with {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges" ) points = Graph.from_nx_graph(requested_graph, spec) points.crop(spec.roi) batch[key] = points timing.stop() batch.profiling_stats.add(timing) return batch
class MongoWriteGraph(gp.BatchFilter): def __init__( self, mst, db_host, db_name, read_size, write_size, voxel_size, directed, mode="r+", collection="", position_attr="position", node_attrs: List[str] = None, edge_attrs: List[str] = None, ): self.mst = mst self.db_host = db_host self.db_name = db_name self.client = None self.voxel_size = voxel_size self.read_size = read_size self.write_size = write_size self.context = (read_size - write_size) / 2 assert self.write_size + (self.context * 2) == self.read_size self.directed = directed self.mode = mode self.position_attr = position_attr self.nodes_collection = f"{collection}_nodes" self.edges_collection = f"{collection}_edges" self.node_attrs = node_attrs if node_attrs is not None else [] self.node_attrs += [position_attr] self.edge_attrs = edge_attrs if edge_attrs is not None else [] def setup(self): # Initialize client. Doesn't the daisy mongodb graph provider handle this? if self.client is None: self.client = MongoDbGraphProvider( self.db_name, self.db_host, mode=self.mode, directed=self.directed, nodes_collection=self.nodes_collection, edges_collection=self.edges_collection, ) self.updates(self.mst, self.spec[self.mst].copy()) def prepare(self, request): deps = gp.BatchRequest() deps[self.mst] = request[self.mst].copy() assert (request[self.mst].roi.get_shape() == self.read_size ), f"Got wrong size graph in request" return deps def process(self, batch, request): voxel_size = self.voxel_size read_roi = batch[self.mst].spec.roi write_roi = read_roi.grow(-self.context, -self.context) mst = batch[self.mst].to_nx_graph() # get saved graph with 'dangling' edges. Some nodes may not be in write_roi # Edges crossing boundary should only be saved if lower id is contained in write_roi mongo_graph = self.client.get_graph( write_roi, node_attrs=self.node_attrs, edge_attrs=self.edge_attrs, node_inclusion="dangling", edge_inclusion="either", ) for node in list(mongo_graph.nodes()): if self.position_attr not in mongo_graph.nodes[node]: mongo_graph.remove_node(node) num_created_nodes = 0 num_updated_nodes = 0 for node, attrs in mst.nodes.items(): loc = attrs["location"] pos = np.floor(loc / voxel_size) node_id = int(math.cantor_number(pos)) # only update attributes of nodes in the write_roi if write_roi.contains(loc): attrs_to_save = { self.position_attr: tuple(float(v) for v in loc) } for attr in self.node_attrs: if attr in attrs: value = attrs[attr] if isinstance(value, Iterable): value = tuple(float(v) for v in value) if isinstance(value, np.float32) or (isinstance( value, Iterable) and any( [isinstance(v, np.float32) for v in value])): raise ValueError( f"value of attr {attr} is a np.float32") attrs_to_save[attr] = value if node_id in mongo_graph: num_updated_nodes += 1 mongo_attrs = mongo_graph.nodes[node_id] mongo_attrs.update(attrs_to_save) else: num_created_nodes += 1 mongo_graph.add_node(node_id, **attrs_to_save) num_created_edges = 0 num_updated_edges = 0 for (u, v), attrs in mst.edges.items(): u_loc = mst.nodes[u]["location"] u_pos = np.floor(u_loc / voxel_size) u_id = int(math.cantor_number(u_pos)) v_loc = mst.nodes[v]["location"] v_pos = np.floor(v_loc / voxel_size) v_id = int(math.cantor_number(v_pos)) # node a is the node with the lower id out of u, v a_loc, a_id, b_loc, b_id = ((u_loc, u_id, v_loc, v_id) if u_id < v_id else (v_loc, v_id, u_loc, u_id)) # only write edge if a is contained # may create a node without any attributes if node creation is # dependent on your field of view and the neighboring block # fails to create the same node. if write_roi.contains(a_loc): attrs_to_save = {} for attr in self.edge_attrs: if attr in attrs: value = attrs[attr] if isinstance(value, Iterable): value = tuple(float(v) for v in value) if isinstance(value, np.float32) or (isinstance( value, Iterable) and any( [isinstance(v, np.float32) for v in value])): raise ValueError( f"value of attr {attr} is a np.float32") attrs_to_save[attr] = value if (u_id, v_id) in mongo_graph.edges: num_updated_edges += 1 mongo_attrs = mongo_graph.edges[(u_id, v_id)] mongo_attrs.update(attrs_to_save) else: num_created_edges += 1 mongo_graph.add_edge(u_id, v_id, **attrs_to_save) for node in mst.nodes: if node in mongo_graph.nodes and write_roi.contains( mst.nodes[node]["location"]): assert all( np.isclose( mongo_graph.nodes[node][self.position_attr], mst.nodes[node]["location"], )) if write_roi.contains( mongo_graph.nodes[node][self.position_attr]): assert (mongo_graph.nodes[node]["component_id"] == mst.nodes[node]["component_id"]) if len(mongo_graph.nodes) > 0: mongo_graph.write_nodes(roi=write_roi, attributes=self.node_attrs) if len(mongo_graph.edges) > 0: for edge, attrs in mongo_graph.edges.items(): for attr in self.edge_attrs: assert attr in attrs mongo_graph.write_edges(roi=write_roi, attributes=self.edge_attrs)
class MongoWriteComponents(BatchFilter): def __init__( self, graph, db_host, db_name, read_size, write_size, component_attr, node_attrs=None, collection="", mode="r+", ): self.graph = graph self.db_host = db_host self.db_name = db_name self.read_size = read_size self.write_size = write_size self.context = (read_size - write_size) / 2 self.mode = mode self.component_attr = component_attr assert self.write_size + (self.context * 2) == self.read_size self.collection = collection self.component_collection = f"{self.collection}_components" self.component_edge_collection = f"{self.collection}_component_edges" if node_attrs is None: self.node_attrs = [] else: self.node_attrs = node_attrs self.client = None def setup(self): self.updates(self.graph, self.spec[self.graph].copy()) if self.client is None: self.client = MongoDbGraphProvider( self.db_name, self.db_host, mode=self.mode, directed=False, nodes_collection=self.component_collection, edges_collection=self.component_edge_collection, ) def prepare(self, request): deps = BatchRequest() deps[self.graph] = request[self.graph].copy() assert (request[self.graph].roi.get_shape() == self.read_size ), f"Got wrong size graph in request" return deps def process(self, batch, request): graph = batch[self.graph].to_nx_graph() logger.debug( f"{self.name()} got graph with {graph.number_of_nodes()} nodes and " f"{graph.number_of_edges()} edges!") write_roi = batch[self.graph].spec.roi.grow(-self.context, -self.context) contained_nodes = [ node for node, attr in graph.nodes.items() if write_roi.contains(attr["location"]) ] contained_components = set(graph.nodes[n][self.component_attr] for n in contained_nodes) logger.debug(f"Graph contains {len(contained_nodes)} nodes with " f"{len(contained_components)} components in write_roi") mongo_graph = self.client.get_graph( write_roi, node_attrs=[], edge_attrs=[], node_inclusion="dangling", edge_inclusion="either", ) num_new_components = 0 num_new_component_edges = 0 for node, attrs in graph.nodes.items(): if write_roi.contains(attrs["location"]): cc_id = attrs[self.component_attr] node_attrs = {} for attr in self.node_attrs: node_attrs[attr] = attrs[attr] if cc_id not in mongo_graph.nodes: num_new_components += 1 mongo_graph.add_node(cc_id, position=write_roi.get_center(), **node_attrs) else: for k, v in node_attrs.items(): assert mongo_graph.nodes[cc_id][k] == v # Always write crossing edges if neither component id is None. Other end # point wont have a component id unless it has already been processed, # thus it is the duty of the second pass to write the edge, regardless # of whether the lower or upper end point is contained. for u, v in graph.edges: u_loc = graph.nodes[u]["location"] v_loc = graph.nodes[v]["location"] if write_roi.contains(u_loc) or write_roi.contains(v_loc): u_cc_id = graph.nodes[u].get(self.component_attr) v_cc_id = graph.nodes[v].get(self.component_attr) if u_cc_id is None or v_cc_id is None: continue elif u_cc_id == v_cc_id: continue elif (u_cc_id, v_cc_id) not in mongo_graph.edges: num_new_component_edges += 1 mongo_graph.add_edge(u_cc_id, v_cc_id) logger.debug( f"{self.name()} writing {num_new_components} new components and " f"{num_new_component_edges} new component edges!") mongo_graph.write_nodes() mongo_graph.write_edges(ignore_v=False)