예제 #1
0
    def __init__(self):

        self.graph = Graph([
            Node(1, np.array([1, 1, 1])),
            Node(2, np.array([500, 500, 500])),
            Node(3, np.array([550, 550, 550])),
        ], [], GraphSpec(roi=Roi((-500, -500, -500), (1500, 1500, 1500))))
예제 #2
0
    def __init__(self):

        self.graph = Graph(
            [Node(id=1, location=np.array([50, 70, 100]))],
            [],
            GraphSpec(roi=Roi((-200, -200, -200), (400, 400, 478))),
        )
예제 #3
0
    def __init__(self):

        self.graph = Graph(
            [Node(id=1, location=np.array([500, 500, 500]))],
            [],
            GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000))),
        )
예제 #4
0
    def test_crop(self):
        g = Graph(self.nodes, self.edges, self.spec)

        sub_g = g.crop(Roi(Coordinate([1, 1, 1]), Coordinate([3, 3, 3])))
        self.assertEqual(g.spec.roi, self.spec.roi)
        self.assertEqual(sub_g.spec.roi,
                         Roi(Coordinate([1, 1, 1]), Coordinate([3, 3, 3])))

        sub_g.spec.directed = False
        self.assertTrue(g.spec.directed)
        self.assertFalse(sub_g.spec.directed)
예제 #5
0
    def __init__(self):

        self.dtype = float
        self.__vertices = [
            Node(id=1, location=np.array([1, 1, 1], dtype=self.dtype)),
            Node(id=2, location=np.array([500, 500, 500], dtype=self.dtype)),
            Node(id=3, location=np.array([550, 550, 550], dtype=self.dtype)),
        ]
        self.__edges = [Edge(1, 2), Edge(2, 3)]
        self.__spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]),
                                        Coordinate([1500, 1500, 1500])))
        self.graph = Graph(self.__vertices, self.__edges, self.__spec)
예제 #6
0
    def setup(self):
        roi = Roi(Coordinate([0] * len(self.size)), self.size)
        for points_key in self.points:
            self.provides(points_key, GraphSpec(roi=roi,
                                                directed=self.directed))

        k = min(self.size)
        nodes = [
            Node(id=i, location=np.array([i * k / self.num_points] * 3))
            for i in range(self.num_points)
        ]
        edges = [Edge(i, i + 1) for i in range(self.num_points - 1)]

        self.graph = Graph(nodes, edges,
                           GraphSpec(roi=roi, directed=self.directed))
예제 #7
0
    def provide(self, request):

        batch = Batch()

        roi_points = request[GraphKeys.TEST_POINTS].roi
        roi_array = request[ArrayKeys.TEST_LABELS].roi
        roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size

        data = np.zeros(roi_voxel.get_shape(), dtype=np.uint32)
        data[:, ::2] = 100

        for node in self.points:
            loc = self.point_to_voxel(roi_array, node.location)
            data[loc] = node.id

        spec = self.spec[ArrayKeys.TEST_LABELS].copy()
        spec.roi = roi_array
        batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec)

        points = []
        for node in self.points:
            if roi_points.contains(node.location):
                points.append(node)
        batch.graphs[GraphKeys.TEST_POINTS] = Graph(points, [],
                                                    GraphSpec(roi=roi_points))

        return batch
class TestSourceRandomLocation(BatchProvider):
    def __init__(self):

        self.graph = Graph(
            [
                Node(id=1, location=np.array([1, 1, 1])),
                Node(id=2, location=np.array([500, 500, 500])),
                Node(id=3, location=np.array([550, 550, 550])),
            ],
            [],
            GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000))),
        )

    def setup(self):

        self.provides(GraphKeys.TEST_GRAPH, self.graph.spec)

    def provide(self, request):

        batch = Batch()

        roi = request[GraphKeys.TEST_GRAPH].roi
        batch[GraphKeys.TEST_GRAPH] = self.graph.crop(roi).trim(roi)

        return batch
예제 #9
0
class TestPointSource(BatchProvider):
    def __init__(self, points: List[GraphKey], directed: bool,
                 size: Coordinate, num_points: int):
        self.points = points
        self.directed = directed
        self.size = size
        self.num_points = num_points

    def setup(self):
        roi = Roi(Coordinate([0] * len(self.size)), self.size)
        for points_key in self.points:
            self.provides(points_key, GraphSpec(roi=roi,
                                                directed=self.directed))

        k = min(self.size)
        nodes = [
            Node(id=i, location=np.array([i * k / self.num_points] * 3))
            for i in range(self.num_points)
        ]
        edges = [Edge(i, i + 1) for i in range(self.num_points - 1)]

        self.graph = Graph(nodes, edges,
                           GraphSpec(roi=roi, directed=self.directed))

    def provide(self, request: BatchRequest) -> Batch:
        batch = Batch()
        for points_key in self.points:
            if points_key in request:
                spec = request[points_key].copy()

                subgraph = self.graph.crop(roi=spec.roi)
                subgraph.relabel_connected_components()

                batch[points_key] = subgraph
        return batch
예제 #10
0
def test_nodes():

    initial_locations = {
        1: np.array([1, 1, 1], dtype=np.float32),
        2: np.array([500, 500, 500], dtype=np.float32),
        3: np.array([550, 550, 550], dtype=np.float32),
    }
    replacement_locations = {
        1: np.array([0, 0, 0], dtype=np.float32),
        2: np.array([50, 50, 50], dtype=np.float32),
        3: np.array([55, 55, 55], dtype=np.float32),
    }

    nodes = [
        Node(id=id, location=location)
        for id, location in initial_locations.items()
    ]
    edges = [Edge(1, 2), Edge(2, 3)]
    spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]),
                             Coordinate([1500, 1500, 1500])))
    graph = Graph(nodes, edges, spec)
    for node in graph.nodes:
        node.location = replacement_locations[node.id]

    for node in graph.nodes:
        assert all(np.isclose(node.location, replacement_locations[node.id]))
예제 #11
0
class ExampleGraphSource(BatchProvider):
    def __init__(self):

        self.dtype = float
        self.__vertices = [
            Node(id=1, location=np.array([1, 1, 1], dtype=self.dtype)),
            Node(id=2, location=np.array([500, 500, 500], dtype=self.dtype)),
            Node(id=3, location=np.array([550, 550, 550], dtype=self.dtype)),
        ]
        self.__edges = [Edge(1, 2), Edge(2, 3)]
        self.__spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]),
                                        Coordinate([1500, 1500, 1500])))
        self.graph = Graph(self.__vertices, self.__edges, self.__spec)

    def setup(self):

        self.provides(GraphKeys.TEST_GRAPH, self.__spec)

    def provide(self, request):

        batch = Batch()

        roi = request[GraphKeys.TEST_GRAPH].roi

        sub_graph = self.graph.crop(roi)

        batch[GraphKeys.TEST_GRAPH] = sub_graph

        return batch
예제 #12
0
    def process(self, batch, request):
        outputs = Batch()
        g = batch[self.graph].to_nx_graph()

        branch_points = [n for n in g.nodes if g.degree(n) > 2]

        for branch_point in branch_points:
            if g.is_directed():
                successors = list(g.successors(branch_point))
                predecessors = list(g.predecessors(branch_point))
                lowest = min(successors + predecessors)
                for successor in successors:
                    if successor != lowest:
                        g.remove_edge(branch_point, successor)
                for predecessor in predecessors:
                    if predecessor != lowest:
                        g.remove_edge(predecessor, branch_point)
            else:
                neighbors = sorted(list(g.neighbors(branch_point)))
                for neighbor in neighbors[1:]:
                    g.remove_edge(branch_point, neighbor)

        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())
        return outputs
예제 #13
0
    def test_shift_points5(self):
        data = [
            Node(id=0, location=np.array([3, 0])),
            Node(id=1, location=np.array([3, 2])),
            Node(id=2, location=np.array([3, 4])),
            Node(id=3, location=np.array([3, 6])),
            Node(id=4, location=np.array([3, 8])),
        ]
        spec = GraphSpec(Roi(offset=(0, 0), shape=(15, 10)))
        points = Graph(data, [], spec)
        request_roi = Roi(offset=(3, 0), shape=(9, 10))
        shift_array = np.array([[3, 0], [-3, 0], [0, 0], [-3, 0], [3, 0]],
                               dtype=int)

        lcm_voxel_size = Coordinate((3, 2))
        shifted_data = [
            Node(id=0, location=np.array([6, 0])),
            Node(id=2, location=np.array([3, 4])),
            Node(id=4, location=np.array([6, 8])),
        ]
        result = ShiftAugment.shift_points(
            points,
            request_roi,
            shift_array,
            shift_axis=1,
            lcm_voxel_size=lcm_voxel_size,
        )
        # print("test 4", result.data, shifted_data)
        self.assertTrue(self.points_equal(result.nodes, shifted_data))
        self.assertTrue(result.spec == GraphSpec(request_roi))
예제 #14
0
 def _empty_copy(self, base: Batch):
     add = Batch()
     for key, array in base.arrays.items():
         add[key] = Array(np.zeros_like(array.data),
                          spec=copy.deepcopy(array.spec))
     for key, points in base.points.items():
         add[key] = Graph([], [], spec=copy.deepcopy(points.spec))
     return add
예제 #15
0
    def provide(self, request):
        outputs = Batch()

        nodes = [
            Node(id=0, location=np.array((1, 1, 1))),
            Node(id=1, location=np.array((10, 10, 10))),
            Node(id=2, location=np.array((19, 19, 19))),
            Node(id=3, location=np.array((21, 21, 21))),
            Node(id=104, location=np.array((30, 30, 30))),
            Node(id=5, location=np.array((39, 39, 39))),
        ]
        edges = [Edge(0, 1), Edge(1, 2), Edge(3, 104), Edge(104, 5)]
        spec = self.spec[GraphKeys.RAW].copy()
        spec.roi = request[GraphKeys.RAW].roi
        graph = Graph(nodes, edges, spec)

        outputs[GraphKeys.RAW] = graph.crop(spec.roi)
        return outputs
예제 #16
0
    def process(self, batch, request):
        g = batch[self.graph].to_nx_graph()
        assert batch[self.graph].spec.roi.get_shape() == self.read_size

        logger.debug(
            f"{self.name()} got graph with {g.number_of_nodes()} nodes, and "
            f"{g.number_of_edges()} edges!")

        write_roi = batch[self.graph].spec.roi.grow(-self.context,
                                                    -self.context)

        cc_func = (nx.connected_components
                   if not g.is_directed() else nx.weakly_connected_components)

        for cc in cc_func(g):
            contained_nodes = [
                n for n in cc if write_roi.contains(g.nodes[n]["location"])
            ]
            if len(contained_nodes) == 0:
                continue
            else:
                cc_id = min(contained_nodes)
                cc_subgraph = g.subgraph(cc)

                # total edge length of this connected component in this write_roi
                total_edge_len = 0

                for u, v in cc_subgraph.edges:
                    u_loc = cc_subgraph.nodes[u]["location"]
                    v_loc = cc_subgraph.nodes[v]["location"]
                    edge_len = np.linalg.norm(u_loc - v_loc)
                    if write_roi.contains(u_loc) and write_roi.contains(v_loc):
                        total_edge_len += edge_len
                    elif write_roi.contains(u_loc) or write_roi.contains(
                            v_loc):
                        total_edge_len += edge_len / 2

                for u in contained_nodes:
                    attrs = cc_subgraph.nodes[u]
                    attrs[self.component_attr] = int(cc_id)
                    attrs[self.size_attr] = float(total_edge_len)

        count = 0
        for node, attrs in g.nodes.items():
            if write_roi.contains(attrs["location"]):
                assert self.component_attr in attrs
                count += 1

        logger.debug(
            f"{self.name()} updated component id of {count} nodes in write_roi"
        )

        outputs = Batch()
        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())

        return outputs
예제 #17
0
    def _shift_and_crop(self, points: Graph, array: Array,
                        direction: Coordinate, output_roi: Roi):
        # Shift and crop the array
        center = array.spec.roi.get_offset() + array.spec.roi.get_shape() // 2
        new_center = center + direction
        new_offset = new_center - output_roi.get_shape() // 2
        new_roi = Roi(new_offset, output_roi.get_shape())
        array = array.crop(new_roi)
        array.spec.roi = output_roi

        new_points_data = {}
        new_points_spec = points.spec
        new_points_spec.roi = new_roi
        new_points_graph = nx.DiGraph()

        # shift points and add them to a graph
        for point_id, point in points.data.items():
            if new_roi.contains(point.location):
                new_point = point.copy()
                new_point.location = (point.location - new_offset +
                                      output_roi.get_begin())
                new_points_graph.add_node(
                    new_point.point_id,
                    point_id=new_point.point_id,
                    parent_id=new_point.parent_id,
                    location=new_point.location,
                    label_id=new_point.label_id,
                    radius=new_point.radius,
                    point_type=new_point.point_type,
                )
                if points.data.get(
                        new_point.parent_id, False) and new_roi.contains(
                            points.data[new_point.parent_id].location):
                    new_points_graph.add_edge(new_point.parent_id,
                                              new_point.point_id)

        # relabel connected components
        for i, connected_component in enumerate(
                nx.weakly_connected_components(new_points_graph)):
            for node in connected_component:
                new_points_graph.nodes[node]["label_id"] = i

        # store new graph data in points
        new_points_data = {
            point_id: Node(
                point["location"],
                point_id=point["point_id"],
                point_type=point["point_type"],
                radius=point["radius"],
                parent_id=point["parent_id"],
                label_id=point["label_id"],
            )
            for point_id, point in new_points_graph.nodes.items()
        }
        points = Graph(new_points_data, new_points_spec)
        points.spec.roi = output_roi
        return points, array
예제 #18
0
 def provide(self, request):
     outputs = Batch()
     spec = self.graph_spec.copy()
     spec.roi = request[self.graph].roi
     outputs[self.graph] = Graph(
         self.component_1_nodes + self.component_2_nodes,
         self.component_1_edges + self.component_2_edges,
         spec,
     )
     return outputs
예제 #19
0
def test_transpose():
    voxel_size = Coordinate((20, 20))
    graph_key = GraphKey("GRAPH")
    array_key = ArrayKey("ARRAY")
    graph = Graph(
        [Node(id=1, location=np.array([450, 550]))],
        [],
        GraphSpec(roi=Roi((100, 200), (800, 600))),
    )
    data = np.zeros([40, 30])
    data[17, 17] = 1
    array = Array(
        data, ArraySpec(roi=Roi((100, 200), (800, 600)),
                        voxel_size=voxel_size))

    default_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[], transpose_only=[0, 1], transpose_probs=[0, 0]))

    transpose_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[], transpose_only=[0, 1], transpose_probs=[1, 1]))

    request = BatchRequest()
    request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300)))
    request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300)))
    with build(default_pipeline):
        expected_location = [450, 550]
        batch = default_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert (
            batch[array_key].data[node_voxel_index] == 1
        ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"

    with build(transpose_pipeline):
        expected_location = [410, 590]
        batch = transpose_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert (
            batch[array_key].data[node_voxel_index] == 1
        ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
예제 #20
0
    def test_shift_points1(self):
        data = [Node(id=1, location=np.array([0, 1]))]
        spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5)))
        points = Graph(data, [], spec)
        request_roi = Roi(offset=(0, 1), shape=(5, 3))
        shift_array = np.array([[0, -1], [0, -1], [0, 0], [0, 0], [0, 1]],
                               dtype=int)
        lcm_voxel_size = Coordinate((1, 1))

        shifted_points = Graph([], [], GraphSpec(request_roi))
        result = ShiftAugment.shift_points(
            points,
            request_roi,
            shift_array,
            shift_axis=0,
            lcm_voxel_size=lcm_voxel_size,
        )
        # print(result)
        self.assertTrue(self.points_equal(result.nodes, shifted_points.nodes))
        self.assertTrue(result.spec == GraphSpec(request_roi))
예제 #21
0
    def provide(self, request):

        batch = Batch()

        # have the pixels encode their position
        if ArrayKeys.RAW in request:

            # the z,y,x coordinates of the ROI
            roi = request[ArrayKeys.RAW].roi
            roi_voxel = roi // self.spec[ArrayKeys.RAW].voxel_size
            meshgrids = np.meshgrid(
                range(roi_voxel.get_begin()[0], roi_voxel.get_end()[0]),
                range(roi_voxel.get_begin()[1], roi_voxel.get_end()[1]),
                range(roi_voxel.get_begin()[2], roi_voxel.get_end()[2]),
                indexing="ij",
            )
            data = meshgrids[0] + meshgrids[1] + meshgrids[2]

            spec = self.spec[ArrayKeys.RAW].copy()
            spec.roi = roi
            batch.arrays[ArrayKeys.RAW] = Array(data, spec)

        if ArrayKeys.GT_LABELS in request:
            roi = request[ArrayKeys.GT_LABELS].roi
            roi_voxel_shape = (
                roi // self.spec[ArrayKeys.GT_LABELS].voxel_size
            ).get_shape()
            data = np.ones(roi_voxel_shape)
            data[roi_voxel_shape[0] // 2 :, roi_voxel_shape[1] // 2 :, :] = 2
            data[roi_voxel_shape[0] // 2 :, -(roi_voxel_shape[1] // 2) :, :] = 3
            spec = self.spec[ArrayKeys.GT_LABELS].copy()
            spec.roi = roi
            batch.arrays[ArrayKeys.GT_LABELS] = Array(data, spec)

        if GraphKeys.PRESYN in request:
            data_presyn, data_postsyn = self.__get_pre_and_postsyn_locations(
                roi=request[GraphKeys.PRESYN].roi
            )
        elif GraphKeys.POSTSYN in request:
            data_presyn, data_postsyn = self.__get_pre_and_postsyn_locations(
                roi=request[GraphKeys.POSTSYN].roi
            )

        voxel_size_points = self.spec[ArrayKeys.RAW].voxel_size
        for (graph_key, spec) in request.graph_specs.items():
            if graph_key == GraphKeys.PRESYN:
                data = data_presyn
            if graph_key == GraphKeys.POSTSYN:
                data = data_postsyn
            batch.graphs[graph_key] = Graph(
                list(data.values()), [], GraphSpec(spec.roi)
            )

        return batch
예제 #22
0
    def process(self, batch, request):
        outputs = Batch()

        g = batch[self.graph].to_nx_graph()

        for node, attrs in list(g.nodes.items()):
            if attrs[self.size_attr] < self.size_threshold:
                g.remove_node(node)

        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())
        return outputs
예제 #23
0
 def provide(self, request):
     outputs = Batch()
     if self.n % self.every == 0:
         assert GraphKeys.TEST_GRAPH in request
     else:
         assert GraphKeys.TEST_GRAPH not in request
     for key, spec in request.items():
         if isinstance(key, GraphKey):
             outputs[key] = Graph([], [], spec)
         if isinstance(key, ArrayKey):
             spec.voxel_size = self.spec[key].voxel_size
             outputs[key] = Array(
                 np.zeros(spec.roi.get_shape(), dtype=spec.dtype), spec)
     self.n += 1
     return outputs
예제 #24
0
    def test_neighbors(self):
        # directed
        d_spec = self.spec
        # undirected
        ud_spec = self.spec
        ud_spec.directed = False

        directed = Graph(self.nodes, self.edges, d_spec)
        undirected = Graph(self.nodes, self.edges, ud_spec)

        self.assertCountEqual(directed.neighbors(self.nodes[0]),
                              undirected.neighbors(self.nodes[0]))
예제 #25
0
파일: scan.py 프로젝트: yajivunev/gunpowder
    def provide(self, request):

        # print("ScanTestSource: Got request " + str(request))

        batch = Batch()

        # have the pixels encode their position
        for (array_key, spec) in request.array_specs.items():

            roi = spec.roi
            roi_voxel = roi // self.spec[array_key].voxel_size
            # print("ScanTestSource: Adding " + str(array_key))

            # the z,y,x coordinates of the ROI
            meshgrids = np.meshgrid(range(roi_voxel.get_begin()[0],
                                          roi_voxel.get_end()[0]),
                                    range(roi_voxel.get_begin()[1],
                                          roi_voxel.get_end()[1]),
                                    range(roi_voxel.get_begin()[2],
                                          roi_voxel.get_end()[2]),
                                    indexing='ij')
            data = meshgrids[0] + meshgrids[1] + meshgrids[2]

            # print("Roi is: " + str(roi))

            spec = self.spec[array_key].copy()
            spec.roi = roi
            batch.arrays[array_key] = Array(data, spec)

        for graph_key, spec in request.graph_specs.items():
            # node at x, y, z if x%100==0, y%10==0, z%10==0
            nodes = []
            start = spec.roi.get_begin() - tuple(
                x % s for x, s in zip(spec.roi.get_begin(), [100, 10, 10]))
            for i, j, k in itertools.product(*[
                    range(a, b, s) for a, b, s in zip(
                        start, spec.roi.get_end(), [100, 10, 10])
            ]):
                location = np.array([i, j, k])
                if spec.roi.contains(location):
                    nodes.append(
                        Node(id=coordinate_to_id(i, j, k), location=location))
            batch.graphs[graph_key] = Graph(nodes, [], spec)

        return batch
예제 #26
0
    def process(self, batch, request):
        outputs = Batch()

        g = batch[self.graph].to_nx_graph()
        logger.debug(f"g has {len(g.nodes())} nodes pre filtering")

        cc_func = (nx.weakly_connected_components
                   if g.is_directed() else nx.connected_components)

        ccs = cc_func(g)
        for cc in list(ccs):
            finished = False
            while not finished:
                finished = True
                g_component = g.subgraph(cc)

                branch_points = [
                    n for n in g_component.nodes if g_component.degree(n) > 2
                ]
                logger.debug(
                    f"Connected component has {len(g_component.nodes)} nodes and {len(branch_points)} branch points"
                )
                removed = 0
                for i, branch_point in enumerate(branch_points):
                    remaining = [n for n in cc if n != branch_point]
                    remaining_g = g_component.subgraph(remaining)

                    remaining_ccs = list(cc_func(remaining_g))
                    logger.debug(
                        f"After removing branch point {i}, cc is broken into pieces sized: {[len(x) for x in remaining_ccs]}"
                    )
                    for remaining_cc in list(remaining_ccs):
                        if (self.cable_len(g,
                                           list(remaining_cc) + [branch_point])
                                <= self.node_threshold):
                            for n in remaining_cc:
                                g.remove_node(n)
                                finished = False
                                removed += 1
                logger.debug(f"Removed {removed} nodes from this cc")
        logger.debug(f"g has {len(g.nodes())} nodes post filtering")

        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())
        return outputs
예제 #27
0
class ExampleSourceRandomLocation(BatchProvider):
    def __init__(self):

        self.graph = Graph([
            Node(1, np.array([1, 1, 1])),
            Node(2, np.array([500, 500, 500])),
            Node(3, np.array([550, 550, 550])),
        ], [], GraphSpec(roi=Roi((-500, -500, -500), (1500, 1500, 1500))))

    def setup(self):

        self.provides(GraphKeys.TEST_POINTS, self.graph.spec)

    def provide(self, request):

        batch = Batch()

        roi = request[GraphKeys.TEST_POINTS].roi
        batch[GraphKeys.TEST_POINTS] = self.graph.crop(roi).trim(roi)
        return batch
예제 #28
0
    def process(self, batch, request):
        g = batch[self.graph].to_nx_graph()

        logger.debug(
            f"{self.name()} got graph with {g.number_of_nodes()} nodes, and "
            f"{g.number_of_edges()} edges!")

        write_roi = batch[self.graph].spec.roi.grow(-self.context,
                                                    -self.context)

        contained_nodes = [
            node for node, attr in g.nodes.items()
            if write_roi.contains(attr["location"])
        ]
        contained_components = set(g.nodes[n][self.component_attr]
                                   for n in contained_nodes)

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        component_graph = self.client.get_graph(roi=write_roi,
                                                node_inclusion="dangling",
                                                edge_inclusion="either")

        for node in contained_nodes:
            attrs = g.nodes[node]
            block_component_id = attrs[self.component_attr]
            global_component_id = component_graph.nodes[block_component_id][
                self.component_attr]
            attrs[self.component_attr] = global_component_id
            attrs[self.size_attr] = component_graph.nodes[block_component_id][
                self.size_attr]

        logger.debug(f"Graph contains {len(contained_nodes)} nodes with "
                     f"{len(contained_components)} components in write_roi")

        outputs = Batch()
        outputs[self.graph] = Graph.from_nx_graph(
            g, batch[self.graph].spec.copy())

        return outputs
예제 #29
0
    def process(self, batch, request):
        mst = batch[self.mst].to_nx_graph()
        dense_mst = batch[self.dense_mst].to_nx_graph()
        embeddings = batch[self.embeddings].data
        voxel_size = batch[self.embeddings].spec.voxel_size
        offset = batch[self.embeddings].spec.roi.get_begin()

        for (u, v), chain in self.get_edge_chains(mst, dense_mst):
            chain_embeddings = []
            for n in chain:
                n_loc = dense_mst.nodes[n]["location"]
                n_ind = tuple(int(x) for x in ((n_loc - offset) // voxel_size))
                chain_embeddings.append(
                    embeddings[(slice(None), ) * (len(embeddings.shape) - 3) +
                               n_ind])

            mst.edges[(u, v)][self.distance_attr] = self.get_stat(chain)

        outputs = Batch()
        outputs[self.mst] = Graph.from_nx_graph(mst, batch[self.mst].spec)

        return outputs
예제 #30
0
class TestSource(BatchProvider):
    def __init__(self):

        self.graph = Graph(
            [Node(id=1, location=np.array([50, 70, 100]))],
            [],
            GraphSpec(roi=Roi((-200, -200, -200), (400, 400, 478))),
        )

    def setup(self):

        self.provides(GraphKeys.TEST_GRAPH, self.graph.spec)

    def prepare(self, request):
        return request

    def provide(self, request):

        batch = Batch()

        roi = request[GraphKeys.TEST_GRAPH].roi
        batch[GraphKeys.TEST_GRAPH] = self.graph.crop(roi).trim(roi)

        return batch