Exemple #1
0
def overlay_segmentation(db_host,
                         db_name,
                         roi_offset,
                         roi_size,
                         selected_attr,
                         solved_attr,
                         edge_collection,
                         segmentation_container,
                         segmentation_dataset,
                         segmentation_number,
                         voxel_size=(40, 4, 4)):

    graph_provider = MongoDbGraphProvider(db_name,
                                          db_host,
                                          directed=False,
                                          position_attribute=['z', 'y', 'x'],
                                          edges_collection=edge_collection)
    graph_roi = daisy.Roi(roi_offset, roi_size)

    segmentation = daisy.open_ds(segmentation_container, segmentation_dataset)
    intersection_roi = segmentation.roi.intersect(graph_roi).snap_to_grid(
        voxel_size)

    nx_graph = graph_provider.get_graph(intersection_roi,
                                        nodes_filter={selected_attr: True},
                                        edges_filter={selected_attr: True})

    for node_id, data in nx_graph.nodes(data=True):
        node_position = daisy.Coordinate((data["z"], data["y"]), data["x"])
        nx_graph.nodes[node_id]["segmentation_{}".format(
            segmentation_number)] = segmentation[node_position]

    graph_provider.write_nodes(intersection_roi)
    def setup(self):
        self.updates(self.graph, self.spec[self.graph].copy())

        if self.client is None:
            self.client = MongoDbGraphProvider(
                self.db_name,
                self.db_host,
                mode=self.mode,
                directed=False,
                nodes_collection=self.component_collection,
                edges_collection=self.component_edge_collection,
            )
    def __init__(
        self,
        dbname: str,
        url: str,
        points: List[GraphKey],
        graph_specs: Optional[Union[GraphSpec, List[GraphSpec]]] = None,
        directed: bool = False,
        total_roi: Roi = None,
        nodes_collection: str = "nodes",
        edges_collection: str = "edges",
        meta_collection: str = "meta",
        endpoint_names: Tuple[str, str] = ("u", "v"),
        position_attribute: str = "position",
        node_attrs: Optional[List[str]] = None,
        edge_attrs: Optional[List[str]] = None,
        nodes_filter: Optional[Dict[str, Any]] = None,
        edges_filter: Optional[Dict[str, Any]] = None,
        num_nodes=100000,
        dist_attribute=None,
        min_dist=29000,
    ):
        self.points = points
        graph_specs = (graph_specs if graph_specs is not None else GraphSpec(
            Roi(Coordinate([None] * 3), Coordinate([None] * 3)),
            directed=False))
        specs = (graph_specs if isinstance(graph_specs, list)
                 and len(graph_specs) == len(points) else [graph_specs] *
                 len(points))
        self.specs = {key: spec for key, spec in zip(points, specs)}

        self.position_attribute = position_attribute
        self.node_attrs = node_attrs
        self.edge_attrs = edge_attrs
        self.nodes_filter = nodes_filter
        self.edges_filter = edges_filter
        self.dist_attribute = dist_attribute
        self.min_dist = min_dist

        self.num_nodes = num_nodes

        self.graph_provider = MongoDbGraphProvider(
            dbname,
            url,
            mode="r+",
            directed=directed,
            total_roi=None,
            nodes_collection=nodes_collection,
            edges_collection=edges_collection,
            meta_collection=meta_collection,
            endpoint_names=endpoint_names,
            position_attribute=position_attribute,
        )
Exemple #4
0
    def setup(self):

        # Initialize client. Doesn't the daisy mongodb graph provider handle this?
        if self.client is None:
            self.client = MongoDbGraphProvider(
                self.db_name,
                self.db_host,
                mode=self.mode,
                directed=self.directed,
                nodes_collection=self.nodes_collection,
                edges_collection=self.edges_collection,
            )

        self.updates(self.mst, self.spec[self.mst].copy())
Exemple #5
0
    def setUp(self):
        # y: 1|   0(s)------1(s)------2(s)
        #    2|
        #    3|   3(s)                6(s)
        #    4|    |                  |
        #    5|   4(s)-(ns)-9(s)-(ns)-7(s)
        #    6|    |                  |
        #    7|   5(s)                8(s)
        #     |-------------------------------->
        # x:       1         2         3
        #
        # s = selected
        # ns = not selected
        self.nodes = [
                {'id': 0, 'z': 1, 'y': 1, 'x': 1, 'selected': True, 'solved': True},
                {'id': 1, 'z': 1, 'y': 1, 'x': 2, 'selected': True, 'solved': True},
                {'id': 2, 'z': 1, 'y': 1, 'x': 3, 'selected': True, 'solved': True},
                {'id': 3, 'z': 1, 'y': 3, 'x': 1, 'selected': True, 'solved': True},
                {'id': 4, 'z': 1, 'y': 5, 'x': 1, 'selected': True, 'solved': True},
                {'id': 5, 'z': 1, 'y': 7, 'x': 1, 'selected': True, 'solved': True},
                {'id': 6, 'z': 1, 'y': 3, 'x': 3, 'selected': True, 'solved': True},
                {'id': 7, 'z': 1, 'y': 5, 'x': 3, 'selected': True, 'solved': True},
                {'id': 8, 'z': 1, 'y': 7, 'x': 3, 'selected': True, 'solved': True},
                {'id': 9, 'z': 1, 'y': 5, 'x': 2, 'selected': True, 'solved': True}
                ]

        self.edges = [{'u': 0, 'v': 1, 'evidence': 0.5, 'selected': True, 'solved': True},
                      {'u': 1, 'v': 2, 'evidence': 0.5, 'selected': True, 'solved': True},
                      {'u': 3, 'v': 4, 'evidence': 0.5, 'selected': True, 'solved': True},
                      {'u': 4, 'v': 5, 'evidence': 0.5, 'selected': True, 'solved': True},
                      {'u': 6, 'v': 7, 'evidence': 0.5, 'selected': True, 'solved': True},
                      {'u': 7, 'v': 8, 'evidence': 0.5, 'selected': True, 'solved': True},
                      {'u': 4, 'v': 9, 'evidence': 0.5, 'selected': False, 'solved': False},
                      {'u': 7, 'v': 9, 'evidence': 0.5, 'selected': False, 'solved': False}
                      ]

        #self.nodes = self.nodes[:3]

        self.db_name = 'micron_test_solver'
        config = configparser.ConfigParser()
        config.read(os.path.expanduser("../mongo.ini"))
        self.db_host = "mongodb://{}:{}@{}:{}".format(config.get("Credentials", "user"),
                                                      config.get("Credentials", "password"),
                                                      config.get("Credentials", "host"),
                                                      config.get("Credentials", "port"))

        self.graph_provider = MongoDbGraphProvider(self.db_name,
                                                   self.db_host,
                                                   mode='w',
                                                   position_attribute=['z', 'y', 'x'])
        self.roi = daisy.Roi((0,0,0), (4,4,4))
        self.graph = self.graph_provider[self.roi]
        self.graph.add_nodes_from([(node['id'], node) for node in self.nodes])
        self.graph.add_edges_from([(edge['u'], edge['v'], edge) for edge in self.edges])
        
        self.solve_params = {"graph": self.graph, 
                             "evidence_factor": 12,
                             "comb_angle_factor": 14,
                             "start_edge_prior": 180,
                             "selection_cost": -80}
    def setup(self):
        for key, spec in self.specs.items():
            self.provides(key, spec)

        if self.graph_provider is None:
            self.graph_provider = MongoDbGraphProvider(
                self.dbname,
                self.url,
                mode="r+",
                directed=self.directed,
                total_roi=None,
                nodes_collection=self.nodes_collection,
                edges_collection=self.edges_collection,
                meta_collection=self.meta_collection,
                endpoint_names=self.endpoint_names,
                position_attribute=self.position_attribute,
            )
Exemple #7
0
def get_graph(db_host, db_name, roi_offset, roi_size, solve_number,
              edge_collection):
    """
    Get selected subgraph containing no isolated
    selected edges. (All edges with not 
    2 selected vertices are filtered out)
    """

    selected_attr = "selected_{}".format(solve_number)
    solved_attr = "solved_{}".format(solve_number)

    graph_provider = MongoDbGraphProvider(db_name,
                                          db_host,
                                          directed=False,
                                          position_attribute=['z', 'y', 'x'],
                                          edges_collection=edge_collection)
    roi = daisy.Roi(roi_offset, roi_size)

    nx_graph = graph_provider.get_graph(
        roi,
        nodes_filter={selected_attr: True},
        edges_filter={selected_attr: True},
        node_attrs=[selected_attr, solved_attr, "x", "y", "z"],
        edge_attrs=[selected_attr, solved_attr, "x", "y", "z"])

    # NOTE: We filter edges that do not have
    # two selected vertices.
    edges_to_remove = set()
    nodes_to_remove = set()
    for e in nx_graph.edges():
        u = nx_graph.nodes()[e[0]]
        v = nx_graph.nodes()[e[1]]

        if not u:
            edges_to_remove.add((e[0], e[1]))
            nodes_to_remove.add(e[0])
        if not v:
            edges_to_remove.add((e[0], e[1]))
            nodes_to_remove.add(e[1])

    nx_graph.remove_edges_from(edges_to_remove)
    nx_graph.remove_nodes_from(nodes_to_remove)

    return nx_graph
Exemple #8
0
    def test_solve(self):
        # y: 1|   0--1--2
        # x:      1  2 3
        # Test all combinations of
        # s/ns

        combinations = [(False, False), (False, True), (True, True)]

        for v0 in combinations:
            for v1 in combinations:
                for v2 in combinations:
                    for e0 in combinations:
                        for e1 in combinations:
                            self.nodes = [
                                    {'id': 0, 'z': 1, 'y': 1, 'x': 1, 'selected': v0[0], 'solved': v0[1]},
                                    {'id': 1, 'z': 1, 'y': 1, 'x': 2, 'selected': v1[0], 'solved': v1[1]},
                                    {'id': 2, 'z': 1, 'y': 1, 'x': 3, 'selected': v2[0], 'solved': v2[1]},
                                    ]

                            self.edges = [{'u': 0, 'v': 1, 'evidence': 0.5, 'selected': e0[0], 'solved': e0[1]},
                                          {'u': 1, 'v': 2, 'evidence': 0.5, 'selected': e1[0], 'solved': e1[1]}]

                            self.db_name = 'micron_test_solver'
                            config = configparser.ConfigParser()
                            config.read(os.path.expanduser("../mongo.ini"))
                            self.db_host = "mongodb://{}:{}@{}:{}".format(config.get("Credentials", "user"),
                                                                          config.get("Credentials", "password"),
                                                                          config.get("Credentials", "host"),
                                                                          config.get("Credentials", "port"))

                            self.graph_provider = MongoDbGraphProvider(self.db_name,
                                                                       self.db_host,
                                                                       mode='w',
                                                                       position_attribute=['z', 'y', 'x'])
                            self.roi = daisy.Roi((0,0,0), (4,4,4))
                            self.graph = self.graph_provider[self.roi]
                            self.graph.add_nodes_from([(node['id'], node) for node in self.nodes])
                            self.graph.add_edges_from([(edge['u'], edge['v'], edge) for edge in self.edges])
                            
                            self.solve_params = {"graph": self.graph, 
                                                 "evidence_factor": 12,
                                                 "comb_angle_factor": 14,
                                                 "start_edge_prior": 180,
                                                 "selection_cost": -80}

                            solver = Solver(**self.solve_params)
                            solver.initialize()
                            solver.solve()
                            graph = self.solve_params["graph"]
                            selected_edges = [(e[0], e[1]) for e in graph.edges(data=True) if e[2]["selected"]]
                            selected_nodes = [v[0] for v in graph.nodes(data=True) if v[1]["selected"]]

                            client = pymongo.MongoClient(self.db_host)
                            client.drop_database(self.db_name)
Exemple #9
0
def label_connected_components(db_host,
                               db_name,
                               roi,
                               selected_attr,
                               solved_attr,
                               edge_collection,
                               label_attribute="label"):

    graph_provider = MongoDbGraphProvider(db_name,
                                          db_host,
                                          directed=False,
                                          position_attribute=['z', 'y', 'x'],
                                          edges_collection=edge_collection)
    graph = graph_provider.get_graph(roi,
                                     nodes_filter={selected_attr: True},
                                     edges_filter={selected_attr: True})

    lut = find_connected_components(graph,
                                    node_component_attribute=label_attribute,
                                    return_lut=True)

    return graph, lut
Exemple #10
0
def get_graph(db_host,
              db_name,
              roi_offset=(158000, 121800, 448616),
              roi_size=(7600, 5200, 2200),
              selected_only=False,
              selected_attr="selected",
              solved_attr="solved",
              edge_collection="edges"):

    graph = MongoDbGraphProvider(db_name,
                                 db_host,
                                 directed=False,
                                 position_attribute=['z', 'y', 'x'],
                                 edges_collection=edge_collection)
    roi = daisy.Roi(roi_offset, roi_size)
    nodes, edges = graph.read_blockwise(roi,
                                        block_size=daisy.Coordinate(
                                            (10000, 10000, 10000)),
                                        num_workers=40)

    if len(edges["u"]) != 0:
        if selected_only:
            edges = [(u, v, selected, solved) for u, v, selected, solved in
                     zip(edges["u"], edges["v"], edges[selected_attr],
                         edges[solved_attr]) if selected]
            nodes = {node_id: (z,y,x) for z,y,x,node_id, selected in\
                               zip(nodes["z"], nodes["y"], nodes["x"], nodes["id"], nodes[selected_attr]) if selected}

        else:
            edges = [(u, v, selected, solved) for u, v, selected, solved in
                     zip(edges["u"], edges["v"], edges[selected_attr],
                         edges[solved_attr])]
    else:
        edges = []

    return nodes, edges
Exemple #11
0
def solve_in_block(
    db_host,
    skeletonization_db,
    subsampled_skeletonization_db,
    time_limit,
    solve_number,
    graph_number,
    location_attr,
    u_name,
    v_name,
    **kwargs,
):
    logger.info("Solve in block")

    subsampled_provider = MongoDbGraphProvider(subsampled_skeletonization_db,
                                               db_host,
                                               mode="r+",
                                               directed=True)

    skeletonization_provider = MongoDbGraphProvider(skeletonization_db,
                                                    db_host,
                                                    mode="r+")

    client = daisy.Client()

    while True:
        logger.info("Acquire block")
        block = client.acquire_block()

        if not block:
            return 0

        logger.debug("Solving in block %s", block)

        if check_function(
                block,
                "solve_s{}".format(solve_number),
                subsampled_skeletonization_db,
                db_host,
        ):
            client.release_block(block, 0)
            continue

        start_time = time.time()
        skeletonization = skeletonization_provider.get_graph(
            block.read_roi, node_inclusion="dangling", edge_inclusion="both")
        # anything in matched was solved previously and must be maintained.
        pre_solved = subsampled_provider.get_graph(block.read_roi,
                                                   node_inclusion="dangling",
                                                   edge_inclusion="both")

        # if len(skeletonization.nodes()) > 10_000:
        #     to_remove = set(skeletonization.nodes()) - set(pre_solved.nodes())
        #     skeletonization.remove_nodes_from(to_remove)
        #     logger.info(f"Solving for {len(skeletonization.nodes())} would take too long")
        #     logger.info(f"Ignoring {len(to_remove)} nodes and skipping this block!")

        logger.info(
            f"Reading skeletonization with {len(skeletonization.nodes)} nodes and "
            +
            f"{len(skeletonization.edges)} edges took {time.time() - start_time} seconds"
        )

        if len(skeletonization.nodes) == 0:
            logger.info(
                f"No consensus nodes in roi {block.read_roi}. Skipping")
            write_done(block, f"solve_s{solve_number}",
                       subsampled_skeletonization_db, db_host)
            client.release_block(block, 0)
            continue

        logger.info("PreProcess...")
        start_time = time.time()

        logger.info(
            f"Skeletoniation has {len(skeletonization.nodes)} nodes "
            f"and {len(skeletonization.edges)} edges before subsampling")

        num_removed = remove_line_nodes(skeletonization, location_attr)
        logger.info(f"Removed {num_removed} nodes from skeletonization!")

        num_nodes, num_edges = write_matched(
            db_host,
            subsampled_skeletonization_db,
            block,
            skeletonization,
            pre_solved,
            location_attr,
            u_name,
            v_name,
        )

        logger.info(
            f"Writing matched graph with {num_nodes} nodes and {num_edges} edges "
            f"took {time.time()-start_time} seconds")

        logger.info("Write done")
        write_done(
            block,
            "solve_s{}".format(solve_number),
            subsampled_skeletonization_db,
            db_host,
        )

        logger.info("Release block")
        client.release_block(block, 0)

    return 0
class GlobalConnectedComponents(BatchFilter):
    def __init__(
        self,
        db_name,
        db_host,
        collection,
        graph,
        size_attr,
        component_attr,
        read_size,
        write_size,
        mode="r+",
    ):
        self.db_name = db_name
        self.db_host = db_host
        self.collection = collection
        self.component_collection = f"{self.collection}_components"
        self.component_edge_collection = f"{self.collection}_component_edges"
        self.client = None
        self.mode = mode

        self.graph = graph
        self.size_attr = size_attr
        self.component_attr = component_attr
        self.read_size = read_size
        self.write_size = write_size

        self.context = (read_size - write_size) / 2

        assert self.read_size == self.write_size + self.context * 2

    def setup(self):
        self.updates(self.graph, self.spec[self.graph].copy())

        if self.client is None:
            self.client = MongoDbGraphProvider(
                self.db_name,
                self.db_host,
                mode=self.mode,
                directed=False,
                nodes_collection=self.component_collection,
                edges_collection=self.component_edge_collection,
            )

    def prepare(self, request):
        deps = BatchRequest()
        deps[self.graph] = request[self.graph].copy()
        assert (request[self.graph].roi.get_shape() == self.read_size
                ), f"Got wrong size graph in request"

        return deps

    def process(self, batch, request):
        g = batch[self.graph].to_nx_graph()

        logger.debug(
            f"{self.name()} got graph with {g.number_of_nodes()} nodes, and "
            f"{g.number_of_edges()} edges!")

        write_roi = batch[self.graph].spec.roi.grow(-self.context,
                                                    -self.context)

        contained_nodes = [
            node for node, attr in g.nodes.items()
            if write_roi.contains(attr["location"])
        ]
        contained_components = set(g.nodes[n][self.component_attr]
                                   for n in contained_nodes)

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        component_graph = self.client.get_graph(roi=write_roi,
                                                node_inclusion="dangling",
                                                edge_inclusion="either")

        for node in contained_nodes:
            attrs = g.nodes[node]
            block_component_id = attrs[self.component_attr]
            global_component_id = component_graph.nodes[block_component_id][
                self.component_attr]
            attrs[self.component_attr] = global_component_id
            attrs[self.size_attr] = component_graph.nodes[block_component_id][
                self.size_attr]

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        outputs = Batch()
        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())

        return outputs
Exemple #13
0
def extract_edges_in_block(db_name, db_host, soft_mask_container,
                           soft_mask_dataset, distance_threshold,
                           evidence_threshold, graph_number, block):

    graph_provider = MongoDbGraphProvider(
        db_name,
        db_host,
        mode='r+',
        position_attribute=['z', 'y', 'x'],
        directed=False,
        edges_collection='edges_g{}'.format(graph_number))

    if check_function(graph_provider.database, block,
                      "edges_g{}".format(graph_number)):
        return 0

    logger.debug("Finding edges in %s, reading from %s", block.write_roi,
                 block.read_roi)

    start = time.time()

    soft_mask_array = daisy.open_ds(soft_mask_container, soft_mask_dataset)

    graph = graph_provider[block.read_roi.intersect(soft_mask_array.roi)]

    if graph.number_of_nodes() == 0:
        logger.info("No nodes in roi %s. Skipping", block.read_roi)
        write_done(graph_provider.database, block,
                   'edges_g{}'.format(graph_number))
        return 0

    logger.debug("Read %d candidates in %.3fs", graph.number_of_nodes(),
                 time.time() - start)

    start = time.time()
    """
    candidates = [(candidate_id, 
                   np.array([data[d] for d in ['z', 'y', 'x']])) 
                   for candidate_id, data in graph.nodes(data=True) if 'z' in data]
    """
    candidates = np.array(
        [[candidate_id] + [data[d] for d in ['z', 'y', 'x']]
         for candidate_id, data in graph.nodes(data=True) if 'z' in data],
        dtype=np.uint64)

    kdtree_start = time.time()
    kdtree = KDTree([[candidate[1], candidate[2], candidate[3]]
                     for candidate in candidates])
    #kdtree = KDTree(candidates[])
    pairs = kdtree.query_pairs(distance_threshold, p=2.0, eps=0)
    logger.debug("Query pairs in %.3fs", time.time() - kdtree_start)

    soft_mask_array = daisy.open_ds(soft_mask_container, soft_mask_dataset)

    voxel_size = np.array(soft_mask_array.voxel_size, dtype=np.uint32)
    soft_mask_roi = block.read_roi.snap_to_grid(
        voxel_size=voxel_size).intersect(soft_mask_array.roi)
    soft_mask_array_data = soft_mask_array.to_ndarray(roi=soft_mask_roi)

    sm_dtype = soft_mask_array_data.dtype
    if sm_dtype == np.uint8:  # standard pipeline pm 0-255
        pass
    elif sm_dtype == np.float32 or sm_dtype == np.float64:
        if not (soft_mask_array_data.min() >= 0
                and soft_mask_array_data.max() <= 1):
            raise ValueError(
                "Provided soft_mask has dtype float but not in range [0,1], abort"
            )
        else:
            soft_mask_array_data *= 255
    else:
        raise ValueError("Soft mask dtype {} not understood".format(sm_dtype))

    soft_mask_array_data = soft_mask_array_data.astype(np.float64)

    if evidence_threshold is not None:
        soft_mask_array_data = (soft_mask_array_data >= evidence_threshold *
                                255).astype(np.float64) * 255

    offset = np.array(np.array(soft_mask_roi.get_offset()) / voxel_size,
                      dtype=np.uint64)
    evidence_start = time.time()

    if pairs:
        pairs = np.array(list(pairs), dtype=np.uint64)
        evidence_array = cpp_get_evidence(candidates, pairs,
                                          soft_mask_array_data, offset,
                                          voxel_size)
        graph.add_weighted_edges_from(evidence_array, weight='evidence')

        logger.debug("Accumulate evidence in %.3fs",
                     time.time() - evidence_start)

        logger.debug("Found %d edges", graph.number_of_edges())

        logger.debug("Extracted edges in %.3fs", time.time() - start)

        start = time.time()

        graph.write_edges(block.write_roi)

        logger.debug("Wrote edges in %.3fs", time.time() - start)
    else:
        logger.debug("No pairs in block, skip")

    write_done(graph_provider.database, block,
               'edges_g{}'.format(graph_number))
    return 0
import itertools
from pathlib import Path
import logging
from tqdm import tqdm

from config_parser import read_data_config

logging.basicConfig(level=logging.INFO)

config = read_data_config(Path("config.ini"))

print(config["sample"])

consensus_provider = MongoDbGraphProvider(config["consensus_db"],
                                          config["db_host"],
                                          mode="r+",
                                          directed=True)

subdivided_provider = MongoDbGraphProvider(config["subdivided_db"],
                                           config["db_host"],
                                           mode="w",
                                           directed=True)

consensus = consensus_provider.get_graph(config["roi"],
                                         node_inclusion="dangling",
                                         edge_inclusion="either")

print(f"Read consensus with {len(consensus.nodes)} nodes")

subdivided = subdivided_provider.get_graph(config["roi"],
                                           node_inclusion="dangling",
class DaisyGraphProvider(BatchProvider):
    """
    See documentation for mongo graph provider at
    https://github.com/funkelab/daisy/blob/0.3-dev/daisy/persistence/mongodb_graph_provider.py#L17
    """
    def __init__(
        self,
        dbname: str,
        url: str,
        points: List[GraphKey],
        graph_specs: Optional[Union[GraphSpec, List[GraphSpec]]] = None,
        directed: bool = False,
        total_roi: Roi = None,
        nodes_collection: str = "nodes",
        edges_collection: str = "edges",
        meta_collection: str = "meta",
        endpoint_names: Tuple[str, str] = ("u", "v"),
        position_attribute: str = "position",
        node_attrs: Optional[List[str]] = None,
        edge_attrs: Optional[List[str]] = None,
        nodes_filter: Optional[Dict[str, Any]] = None,
        edges_filter: Optional[Dict[str, Any]] = None,
        edge_inclusion: str = "either",
        node_inclusion: str = "dangling",
        fail_on_inconsistent_node: bool = False,
    ):
        self.points = points
        graph_specs = (graph_specs if graph_specs is not None else GraphSpec(
            Roi(Coordinate([None] * 3), Coordinate([None] * 3)),
            directed=False))
        specs = (graph_specs if isinstance(graph_specs, list)
                 and len(graph_specs) == len(points) else [graph_specs] *
                 len(points))
        self.specs = {key: spec for key, spec in zip(points, specs)}

        self.directed = directed
        self.nodes_collection = nodes_collection
        self.edges_collection = edges_collection
        self.meta_collection = meta_collection
        self.endpoint_names = endpoint_names
        self.position_attribute = position_attribute

        self.position_attribute = position_attribute
        self.node_attrs = node_attrs
        self.edge_attrs = edge_attrs
        self.nodes_filter = nodes_filter
        self.edges_filter = edges_filter

        self.edge_inclusion = edge_inclusion
        self.node_inclusion = node_inclusion

        self.dbname = dbname
        self.url = url
        self.nodes_collection = nodes_collection

        self.fail_on_inconsistent_node = fail_on_inconsistent_node

        self.graph_provider = None

    def setup(self):
        for key, spec in self.specs.items():
            self.provides(key, spec)

        if self.graph_provider is None:
            self.graph_provider = MongoDbGraphProvider(
                self.dbname,
                self.url,
                mode="r+",
                directed=self.directed,
                total_roi=None,
                nodes_collection=self.nodes_collection,
                edges_collection=self.edges_collection,
                meta_collection=self.meta_collection,
                endpoint_names=self.endpoint_names,
                position_attribute=self.position_attribute,
            )

    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion=self.edge_inclusion,
                node_inclusion=self.node_inclusion,
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            failed_nodes = []

            for node, attrs in requested_graph.nodes.items():
                try:
                    attrs["location"] = np.array(
                        attrs[self.position_attribute], dtype=np.float32)
                except KeyError:
                    logger.warning(
                        f"node: {node} was written (probably part of an edge), but never given coordinates!"
                    )
                    failed_nodes.append(node)
                attrs["id"] = node

            for node in failed_nodes:
                if self.fail_on_inconsistent_node:
                    raise ValueError(
                        f"Mongodb contains node {node} without location! "
                        f"It was probably written as part of an edge")
                requested_graph.remove_node(node)

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            points = Graph.from_nx_graph(requested_graph, spec)
            points.relabel_connected_components()
            points.crop(spec.roi)
            batch[key] = points

            logger.debug(f"{key} with {len(list(points.nodes))} nodes")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Exemple #16
0
def solve_in_block(db_host,
                   db_name,
                   evidence_factor,
                   comb_angle_factor,
                   start_edge_prior,
                   selection_cost,
                   time_limit,
                   solve_number,
                   graph_number,
                   selected_attr="selected",
                   solved_attr="solved",
                   **kwargs):

    print("Solve in block")

    graph_provider = MongoDbGraphProvider(
        db_name,
        db_host,
        mode='r+',
        position_attribute=['z', 'y', 'x'],
        edges_collection="edges_g{}".format(graph_number))

    client = daisy.Client()

    while True:
        print("Acquire block")
        block = client.acquire_block()

        if not block:
            return

        logger.debug("Solving in block %s", block)

        if check_function(block, 'solve_s{}'.format(solve_number), db_name,
                          db_host):
            client.release_block(block, 0)
            continue

        start_time = time.time()
        graph = graph_provider.get_graph(block.read_roi)

        num_nodes = graph.number_of_nodes()
        num_edges = graph.number_of_edges()
        logger.info(
            "Reading graph with %d nodes and %d edges took %s seconds" %
            (num_nodes, num_edges, time.time() - start_time))

        if num_edges == 0:
            logger.info("No edges in roi %s. Skipping" % block.read_roi)
            write_done(block, 'solve_s{}'.format(solve_number), db_name,
                       db_host)
            client.release_block(block, 0)
            continue

        print("solve")
        solver = Solver(graph, evidence_factor, comb_angle_factor,
                        start_edge_prior, selection_cost, time_limit,
                        selected_attr, solved_attr)

        solver.initialize()
        solver.solve()

        start_time = time.time()
        graph.update_edge_attrs(block.write_roi,
                                attributes=[selected_attr, solved_attr])

        graph.update_node_attrs(block.write_roi,
                                attributes=[selected_attr, solved_attr])

        logger.info(
            "Updating attributes %s & %s for %d edges took %s seconds" %
            (selected_attr, solved_attr, num_edges, time.time() - start_time))

        print("Write done")
        write_done(block, 'solve_s{}'.format(solve_number), db_name, db_host)

        print("Release block")
        client.release_block(block, 0)

    return 0
class FilteredDaisyGraphProvider(BatchProvider):
    """
    See documentation for mongo graph provider at
    https://github.com/funkelab/daisy/blob/0.3-dev/daisy/persistence/mongodb_graph_provider.py#L17
    """
    def __init__(
        self,
        dbname: str,
        url: str,
        points: List[GraphKey],
        graph_specs: Optional[Union[GraphSpec, List[GraphSpec]]] = None,
        directed: bool = False,
        total_roi: Roi = None,
        nodes_collection: str = "nodes",
        edges_collection: str = "edges",
        meta_collection: str = "meta",
        endpoint_names: Tuple[str, str] = ("u", "v"),
        position_attribute: str = "position",
        node_attrs: Optional[List[str]] = None,
        edge_attrs: Optional[List[str]] = None,
        nodes_filter: Optional[Dict[str, Any]] = None,
        edges_filter: Optional[Dict[str, Any]] = None,
        num_nodes=100000,
        dist_attribute=None,
        min_dist=29000,
    ):
        self.points = points
        graph_specs = (graph_specs if graph_specs is not None else GraphSpec(
            Roi(Coordinate([None] * 3), Coordinate([None] * 3)),
            directed=False))
        specs = (graph_specs if isinstance(graph_specs, list)
                 and len(graph_specs) == len(points) else [graph_specs] *
                 len(points))
        self.specs = {key: spec for key, spec in zip(points, specs)}

        self.position_attribute = position_attribute
        self.node_attrs = node_attrs
        self.edge_attrs = edge_attrs
        self.nodes_filter = nodes_filter
        self.edges_filter = edges_filter
        self.dist_attribute = dist_attribute
        self.min_dist = min_dist

        self.num_nodes = num_nodes

        self.graph_provider = MongoDbGraphProvider(
            dbname,
            url,
            mode="r+",
            directed=directed,
            total_roi=None,
            nodes_collection=nodes_collection,
            edges_collection=edges_collection,
            meta_collection=meta_collection,
            endpoint_names=endpoint_names,
            position_attribute=position_attribute,
        )

    def setup(self):
        for key, spec in self.specs.items():
            self.provides(key, spec)

    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion="either",
                node_inclusion="dangling",
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )
            for node, attrs in list(requested_graph.nodes.items()):
                if self.dist_attribute in attrs:
                    if attrs[self.dist_attribute] < self.min_dist:
                        requested_graph.remove_node(node)

            logger.debug(
                f"{len(requested_graph.nodes)} nodes remaining after filtering by distance"
            )

            if len(requested_graph.nodes) > self.num_nodes:
                nodes = list(requested_graph.nodes)
                nodes_to_keep = set(random.sample(nodes, self.num_nodes))

                for node in list(requested_graph.nodes()):
                    if node not in nodes_to_keep:
                        requested_graph.remove_node(node)

            for node, attrs in requested_graph.nodes.items():
                attrs["location"] = np.array(attrs[self.position_attribute],
                                             dtype=np.float32)
                attrs["id"] = node

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            logger.debug(
                f"providing {key} with {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            points = Graph.from_nx_graph(requested_graph, spec)
            points.crop(spec.roi)
            batch[key] = points

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Exemple #18
0
class MongoWriteGraph(gp.BatchFilter):
    def __init__(
        self,
        mst,
        db_host,
        db_name,
        read_size,
        write_size,
        voxel_size,
        directed,
        mode="r+",
        collection="",
        position_attr="position",
        node_attrs: List[str] = None,
        edge_attrs: List[str] = None,
    ):

        self.mst = mst
        self.db_host = db_host
        self.db_name = db_name
        self.client = None

        self.voxel_size = voxel_size
        self.read_size = read_size
        self.write_size = write_size
        self.context = (read_size - write_size) / 2

        assert self.write_size + (self.context * 2) == self.read_size

        self.directed = directed
        self.mode = mode
        self.position_attr = position_attr

        self.nodes_collection = f"{collection}_nodes"
        self.edges_collection = f"{collection}_edges"

        self.node_attrs = node_attrs if node_attrs is not None else []
        self.node_attrs += [position_attr]
        self.edge_attrs = edge_attrs if edge_attrs is not None else []

    def setup(self):

        # Initialize client. Doesn't the daisy mongodb graph provider handle this?
        if self.client is None:
            self.client = MongoDbGraphProvider(
                self.db_name,
                self.db_host,
                mode=self.mode,
                directed=self.directed,
                nodes_collection=self.nodes_collection,
                edges_collection=self.edges_collection,
            )

        self.updates(self.mst, self.spec[self.mst].copy())

    def prepare(self, request):
        deps = gp.BatchRequest()
        deps[self.mst] = request[self.mst].copy()
        assert (request[self.mst].roi.get_shape() == self.read_size
                ), f"Got wrong size graph in request"
        return deps

    def process(self, batch, request):

        voxel_size = self.voxel_size

        read_roi = batch[self.mst].spec.roi
        write_roi = read_roi.grow(-self.context, -self.context)

        mst = batch[self.mst].to_nx_graph()

        # get saved graph with 'dangling' edges. Some nodes may not be in write_roi
        # Edges crossing boundary should only be saved if lower id is contained in write_roi
        mongo_graph = self.client.get_graph(
            write_roi,
            node_attrs=self.node_attrs,
            edge_attrs=self.edge_attrs,
            node_inclusion="dangling",
            edge_inclusion="either",
        )

        for node in list(mongo_graph.nodes()):
            if self.position_attr not in mongo_graph.nodes[node]:
                mongo_graph.remove_node(node)

        num_created_nodes = 0
        num_updated_nodes = 0
        for node, attrs in mst.nodes.items():
            loc = attrs["location"]
            pos = np.floor(loc / voxel_size)
            node_id = int(math.cantor_number(pos))

            # only update attributes of nodes in the write_roi
            if write_roi.contains(loc):
                attrs_to_save = {
                    self.position_attr: tuple(float(v) for v in loc)
                }
                for attr in self.node_attrs:
                    if attr in attrs:
                        value = attrs[attr]
                        if isinstance(value, Iterable):
                            value = tuple(float(v) for v in value)
                        if isinstance(value, np.float32) or (isinstance(
                                value, Iterable) and any(
                                    [isinstance(v, np.float32)
                                     for v in value])):
                            raise ValueError(
                                f"value of attr {attr} is a np.float32")
                        attrs_to_save[attr] = value

                if node_id in mongo_graph:
                    num_updated_nodes += 1
                    mongo_attrs = mongo_graph.nodes[node_id]
                    mongo_attrs.update(attrs_to_save)

                else:
                    num_created_nodes += 1
                    mongo_graph.add_node(node_id, **attrs_to_save)

        num_created_edges = 0
        num_updated_edges = 0
        for (u, v), attrs in mst.edges.items():
            u_loc = mst.nodes[u]["location"]
            u_pos = np.floor(u_loc / voxel_size)
            u_id = int(math.cantor_number(u_pos))

            v_loc = mst.nodes[v]["location"]
            v_pos = np.floor(v_loc / voxel_size)
            v_id = int(math.cantor_number(v_pos))

            # node a is the node with the lower id out of u, v
            a_loc, a_id, b_loc, b_id = ((u_loc, u_id, v_loc,
                                         v_id) if u_id < v_id else
                                        (v_loc, v_id, u_loc, u_id))
            # only write edge if a is contained
            # may create a node without any attributes if node creation is
            # dependent on your field of view and the neighboring block
            # fails to create the same node.
            if write_roi.contains(a_loc):

                attrs_to_save = {}
                for attr in self.edge_attrs:
                    if attr in attrs:
                        value = attrs[attr]
                        if isinstance(value, Iterable):
                            value = tuple(float(v) for v in value)
                        if isinstance(value, np.float32) or (isinstance(
                                value, Iterable) and any(
                                    [isinstance(v, np.float32)
                                     for v in value])):
                            raise ValueError(
                                f"value of attr {attr} is a np.float32")
                        attrs_to_save[attr] = value

                if (u_id, v_id) in mongo_graph.edges:
                    num_updated_edges += 1
                    mongo_attrs = mongo_graph.edges[(u_id, v_id)]
                    mongo_attrs.update(attrs_to_save)
                else:
                    num_created_edges += 1
                    mongo_graph.add_edge(u_id, v_id, **attrs_to_save)

        for node in mst.nodes:
            if node in mongo_graph.nodes and write_roi.contains(
                    mst.nodes[node]["location"]):
                assert all(
                    np.isclose(
                        mongo_graph.nodes[node][self.position_attr],
                        mst.nodes[node]["location"],
                    ))
                if write_roi.contains(
                        mongo_graph.nodes[node][self.position_attr]):
                    assert (mongo_graph.nodes[node]["component_id"] ==
                            mst.nodes[node]["component_id"])

        if len(mongo_graph.nodes) > 0:
            mongo_graph.write_nodes(roi=write_roi, attributes=self.node_attrs)

        if len(mongo_graph.edges) > 0:
            for edge, attrs in mongo_graph.edges.items():
                for attr in self.edge_attrs:
                    assert attr in attrs
            mongo_graph.write_edges(roi=write_roi, attributes=self.edge_attrs)
class MongoWriteComponents(BatchFilter):
    def __init__(
        self,
        graph,
        db_host,
        db_name,
        read_size,
        write_size,
        component_attr,
        node_attrs=None,
        collection="",
        mode="r+",
    ):
        self.graph = graph
        self.db_host = db_host
        self.db_name = db_name
        self.read_size = read_size
        self.write_size = write_size
        self.context = (read_size - write_size) / 2
        self.mode = mode
        self.component_attr = component_attr

        assert self.write_size + (self.context * 2) == self.read_size

        self.collection = collection
        self.component_collection = f"{self.collection}_components"
        self.component_edge_collection = f"{self.collection}_component_edges"

        if node_attrs is None:
            self.node_attrs = []
        else:
            self.node_attrs = node_attrs

        self.client = None

    def setup(self):
        self.updates(self.graph, self.spec[self.graph].copy())

        if self.client is None:
            self.client = MongoDbGraphProvider(
                self.db_name,
                self.db_host,
                mode=self.mode,
                directed=False,
                nodes_collection=self.component_collection,
                edges_collection=self.component_edge_collection,
            )

    def prepare(self, request):
        deps = BatchRequest()
        deps[self.graph] = request[self.graph].copy()
        assert (request[self.graph].roi.get_shape() == self.read_size
                ), f"Got wrong size graph in request"

        return deps

    def process(self, batch, request):

        graph = batch[self.graph].to_nx_graph()

        logger.debug(
            f"{self.name()} got graph with {graph.number_of_nodes()} nodes and "
            f"{graph.number_of_edges()} edges!")

        write_roi = batch[self.graph].spec.roi.grow(-self.context,
                                                    -self.context)

        contained_nodes = [
            node for node, attr in graph.nodes.items()
            if write_roi.contains(attr["location"])
        ]
        contained_components = set(graph.nodes[n][self.component_attr]
                                   for n in contained_nodes)

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        mongo_graph = self.client.get_graph(
            write_roi,
            node_attrs=[],
            edge_attrs=[],
            node_inclusion="dangling",
            edge_inclusion="either",
        )

        num_new_components = 0
        num_new_component_edges = 0
        for node, attrs in graph.nodes.items():
            if write_roi.contains(attrs["location"]):
                cc_id = attrs[self.component_attr]
                node_attrs = {}
                for attr in self.node_attrs:
                    node_attrs[attr] = attrs[attr]
                if cc_id not in mongo_graph.nodes:
                    num_new_components += 1
                    mongo_graph.add_node(cc_id,
                                         position=write_roi.get_center(),
                                         **node_attrs)
                else:
                    for k, v in node_attrs.items():
                        assert mongo_graph.nodes[cc_id][k] == v

        # Always write crossing edges if neither component id is None. Other end
        # point wont have a component id unless it has already been processed,
        # thus it is the duty of the second pass to write the edge, regardless
        # of whether the lower or upper end point is contained.
        for u, v in graph.edges:
            u_loc = graph.nodes[u]["location"]
            v_loc = graph.nodes[v]["location"]
            if write_roi.contains(u_loc) or write_roi.contains(v_loc):
                u_cc_id = graph.nodes[u].get(self.component_attr)
                v_cc_id = graph.nodes[v].get(self.component_attr)

                if u_cc_id is None or v_cc_id is None:
                    continue
                elif u_cc_id == v_cc_id:
                    continue
                elif (u_cc_id, v_cc_id) not in mongo_graph.edges:
                    num_new_component_edges += 1
                    mongo_graph.add_edge(u_cc_id, v_cc_id)

        logger.debug(
            f"{self.name()} writing {num_new_components} new components and "
            f"{num_new_component_edges} new component edges!")

        mongo_graph.write_nodes()
        mongo_graph.write_edges(ignore_v=False)