예제 #1
0
 def __init__(self, graph):
     self.graph = graph
     self.graph_spec = GraphSpec(roi=Roi((-10, -10, -10), (30, 30, 30)),
                                 directed=False)
     self.component_1_nodes = [
         Node(i, np.array([0, i, 0])) for i in range(10)
     ] + [Node(i + 10, np.array([i, 5, 0])) for i in range(10)]
     self.component_1_edges = ([Edge(i, i + 1) for i in range(9)] +
                               [Edge(i + 10, i + 11)
                                for i in range(9)] + [Edge(5, 15)])
     self.component_2_nodes = [
         Node(i + 20, np.array([i, 4, 0])) for i in range(10)
     ]
     self.component_2_edges = [Edge(i + 20, i + 21) for i in range(9)]
예제 #2
0
    def array_to_graph(self, array):
        # Override with local function
        sklearn.feature_extraction.image._make_edges_3d = _make_edges_3d

        s = array.data.shape
        # Identify connectivity
        t1 = time.time()
        adj_mat = grid_to_graph(n_x=s[0], n_y=s[1], n_z=s[2], mask=array.data)
        t2 = time.time()
        logger.debug(f"GRID TO GRAPH TOOK {t2-t1} SECONDS!")
        # Identify order of the voxels
        t1 = time.time()
        voxel_locs = compute_voxel_locs(
            mask=array.data,
            offset=array.spec.roi.get_begin(),
            scale=array.spec.voxel_size,
        )
        t2 = time.time()
        logger.debug(f"COMPUTING VOXEL LOCS TOOK {t2-t1} SECONDS!")

        t1 = time.time()
        nodes = [
            Node(node_id, voxel_loc) for node_id, voxel_loc in enumerate(voxel_locs)
        ]

        for a, b in zip(adj_mat.row, adj_mat.col):
            assert all(
                abs(voxel_locs[a] - voxel_locs[b]) <= array.spec.voxel_size
            ), f"{voxel_locs[a] - voxel_locs[b]}, {array.spec.voxel_size}"

        edges = [Edge(a, b) for a, b in zip(adj_mat.row, adj_mat.col) if a != b]
        graph = Graph(nodes, edges, GraphSpec(array.spec.roi, directed=False))
        t2 = time.time()
        logger.debug(f"BUILDING GRAPH TOOK {t2-t1} SECONDS!")
        return graph
예제 #3
0
    def provide(self, request: BatchRequest) -> Batch:

        timing = Timing(self, "provide")
        timing.start()

        batch = Batch()

        for points_key in self.points:

            if points_key not in request:
                continue

            # Retrieve all points in the requested region using a kdtree for speed
            point_ids = self._query_kdtree(
                self.data.tree,
                (
                    np.array(request[points_key].roi.get_begin()),
                    np.array(request[points_key].roi.get_end()),
                ),
            )

            # To account for boundary crossings we must retrieve neighbors of all points
            # in the graph. This is too slow for large queries and less important
            points_subgraph = self._subgraph_points(
                point_ids,
                with_neighbors=len(point_ids) < len(self._graph.nodes) // 2)
            nodes = [
                Node(id=node, location=attrs["location"], attrs=attrs)
                for node, attrs in points_subgraph.nodes.items()
            ]
            edges = [Edge(u, v) for u, v in points_subgraph.edges]
            return_graph = Graph(nodes, edges,
                                 GraphSpec(roi=request[points_key].roi))

            # Handle boundary cases
            return_graph = return_graph.trim(request[points_key].roi)

            batch = Batch()
            batch.points[points_key] = return_graph

            logger.debug(
                "Graph points source provided {} points for roi: {}".format(
                    len(list(batch.points[points_key].nodes)),
                    request[points_key].roi))

            logger.debug(
                f"Providing {len(list(points_subgraph.nodes))} nodes to {points_key}"
            )

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
    def _parse_swc(self, filename: Path):
        """Read one point per line. If ``ndims`` is 0, all values in one line
        are considered as the location of the point. If positive, only the
        first ``ndims`` are used. If negative, all but the last ``-ndims`` are
        used.
        """
        if "cube" in filename.name:
            return 0

        tree = parse_swc(
            filename,
            self.transform_file,
            resolution=[self.scale[i] for i in self.transpose],
            transpose=self.transpose,
        )

        assert len(list(nx.weakly_connected_components(tree))) == 1

        points = []

        for node, attrs in tree.nodes.items():
            if not self.ignore_human_nodes or attrs["human_placed"]:
                points.append(
                    Node(
                        id=node,
                        location=attrs["location"],
                        attrs=attrs,
                    ))

        human_edges = set()
        if self.ignore_human_nodes:
            for u, v in tree.edges:
                if u not in points or v not in points:
                    human_edges.add(Edge(u, v))
        edges = set(Edge(u, v) for (u, v) in tree.edges)
        if not self.directed:
            edges = edges | set(Edge(v, u) for u, v in tree.edges())
        self._add_points_to_source(points, edges - human_edges)

        return len(list(nx.weakly_connected_components(tree)))
예제 #5
0
    def __init__(self):

        self.voxel_size = Coordinate((1, 1, 1))

        self.nodes = [
            # corners
            Node(id=1, location=np.array((0, 4, 4))),
            Node(id=2, location=np.array((9, 4, 4)))
        ]
        self.edges = [
            Edge(1, 2)
        ]

        self.graph_spec = GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10)))
        self.graph = Graph(self.nodes, self.edges, self.graph_spec)
예제 #6
0
파일: emst.py 프로젝트: pattonw/neurolight
    def process(self, batch, request: BatchRequest):
        outputs = Batch()

        voxel_size = batch[self.embeddings].spec.voxel_size
        offset = batch[self.embeddings].spec.roi.get_begin()
        embeddings = batch[self.embeddings].data
        candidates = batch[self.mask].data
        _, depth, height, width = embeddings.shape
        coordinates = np.meshgrid(
            np.arange(0, (depth - 0.5) * self.coordinate_scale[0],
                      self.coordinate_scale[0]),
            np.arange(0, (height - 0.5) * self.coordinate_scale[1],
                      self.coordinate_scale[1]),
            np.arange(0, (width - 0.5) * self.coordinate_scale[2],
                      self.coordinate_scale[2]),
            indexing="ij",
        )
        for i in range(len(coordinates)):
            coordinates[i] = coordinates[i].astype(np.float32)

        embedding = np.concatenate([embeddings, coordinates], 0)
        embedding = np.transpose(embedding, axes=[1, 2, 3, 0])
        embedding = embedding.reshape(depth * width * height, -1)
        candidates = candidates.reshape(depth * width * height)
        embedding = embedding[candidates == 1, :]

        emst = mlp.emst(embedding)["output"]

        nodes = set()
        edges = []
        for u, v, distance in emst:
            u = int(u)
            pos_u = embedding[u][-3:] / self.coordinate_scale * voxel_size
            v = int(v)
            pos_v = embedding[v][-3:] / self.coordinate_scale * voxel_size
            nodes.add(Node(u, location=pos_u + offset))
            nodes.add(Node(v, location=pos_v + offset))
            edges.append(Edge(u, v, attrs={self.distance_attr: distance}))

        graph_spec = request[self.mst]
        graph_spec.directed = False

        outputs[self.mst] = Graph(nodes, edges, graph_spec)
        logger.debug(
            f"OUTPUTS CONTAINS MST WITH {len(list(outputs[self.mst].nodes))} NODES"
        )

        return outputs
예제 #7
0
    def process(self, batch, request: BatchRequest):
        outputs = Batch()

        voxel_size = batch[self.embeddings].spec.voxel_size
        roi = batch[self.embeddings].spec.roi
        offset = batch[self.embeddings].spec.roi.get_begin()
        spatial_dims = len(voxel_size)
        embeddings = batch[self.embeddings].data
        embeddings = embeddings.reshape((-1,) + embeddings.shape[-spatial_dims:])
        maxima = batch[self.mask].data
        maxima = maxima.reshape((-1,) + maxima.shape[-spatial_dims:])[0]

        try:
            minimax_edges = maximin.maximin_tree_query_hd(
                embeddings.astype(np.float64),
                maxima.astype(np.uint8),
                decimate=self.decimate,
            )
        except OSError as e:
            logger.warning(
                f"embeddings have shape: {embeddings.shape} and mask has shape: {maxima.shape}"
            )
            raise e
        maximin_id = itertools.count(start=0)

        nodes = set()
        edges = []
        ids = {}
        for a, b, cost in minimax_edges:
            a_id = ids.setdefault(a, next(maximin_id))
            b_id = ids.setdefault(b, next(maximin_id))
            a_loc = np.array(a) * voxel_size + offset
            b_loc = np.array(b) * voxel_size + offset
            assert roi.contains(a_loc), f"Roi {roi} does not contain {a_loc}"
            assert roi.contains(b_loc), f"Roi {roi} does not contain {b_loc}"

            nodes.add(Node(a_id, location=a_loc))
            nodes.add(Node(b_id, location=b_loc))
            edges.append(Edge(a_id, b_id, attrs={self.distance_attr: cost}))

        graph_spec = request[self.mst]
        graph_spec.directed = False

        outputs[self.mst] = Graph(nodes, edges, graph_spec)

        return outputs
예제 #8
0
    def process(self, batch, request: BatchRequest):
        outputs = Batch()

        voxel_size = batch[self.intensities].spec.voxel_size
        roi = batch[self.intensities].spec.roi
        offset = batch[self.intensities].spec.roi.get_begin()
        spatial_dims = len(voxel_size)
        intensities = batch[self.intensities].data
        intensities = intensities.reshape((-1,) + intensities.shape[-spatial_dims:])[0]
        maxima = batch[self.mask].data
        maxima = maxima.reshape((-1,) + maxima.shape[-spatial_dims:])[0]

        logger.warning(f"{self.mask} has {maxima.sum()} maxima")

        if maxima.sum() < 2:
            minimax_edges = []
            if self.dense_mst is not None:
                dense_minimax_edges = []

        else:
            if self.dense_mst is not None:
                dense_minimax_edges, minimax_edges = maximin.maximin_tree_query_plus_decimated(
                    intensities.astype(np.float64),
                    maxima.astype(np.uint8),
                    threshold=self.threshold,
                )
            else:
                minimax_edges = maximin.maximin_tree_query(
                    intensities.astype(np.float64),
                    maxima.astype(np.uint8),
                    decimate=self.decimate,
                    threshold=self.threshold,
                )
        maximin_id = itertools.count(start=0)

        nodes = set()
        edges = []
        ids = {}
        for a, b, cost in minimax_edges:
            a_id = ids.setdefault(a, next(maximin_id))
            b_id = ids.setdefault(b, next(maximin_id))
            a_loc = np.array(a) * voxel_size + offset
            b_loc = np.array(b) * voxel_size + offset
            assert roi.contains(a_loc), f"Roi {roi} does not contain {a_loc}"
            assert roi.contains(b_loc), f"Roi {roi} does not contain {b_loc}"

            nodes.add(Node(a_id, location=a_loc))
            nodes.add(Node(b_id, location=b_loc))
            edges.append(Edge(a_id, b_id, attrs={self.distance_attr: 1 - cost}))

        graph_spec = request[self.mst]
        graph_spec.directed = False

        outputs[self.mst] = Graph(nodes, edges, graph_spec)

        if self.dense_mst is not None:
            maximin_id = itertools.count(start=0)

            nodes = set()
            edges = []
            ids = {}
            for a, b, cost in dense_minimax_edges:
                a_id = ids.setdefault(a, next(maximin_id))
                b_id = ids.setdefault(b, next(maximin_id))
                a_loc = np.array(a) * voxel_size + offset
                b_loc = np.array(b) * voxel_size + offset
                assert roi.contains(a_loc), f"Roi {roi} does not contain {a_loc}"
                assert roi.contains(b_loc), f"Roi {roi} does not contain {b_loc}"

                nodes.add(Node(a_id, location=a_loc))
                nodes.add(Node(b_id, location=b_loc))
                edges.append(Edge(a_id, b_id, attrs={self.distance_attr: 1 - cost}))

            graph_spec = request[self.dense_mst]
            graph_spec.directed = False

            outputs[self.dense_mst] = Graph(nodes, edges, graph_spec)

        return outputs
예제 #9
0
    def process(self, batch, request: BatchRequest):
        outputs = Batch()

        voxel_size = batch[self.embeddings].spec.voxel_size
        offset = batch[self.embeddings].spec.roi.get_begin()
        embeddings = batch[self.embeddings].data
        candidates = batch[self.candidates].to_nx_graph()
        _, depth, height, width = embeddings.shape
        coordinates = np.meshgrid(
            np.arange(
                0, (depth - 0.5) * self.coordinate_scale[0], self.coordinate_scale[0]
            ),
            np.arange(
                0, (height - 0.5) * self.coordinate_scale[1], self.coordinate_scale[1]
            ),
            np.arange(
                0, (width - 0.5) * self.coordinate_scale[2], self.coordinate_scale[2]
            ),
            indexing="ij",
        )
        for i in range(len(coordinates)):
            coordinates[i] = coordinates[i].astype(np.float32)

        embedding = np.concatenate([embeddings, coordinates], 0)
        embedding = np.transpose(embedding, axes=[1, 2, 3, 0])
        embedding = embedding.reshape(depth * width * height, -1)

        nodes = set()
        edges = []

        for i, component in enumerate(nx.connected_components(candidates)):
            candidates_array = np.zeros((depth, height, width), dtype=bool)
            locs_to_ids = {}
            for node in component:
                attrs = candidates.nodes[node]
                location = attrs["location"]
                voxel_location = tuple(
                    int(x) for x in ((location - offset) // voxel_size)
                )
                locs_to_ids[voxel_location] = node
                candidates_array[voxel_location] = True
            candidates_array = candidates_array.reshape(-1)
            component_embedding = embedding[candidates_array, :]

            logger.info(
                f"processing component {i} with "
                f"{len(component)} candidates"
            )

            component_emst = mlp.emst(component_embedding)["output"]
            
            for u, v, distance in component_emst:
                u = int(u)
                pos_u = component_embedding[u][-3:] / self.coordinate_scale * voxel_size
                u_index = locs_to_ids[
                    tuple(int(np.round(x)) for x in (pos_u / voxel_size))
                ]
                v = int(v)
                pos_v = component_embedding[v][-3:] / self.coordinate_scale * voxel_size
                v_index = locs_to_ids[
                    tuple(int(np.round(x)) for x in (pos_v / voxel_size))
                ]
                nodes.add(Node(u_index, location=pos_u + offset))
                nodes.add(Node(v_index, location=pos_v + offset))
                edges.append(
                    Edge(u_index, v_index, attrs={self.distance_attr: distance})
                )

        graph_spec = request[self.mst]
        graph_spec.directed = False

        logger.info(
            f"candidates has {candidates.number_of_nodes()} nodes and "
            f"{candidates.number_of_edges()} edges and "
            f"{len(list(nx.connected_components(candidates)))} components"
        )

        outputs[self.mst] = Graph(nodes, edges, graph_spec)
        output_graph = outputs[self.mst].to_nx_graph()

        logger.info(
            f"output_graph has {output_graph.number_of_nodes()} nodes and "
            f"{output_graph.number_of_edges()} edges and "
            f"{len(list(nx.connected_components(output_graph)))} components"
        )

        logger.debug(
            f"OUTPUTS CONTAINS MST WITH {len(list(outputs[self.mst].nodes))} NODES"
        )

        return outputs