Ejemplo n.º 1
0
    def __init__(self):

        self.voxel_size = Coordinate((40, 4, 4))

        self.nodes = [
            # corners
            Node(id=1, location=np.array((-200, -200, -200))),
            Node(id=2, location=np.array((-200, -200, 199))),
            Node(id=3, location=np.array((-200, 199, -200))),
            Node(id=4, location=np.array((-200, 199, 199))),
            Node(id=5, location=np.array((199, -200, -200))),
            Node(id=6, location=np.array((199, -200, 199))),
            Node(id=7, location=np.array((199, 199, -200))),
            Node(id=8, location=np.array((199, 199, 199))),
            # center
            Node(id=9, location=np.array((0, 0, 0))),
            Node(id=10, location=np.array((-1, -1, -1))),
        ]

        self.graph_spec = GraphSpec(roi=Roi((-100, -100, -100), (300, 300, 300)))
        self.array_spec = ArraySpec(
                roi=Roi((-200, -200, -200), (400, 400, 400)), voxel_size=self.voxel_size
            )

        self.graph = Graph(self.nodes, [], self.graph_spec)
Ejemplo n.º 2
0
    def provide(self, request: BatchRequest) -> Batch:

        timing = Timing(self, "provide")
        timing.start()

        batch = Batch()

        for points_key in self.points:

            if points_key not in request:
                continue

            # Retrieve all points in the requested region using a kdtree for speed
            point_ids = self._query_kdtree(
                self.data.tree,
                (
                    np.array(request[points_key].roi.get_begin()),
                    np.array(request[points_key].roi.get_end()),
                ),
            )

            # To account for boundary crossings we must retrieve neighbors of all points
            # in the graph. This is too slow for large queries and less important
            points_subgraph = self._subgraph_points(
                point_ids,
                with_neighbors=len(point_ids) < len(self._graph.nodes) // 2)
            nodes = [
                Node(id=node, location=attrs["location"], attrs=attrs)
                for node, attrs in points_subgraph.nodes.items()
            ]
            edges = [Edge(u, v) for u, v in points_subgraph.edges]
            return_graph = Graph(nodes, edges,
                                 GraphSpec(roi=request[points_key].roi))

            # Handle boundary cases
            return_graph = return_graph.trim(request[points_key].roi)

            batch = Batch()
            batch.points[points_key] = return_graph

            logger.debug(
                "Graph points source provided {} points for roi: {}".format(
                    len(list(batch.points[points_key].nodes)),
                    request[points_key].roi))

            logger.debug(
                f"Providing {len(list(points_subgraph.nodes))} nodes to {points_key}"
            )

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 3
0
    def __init__(self):

        self.voxel_size = Coordinate((1, 1, 1))

        self.nodes = [
            # corners
            Node(id=1, location=np.array((0, 4, 4))),
            Node(id=2, location=np.array((9, 4, 4)))
        ]
        self.edges = [
            Edge(1, 2)
        ]

        self.graph_spec = GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10)))
        self.graph = Graph(self.nodes, self.edges, self.graph_spec)
Ejemplo n.º 4
0
    def provide(self, request):

        batch = Batch()

        if GraphKeys.TEST_POINTS in request:
            roi_points = request[GraphKeys.TEST_POINTS].roi

            contained_points = []
            for point in self.points:
                if roi_points.contains(point.location):
                    contained_points.append(copy.deepcopy(point))
            batch[GraphKeys.TEST_POINTS] = Graph(contained_points, [],
                                                 GraphSpec(roi=roi_points))

        if ArrayKeys.TEST_LABELS in request:
            roi_array = request[ArrayKeys.TEST_LABELS].roi
            roi_voxel = roi_array // self.spec[
                ArrayKeys.TEST_LABELS].voxel_size

            data = np.zeros(roi_voxel.get_shape(), dtype=np.uint32)
            data[:, ::2] = 100

            for point in self.points:
                loc = self.point_to_voxel(roi_array, point.location)
                data[loc] = point.id

            spec = self.spec[ArrayKeys.TEST_LABELS].copy()
            spec.roi = roi_array
            batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec)

        return batch
Ejemplo n.º 5
0
    def provide(self, request):

        batch = Batch()

        roi_graph = request[GraphKeys.TEST_GRAPH].roi
        roi_array = request[ArrayKeys.TEST_LABELS].roi
        roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size

        data = np.zeros(roi_voxel.get_shape(), dtype=np.uint32)
        data[:, ::2] = 100

        for node in self.nodes:
            loc = self.node_to_voxel(roi_array, node.location)
            data[loc] = node.id

        spec = self.spec[ArrayKeys.TEST_LABELS].copy()
        spec.roi = roi_array
        batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec)

        nodes = []
        for node in self.nodes:
            if roi_graph.contains(node.location):
                nodes.append(node)
        batch.graphs[GraphKeys.TEST_GRAPH] = Graph(
            nodes=nodes, edges=[], spec=GraphSpec(roi=roi_graph))

        return batch
Ejemplo n.º 6
0
class GraphTestSourceWithEdge(BatchProvider):
    def __init__(self):

        self.voxel_size = Coordinate((1, 1, 1))

        self.nodes = [
            # corners
            Node(id=1, location=np.array((0, 4, 4))),
            Node(id=2, location=np.array((9, 4, 4)))
        ]
        self.edges = [
            Edge(1, 2)
        ]

        self.graph_spec = GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10)))
        self.graph = Graph(self.nodes, self.edges, self.graph_spec)

    def setup(self):

        self.provides(
            GraphKeys.TEST_GRAPH_WITH_EDGE,
            self.graph_spec,
        )

    def provide(self, request):

        batch = Batch()

        graph_roi = request[GraphKeys.TEST_GRAPH_WITH_EDGE].roi

        batch.graphs[GraphKeys.TEST_GRAPH_WITH_EDGE] = self.graph.crop(graph_roi).trim(
            graph_roi
        )

        return batch
Ejemplo n.º 7
0
    def array_to_graph(self, array):
        # Override with local function
        sklearn.feature_extraction.image._make_edges_3d = _make_edges_3d

        s = array.data.shape
        # Identify connectivity
        t1 = time.time()
        adj_mat = grid_to_graph(n_x=s[0], n_y=s[1], n_z=s[2], mask=array.data)
        t2 = time.time()
        logger.debug(f"GRID TO GRAPH TOOK {t2-t1} SECONDS!")
        # Identify order of the voxels
        t1 = time.time()
        voxel_locs = compute_voxel_locs(
            mask=array.data,
            offset=array.spec.roi.get_begin(),
            scale=array.spec.voxel_size,
        )
        t2 = time.time()
        logger.debug(f"COMPUTING VOXEL LOCS TOOK {t2-t1} SECONDS!")

        t1 = time.time()
        nodes = [
            Node(node_id, voxel_loc) for node_id, voxel_loc in enumerate(voxel_locs)
        ]

        for a, b in zip(adj_mat.row, adj_mat.col):
            assert all(
                abs(voxel_locs[a] - voxel_locs[b]) <= array.spec.voxel_size
            ), f"{voxel_locs[a] - voxel_locs[b]}, {array.spec.voxel_size}"

        edges = [Edge(a, b) for a, b in zip(adj_mat.row, adj_mat.col) if a != b]
        graph = Graph(nodes, edges, GraphSpec(array.spec.roi, directed=False))
        t2 = time.time()
        logger.debug(f"BUILDING GRAPH TOOK {t2-t1} SECONDS!")
        return graph
Ejemplo n.º 8
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        min_bb = request[self.points].roi.get_begin()
        max_bb = request[self.points].roi.get_end()

        logger.debug("CSV points source got request for %s",
                     request[self.points].roi)

        point_filter = np.ones((self.data.shape[0], ), dtype=np.bool)
        for d in range(self.ndims):
            point_filter = np.logical_and(point_filter,
                                          self.data[:, d] >= min_bb[d])
            point_filter = np.logical_and(point_filter,
                                          self.data[:, d] < max_bb[d])

        points_data = self._get_points(point_filter)
        points_spec = GraphSpec(roi=request[self.points].roi.copy())

        batch = Batch()
        batch.graphs[self.points] = Graph(points_data, [], points_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 9
0
    def __setup_batch(self, batch_spec, chunk):
        '''Allocate a batch matching the sizes of ``batch_spec``, using
        ``chunk`` as template.'''

        batch = Batch()

        for (array_key, spec) in batch_spec.array_specs.items():
            roi = spec.roi
            voxel_size = self.spec[array_key].voxel_size

            # get the 'non-spatial' shape of the chunk-batch
            # and append the shape of the request to it
            array = chunk.arrays[array_key]
            shape = array.data.shape[:-roi.dims()]
            shape += (roi.get_shape() // voxel_size)

            spec = self.spec[array_key].copy()
            spec.roi = roi
            logger.info("allocating array of shape %s for %s", shape, array_key)
            batch.arrays[array_key] = Array(data=np.zeros(shape, dtype=spec.dtype),
                                                spec=spec)

        for (graph_key, spec) in batch_spec.graph_specs.items():
            roi = spec.roi
            spec = self.spec[graph_key].copy()
            spec.roi = roi
            batch.graphs[graph_key] = Graph(nodes=[], edges=[], spec=spec)

        logger.debug("setup batch to fill %s", batch)

        return batch
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion="either",
                node_inclusion="dangling",
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )
            for node, attrs in list(requested_graph.nodes.items()):
                if self.dist_attribute in attrs:
                    if attrs[self.dist_attribute] < self.min_dist:
                        requested_graph.remove_node(node)

            logger.debug(
                f"{len(requested_graph.nodes)} nodes remaining after filtering by distance"
            )

            if len(requested_graph.nodes) > self.num_nodes:
                nodes = list(requested_graph.nodes)
                nodes_to_keep = set(random.sample(nodes, self.num_nodes))

                for node in list(requested_graph.nodes()):
                    if node not in nodes_to_keep:
                        requested_graph.remove_node(node)

            for node, attrs in requested_graph.nodes.items():
                attrs["location"] = np.array(attrs[self.position_attribute],
                                             dtype=np.float32)
                attrs["id"] = node

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            logger.debug(
                f"providing {key} with {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            points = Graph.from_nx_graph(requested_graph, spec)
            points.crop(spec.roi)
            batch[key] = points

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 11
0
    def provide(self, request):
        batch = Batch()
        graph_roi = request[GraphKeys.PRESYN].roi

        batch.graphs[GraphKeys.PRESYN] = Graph([], [],
                                               GraphSpec(roi=graph_roi))
        return batch
Ejemplo n.º 12
0
class GraphTestSource3D(BatchProvider):
    def __init__(self):

        self.voxel_size = Coordinate((40, 4, 4))

        self.nodes = [
            # corners
            Node(id=1, location=np.array((-200, -200, -200))),
            Node(id=2, location=np.array((-200, -200, 199))),
            Node(id=3, location=np.array((-200, 199, -200))),
            Node(id=4, location=np.array((-200, 199, 199))),
            Node(id=5, location=np.array((199, -200, -200))),
            Node(id=6, location=np.array((199, -200, 199))),
            Node(id=7, location=np.array((199, 199, -200))),
            Node(id=8, location=np.array((199, 199, 199))),
            # center
            Node(id=9, location=np.array((0, 0, 0))),
            Node(id=10, location=np.array((-1, -1, -1))),
        ]

        self.graph_spec = GraphSpec(roi=Roi((-100, -100, -100), (300, 300, 300)))
        self.array_spec = ArraySpec(
                roi=Roi((-200, -200, -200), (400, 400, 400)), voxel_size=self.voxel_size
            )

        self.graph = Graph(self.nodes, [], self.graph_spec)

    def setup(self):

        self.provides(
            GraphKeys.TEST_GRAPH,
            self.graph_spec,
        )

        self.provides(
            ArrayKeys.GT_LABELS,
            self.array_spec,
        )

    def provide(self, request):

        batch = Batch()

        graph_roi = request[GraphKeys.TEST_GRAPH].roi

        batch.graphs[GraphKeys.TEST_GRAPH] = self.graph.crop(graph_roi).trim(graph_roi)

        roi_array = request[ArrayKeys.GT_LABELS].roi

        image = np.ones(roi_array.get_shape() / self.voxel_size, dtype=np.uint64)
        # label half of GT_LABELS differently
        depth = image.shape[0]
        image[0 : depth // 2] = 2

        spec = self.spec[ArrayKeys.GT_LABELS].copy()
        spec.roi = roi_array
        batch.arrays[ArrayKeys.GT_LABELS] = Array(image, spec=spec)

        return batch
Ejemplo n.º 13
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for key, spec in request.items():
            logger.debug(f"fetching {key} in roi {spec.roi}")
            requested_graph = self.graph_provider.get_graph(
                spec.roi,
                edge_inclusion=self.edge_inclusion,
                node_inclusion=self.node_inclusion,
                node_attrs=self.node_attrs,
                edge_attrs=self.edge_attrs,
                nodes_filter=self.nodes_filter,
                edges_filter=self.edges_filter,
            )
            logger.debug(
                f"got {len(requested_graph.nodes)} nodes and {len(requested_graph.edges)} edges"
            )

            failed_nodes = []

            for node, attrs in requested_graph.nodes.items():
                try:
                    attrs["location"] = np.array(
                        attrs[self.position_attribute], dtype=np.float32)
                except KeyError:
                    logger.warning(
                        f"node: {node} was written (probably part of an edge), but never given coordinates!"
                    )
                    failed_nodes.append(node)
                attrs["id"] = node

            for node in failed_nodes:
                if self.fail_on_inconsistent_node:
                    raise ValueError(
                        f"Mongodb contains node {node} without location! "
                        f"It was probably written as part of an edge")
                requested_graph.remove_node(node)

            if spec.directed:
                requested_graph = requested_graph.to_directed()
            else:
                requested_graph = requested_graph.to_undirected()

            points = Graph.from_nx_graph(requested_graph, spec)
            points.relabel_connected_components()
            points.crop(spec.roi)
            batch[key] = points

            logger.debug(f"{key} with {len(list(points.nodes))} nodes")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        # if pre and postsynaptic locations requested, their id : SynapseLocation dictionaries should be created
        # together s.t. the ids are unique and allow to find partner locations
        if GraphKey.PRESYN in request.points or GraphKey.POSTSYN in request.points:
            try:  # either both have the same roi, or only one of them is requested
                assert request.points[GraphKey.PRESYN] == request.points[
                    GraphKey.POSTSYN]
            except AssertionError:
                assert GraphKey.PRESYN not in request.points or GraphKey.POSTSYN not in request.points
            if GraphKey.PRESYN in request.points:
                presyn_points, postsyn_points = self.__read_syn_points(
                    roi=request.points[GraphKey.PRESYN])
            elif GraphKey.POSTSYN in request.points:
                presyn_points, postsyn_points = self.__read_syn_points(
                    roi=request.points[GraphKey.POSTSYN])

        for (points_key, roi) in request.points.items():
            # check if requested points can be provided
            if points_key not in self.spec:
                raise RuntimeError(
                    "Asked for %s which this source does not provide" %
                    points_key)
            # check if request roi lies within provided roi
            if not self.spec[points_key].roi.contains(roi):
                raise RuntimeError(
                    "%s's ROI %s outside of my ROI %s" %
                    (points_key, roi, self.spec[points_key].roi))

            logger.debug("Reading %s in %s..." % (points_key, roi))
            id_to_point = {
                GraphKey.PRESYN: presyn_points,
                GraphKey.POSTSYN: postsyn_points
            }[points_key]

            batch.points[points_key] = Graph(data=id_to_point,
                                             spec=GraphSpec(roi=roi))

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Ejemplo n.º 15
0
    def downsample(self, graph, sample_distance):
        g = graph.to_nx_graph()
        sampled_nodes = [n for n in g.nodes if g.degree(n) != 2]
        chain_nodes = [n for n in g.nodes if g.degree(n) == 2]
        chain_g = g.subgraph(chain_nodes)
        for cc in nx.connected_components(chain_g):

            if len(cc) < 2:
                continue
            cc_graph = chain_g.subgraph(cc)
            try:
                head, tail = [n for n in cc_graph.nodes if cc_graph.degree(n) == 1]
            except:
                head = list(cc_graph.nodes)[0]
                tail = head
            cable_len = 0
            previous_location = None
            for node in nx.algorithms.dfs_preorder_nodes(cc_graph, source=head):
                current_loc = cc_graph.nodes[node]["location"]
                if previous_location is not None:
                    diff = abs(previous_location - current_loc)
                    dist = np.linalg.norm(diff)
                    cable_len += dist

                previous_location = current_loc
                if node == tail:
                    break

            num_cuts = cable_len // self.sample_distance
            if num_cuts > 0:
                every = cable_len / (num_cuts + 1)
            else:
                every = float("inf")
            seen_cable = 0
            previous_location = None
            for node in nx.algorithms.dfs_preorder_nodes(cc_graph, source=head):
                current_loc = cc_graph.nodes[node]["location"]
                if previous_location is not None:
                    seen_cable += np.linalg.norm(previous_location - current_loc)
                previous_location = current_loc

                if seen_cable > every:
                    sampled_nodes.append(node)
                    seen_cable -= every

                if node == tail:
                    break
        downsampled = g.subgraph(sampled_nodes)
        return Graph.from_nx_graph(downsampled, graph.spec)
Ejemplo n.º 16
0
    def process(self, batch, request: BatchRequest):
        outputs = Batch()

        voxel_size = batch[self.embeddings].spec.voxel_size
        offset = batch[self.embeddings].spec.roi.get_begin()
        embeddings = batch[self.embeddings].data
        candidates = batch[self.mask].data
        _, depth, height, width = embeddings.shape
        coordinates = np.meshgrid(
            np.arange(0, (depth - 0.5) * self.coordinate_scale[0],
                      self.coordinate_scale[0]),
            np.arange(0, (height - 0.5) * self.coordinate_scale[1],
                      self.coordinate_scale[1]),
            np.arange(0, (width - 0.5) * self.coordinate_scale[2],
                      self.coordinate_scale[2]),
            indexing="ij",
        )
        for i in range(len(coordinates)):
            coordinates[i] = coordinates[i].astype(np.float32)

        embedding = np.concatenate([embeddings, coordinates], 0)
        embedding = np.transpose(embedding, axes=[1, 2, 3, 0])
        embedding = embedding.reshape(depth * width * height, -1)
        candidates = candidates.reshape(depth * width * height)
        embedding = embedding[candidates == 1, :]

        emst = mlp.emst(embedding)["output"]

        nodes = set()
        edges = []
        for u, v, distance in emst:
            u = int(u)
            pos_u = embedding[u][-3:] / self.coordinate_scale * voxel_size
            v = int(v)
            pos_v = embedding[v][-3:] / self.coordinate_scale * voxel_size
            nodes.add(Node(u, location=pos_u + offset))
            nodes.add(Node(v, location=pos_v + offset))
            edges.append(Edge(u, v, attrs={self.distance_attr: distance}))

        graph_spec = request[self.mst]
        graph_spec.directed = False

        outputs[self.mst] = Graph(nodes, edges, graph_spec)
        logger.debug(
            f"OUTPUTS CONTAINS MST WITH {len(list(outputs[self.mst].nodes))} NODES"
        )

        return outputs
Ejemplo n.º 17
0
    def process(self, batch, request: BatchRequest):
        outputs = Batch()

        voxel_size = batch[self.embeddings].spec.voxel_size
        roi = batch[self.embeddings].spec.roi
        offset = batch[self.embeddings].spec.roi.get_begin()
        spatial_dims = len(voxel_size)
        embeddings = batch[self.embeddings].data
        embeddings = embeddings.reshape((-1,) + embeddings.shape[-spatial_dims:])
        maxima = batch[self.mask].data
        maxima = maxima.reshape((-1,) + maxima.shape[-spatial_dims:])[0]

        try:
            minimax_edges = maximin.maximin_tree_query_hd(
                embeddings.astype(np.float64),
                maxima.astype(np.uint8),
                decimate=self.decimate,
            )
        except OSError as e:
            logger.warning(
                f"embeddings have shape: {embeddings.shape} and mask has shape: {maxima.shape}"
            )
            raise e
        maximin_id = itertools.count(start=0)

        nodes = set()
        edges = []
        ids = {}
        for a, b, cost in minimax_edges:
            a_id = ids.setdefault(a, next(maximin_id))
            b_id = ids.setdefault(b, next(maximin_id))
            a_loc = np.array(a) * voxel_size + offset
            b_loc = np.array(b) * voxel_size + offset
            assert roi.contains(a_loc), f"Roi {roi} does not contain {a_loc}"
            assert roi.contains(b_loc), f"Roi {roi} does not contain {b_loc}"

            nodes.add(Node(a_id, location=a_loc))
            nodes.add(Node(b_id, location=b_loc))
            edges.append(Edge(a_id, b_id, attrs={self.distance_attr: cost}))

        graph_spec = request[self.mst]
        graph_spec.directed = False

        outputs[self.mst] = Graph(nodes, edges, graph_spec)

        return outputs
Ejemplo n.º 18
0
    def provide(self, request: BatchRequest) -> Batch:
        random.seed(request.random_seed)
        np.random.seed(request.random_seed)

        timing = Timing(self, "provide")
        timing.start()

        batch = Batch()

        roi = request[self.points].roi

        region_shape = roi.get_shape()

        trees = []
        for _ in range(self.n_obj):
            for _ in range(100):
                root = np.random.random(len(region_shape)) * region_shape
                tree = self._grow_tree(
                    root, Roi((0,) * len(region_shape), region_shape)
                )
                if self.num_nodes[0] <= len(tree.nodes) <= self.num_nodes[1]:
                    break
            trees.append(tree)

        # logger.info("{} trees got, expected {}".format(len(trees), self.n_obj))

        trees_graph = nx.disjoint_union_all(trees)

        points = {
            node_id: Node(np.floor(node_attrs["pos"]) + roi.get_begin())
            for node_id, node_attrs in trees_graph.nodes.items()
        }

        batch[self.points] = Graph(points, request[self.points], list(trees_graph.edges))

        timing.stop()
        batch.profiling_stats.add(timing)

        # self._plot_tree(tree)

        return batch
Ejemplo n.º 19
0
    def process(self, batch, request: BatchRequest):
        outputs = Batch()

        voxel_size = batch[self.embeddings].spec.voxel_size
        offset = batch[self.embeddings].spec.roi.get_begin()
        embeddings = batch[self.embeddings].data
        candidates = batch[self.candidates].to_nx_graph()
        _, depth, height, width = embeddings.shape
        coordinates = np.meshgrid(
            np.arange(
                0, (depth - 0.5) * self.coordinate_scale[0], self.coordinate_scale[0]
            ),
            np.arange(
                0, (height - 0.5) * self.coordinate_scale[1], self.coordinate_scale[1]
            ),
            np.arange(
                0, (width - 0.5) * self.coordinate_scale[2], self.coordinate_scale[2]
            ),
            indexing="ij",
        )
        for i in range(len(coordinates)):
            coordinates[i] = coordinates[i].astype(np.float32)

        embedding = np.concatenate([embeddings, coordinates], 0)
        embedding = np.transpose(embedding, axes=[1, 2, 3, 0])
        embedding = embedding.reshape(depth * width * height, -1)

        nodes = set()
        edges = []

        for i, component in enumerate(nx.connected_components(candidates)):
            candidates_array = np.zeros((depth, height, width), dtype=bool)
            locs_to_ids = {}
            for node in component:
                attrs = candidates.nodes[node]
                location = attrs["location"]
                voxel_location = tuple(
                    int(x) for x in ((location - offset) // voxel_size)
                )
                locs_to_ids[voxel_location] = node
                candidates_array[voxel_location] = True
            candidates_array = candidates_array.reshape(-1)
            component_embedding = embedding[candidates_array, :]

            logger.info(
                f"processing component {i} with "
                f"{len(component)} candidates"
            )

            component_emst = mlp.emst(component_embedding)["output"]
            
            for u, v, distance in component_emst:
                u = int(u)
                pos_u = component_embedding[u][-3:] / self.coordinate_scale * voxel_size
                u_index = locs_to_ids[
                    tuple(int(np.round(x)) for x in (pos_u / voxel_size))
                ]
                v = int(v)
                pos_v = component_embedding[v][-3:] / self.coordinate_scale * voxel_size
                v_index = locs_to_ids[
                    tuple(int(np.round(x)) for x in (pos_v / voxel_size))
                ]
                nodes.add(Node(u_index, location=pos_u + offset))
                nodes.add(Node(v_index, location=pos_v + offset))
                edges.append(
                    Edge(u_index, v_index, attrs={self.distance_attr: distance})
                )

        graph_spec = request[self.mst]
        graph_spec.directed = False

        logger.info(
            f"candidates has {candidates.number_of_nodes()} nodes and "
            f"{candidates.number_of_edges()} edges and "
            f"{len(list(nx.connected_components(candidates)))} components"
        )

        outputs[self.mst] = Graph(nodes, edges, graph_spec)
        output_graph = outputs[self.mst].to_nx_graph()

        logger.info(
            f"output_graph has {output_graph.number_of_nodes()} nodes and "
            f"{output_graph.number_of_edges()} edges and "
            f"{len(list(nx.connected_components(output_graph)))} components"
        )

        logger.debug(
            f"OUTPUTS CONTAINS MST WITH {len(list(outputs[self.mst].nodes))} NODES"
        )

        return outputs
Ejemplo n.º 20
0
    def process(self, batch, request: BatchRequest):
        outputs = Batch()

        voxel_size = batch[self.intensities].spec.voxel_size
        roi = batch[self.intensities].spec.roi
        offset = batch[self.intensities].spec.roi.get_begin()
        spatial_dims = len(voxel_size)
        intensities = batch[self.intensities].data
        intensities = intensities.reshape((-1,) + intensities.shape[-spatial_dims:])[0]
        maxima = batch[self.mask].data
        maxima = maxima.reshape((-1,) + maxima.shape[-spatial_dims:])[0]

        logger.warning(f"{self.mask} has {maxima.sum()} maxima")

        if maxima.sum() < 2:
            minimax_edges = []
            if self.dense_mst is not None:
                dense_minimax_edges = []

        else:
            if self.dense_mst is not None:
                dense_minimax_edges, minimax_edges = maximin.maximin_tree_query_plus_decimated(
                    intensities.astype(np.float64),
                    maxima.astype(np.uint8),
                    threshold=self.threshold,
                )
            else:
                minimax_edges = maximin.maximin_tree_query(
                    intensities.astype(np.float64),
                    maxima.astype(np.uint8),
                    decimate=self.decimate,
                    threshold=self.threshold,
                )
        maximin_id = itertools.count(start=0)

        nodes = set()
        edges = []
        ids = {}
        for a, b, cost in minimax_edges:
            a_id = ids.setdefault(a, next(maximin_id))
            b_id = ids.setdefault(b, next(maximin_id))
            a_loc = np.array(a) * voxel_size + offset
            b_loc = np.array(b) * voxel_size + offset
            assert roi.contains(a_loc), f"Roi {roi} does not contain {a_loc}"
            assert roi.contains(b_loc), f"Roi {roi} does not contain {b_loc}"

            nodes.add(Node(a_id, location=a_loc))
            nodes.add(Node(b_id, location=b_loc))
            edges.append(Edge(a_id, b_id, attrs={self.distance_attr: 1 - cost}))

        graph_spec = request[self.mst]
        graph_spec.directed = False

        outputs[self.mst] = Graph(nodes, edges, graph_spec)

        if self.dense_mst is not None:
            maximin_id = itertools.count(start=0)

            nodes = set()
            edges = []
            ids = {}
            for a, b, cost in dense_minimax_edges:
                a_id = ids.setdefault(a, next(maximin_id))
                b_id = ids.setdefault(b, next(maximin_id))
                a_loc = np.array(a) * voxel_size + offset
                b_loc = np.array(b) * voxel_size + offset
                assert roi.contains(a_loc), f"Roi {roi} does not contain {a_loc}"
                assert roi.contains(b_loc), f"Roi {roi} does not contain {b_loc}"

                nodes.add(Node(a_id, location=a_loc))
                nodes.add(Node(b_id, location=b_loc))
                edges.append(Edge(a_id, b_id, attrs={self.distance_attr: 1 - cost}))

            graph_spec = request[self.dense_mst]
            graph_spec.directed = False

            outputs[self.dense_mst] = Graph(nodes, edges, graph_spec)

        return outputs