Ejemplo n.º 1
0
    def _get_points(self, inside: np.ndarray, slope: np.ndarray,
                    bb: Roi) -> Tuple[Dict[int, Node], List[Tuple[int, int]]]:
        slope = slope / max(slope)
        shape = np.array(bb.get_shape())
        outside_down = inside - shape * slope
        outside_up = inside + shape * slope
        down_intercept = self._resample_relative(inside, outside_down, bb)
        up_intercept = self._resample_relative(inside, outside_up, bb)

        points = {
            # line
            Node(id=0,
                 location=down_intercept,
                 attrs={
                     "node_type": 0,
                     "radius": 0
                 }),
            Node(id=1,
                 location=up_intercept,
                 attrs={
                     "node_type": 0,
                     "radius": 0
                 }),
        }
        edges = [Edge(0, 1)]
        return self._graph_points(points, edges)
Ejemplo n.º 2
0
    def __init__(self):

        self.graph = Graph([
            Node(1, np.array([1, 1, 1])),
            Node(2, np.array([500, 500, 500])),
            Node(3, np.array([550, 550, 550])),
        ], [], GraphSpec(roi=Roi((-500, -500, -500), (1500, 1500, 1500))))
Ejemplo n.º 3
0
    def test_points_equal(self):
        points1 = [Node(id=1, location=np.array([0, 1]))]
        points2 = [Node(id=1, location=np.array([0, 1]))]
        self.assertTrue(self.points_equal(points1, points2))

        points1 = [Node(id=2, location=np.array([1, 2]))]
        points2 = [Node(id=2, location=np.array([2, 1]))]
        self.assertFalse(self.points_equal(points1, points2))
Ejemplo n.º 4
0
    def nodes(self):

        return [
            Node(0, location=np.array([0, 0, 0], dtype=self.spec.dtype)),
            Node(1, location=np.array([1, 1, 1], dtype=self.spec.dtype)),
            Node(2, location=np.array([2, 2, 2], dtype=self.spec.dtype)),
            Node(3, location=np.array([3, 3, 3], dtype=self.spec.dtype)),
            Node(4, location=np.array([4, 4, 4], dtype=self.spec.dtype)),
        ]
Ejemplo n.º 5
0
    def __init__(self):

        self.graph = Graph(
            [
                Node(id=1, location=np.array([1, 1, 1])),
                Node(id=2, location=np.array([500, 500, 500])),
                Node(id=3, location=np.array([550, 550, 550])),
            ],
            [],
            GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000))),
        )
Ejemplo n.º 6
0
    def __init__(self):

        self.dtype = float
        self.__vertices = [
            Node(id=1, location=np.array([1, 1, 1], dtype=self.dtype)),
            Node(id=2, location=np.array([500, 500, 500], dtype=self.dtype)),
            Node(id=3, location=np.array([550, 550, 550], dtype=self.dtype)),
        ]
        self.__edges = [Edge(1, 2), Edge(2, 3)]
        self.__spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]),
                                        Coordinate([1500, 1500, 1500])))
        self.graph = Graph(self.__vertices, self.__edges, self.__spec)
Ejemplo n.º 7
0
    def test_shift_points5(self):
        data = [
            Node(id=0, location=np.array([3, 0])),
            Node(id=1, location=np.array([3, 2])),
            Node(id=2, location=np.array([3, 4])),
            Node(id=3, location=np.array([3, 6])),
            Node(id=4, location=np.array([3, 8])),
        ]
        spec = GraphSpec(Roi(offset=(0, 0), shape=(15, 10)))
        points = Graph(data, [], spec)
        request_roi = Roi(offset=(3, 0), shape=(9, 10))
        shift_array = np.array([[3, 0], [-3, 0], [0, 0], [-3, 0], [3, 0]],
                               dtype=int)

        lcm_voxel_size = Coordinate((3, 2))
        shifted_data = [
            Node(id=0, location=np.array([6, 0])),
            Node(id=2, location=np.array([3, 4])),
            Node(id=4, location=np.array([6, 8])),
        ]
        result = ShiftAugment.shift_points(
            points,
            request_roi,
            shift_array,
            shift_axis=1,
            lcm_voxel_size=lcm_voxel_size,
        )
        # print("test 4", result.data, shifted_data)
        self.assertTrue(self.points_equal(result.nodes, shifted_data))
        self.assertTrue(result.spec == GraphSpec(request_roi))
Ejemplo n.º 8
0
    def __init__(self):

        self.graph = Graph(
            [Node(id=1, location=np.array([50, 70, 100]))],
            [],
            GraphSpec(roi=Roi((-200, -200, -200), (400, 400, 478))),
        )
Ejemplo n.º 9
0
def test_nodes():

    initial_locations = {
        1: np.array([1, 1, 1], dtype=np.float32),
        2: np.array([500, 500, 500], dtype=np.float32),
        3: np.array([550, 550, 550], dtype=np.float32),
    }
    replacement_locations = {
        1: np.array([0, 0, 0], dtype=np.float32),
        2: np.array([50, 50, 50], dtype=np.float32),
        3: np.array([55, 55, 55], dtype=np.float32),
    }

    nodes = [
        Node(id=id, location=location)
        for id, location in initial_locations.items()
    ]
    edges = [Edge(1, 2), Edge(2, 3)]
    spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]),
                             Coordinate([1500, 1500, 1500])))
    graph = Graph(nodes, edges, spec)
    for node in graph.nodes:
        node.location = replacement_locations[node.id]

    for node in graph.nodes:
        assert all(np.isclose(node.location, replacement_locations[node.id]))
Ejemplo n.º 10
0
    def _shift_and_crop(self, points: Graph, array: Array,
                        direction: Coordinate, output_roi: Roi):
        # Shift and crop the array
        center = array.spec.roi.get_offset() + array.spec.roi.get_shape() // 2
        new_center = center + direction
        new_offset = new_center - output_roi.get_shape() // 2
        new_roi = Roi(new_offset, output_roi.get_shape())
        array = array.crop(new_roi)
        array.spec.roi = output_roi

        new_points_data = {}
        new_points_spec = points.spec
        new_points_spec.roi = new_roi
        new_points_graph = nx.DiGraph()

        # shift points and add them to a graph
        for point_id, point in points.data.items():
            if new_roi.contains(point.location):
                new_point = point.copy()
                new_point.location = (point.location - new_offset +
                                      output_roi.get_begin())
                new_points_graph.add_node(
                    new_point.point_id,
                    point_id=new_point.point_id,
                    parent_id=new_point.parent_id,
                    location=new_point.location,
                    label_id=new_point.label_id,
                    radius=new_point.radius,
                    point_type=new_point.point_type,
                )
                if points.data.get(
                        new_point.parent_id, False) and new_roi.contains(
                            points.data[new_point.parent_id].location):
                    new_points_graph.add_edge(new_point.parent_id,
                                              new_point.point_id)

        # relabel connected components
        for i, connected_component in enumerate(
                nx.weakly_connected_components(new_points_graph)):
            for node in connected_component:
                new_points_graph.nodes[node]["label_id"] = i

        # store new graph data in points
        new_points_data = {
            point_id: Node(
                point["location"],
                point_id=point["point_id"],
                point_type=point["point_type"],
                radius=point["radius"],
                parent_id=point["parent_id"],
                label_id=point["label_id"],
            )
            for point_id, point in new_points_graph.nodes.items()
        }
        points = Graph(new_points_data, new_points_spec)
        points.spec.roi = output_roi
        return points, array
Ejemplo n.º 11
0
def test_transpose():
    voxel_size = Coordinate((20, 20))
    graph_key = GraphKey("GRAPH")
    array_key = ArrayKey("ARRAY")
    graph = Graph(
        [Node(id=1, location=np.array([450, 550]))],
        [],
        GraphSpec(roi=Roi((100, 200), (800, 600))),
    )
    data = np.zeros([40, 30])
    data[17, 17] = 1
    array = Array(
        data, ArraySpec(roi=Roi((100, 200), (800, 600)),
                        voxel_size=voxel_size))

    default_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[], transpose_only=[0, 1], transpose_probs=[0, 0]))

    transpose_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[], transpose_only=[0, 1], transpose_probs=[1, 1]))

    request = BatchRequest()
    request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300)))
    request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300)))
    with build(default_pipeline):
        expected_location = [450, 550]
        batch = default_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert (
            batch[array_key].data[node_voxel_index] == 1
        ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"

    with build(transpose_pipeline):
        expected_location = [410, 590]
        batch = transpose_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert (
            batch[array_key].data[node_voxel_index] == 1
        ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
Ejemplo n.º 12
0
    def test_shift_points3(self):
        data = [Node(id=1, location=np.array([0, 1]))]
        spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5)))
        points = Graph(data, [], spec)
        request_roi = Roi(offset=(0, 1), shape=(5, 3))
        shift_array = np.array([[0, 1], [0, -1], [0, 0], [0, 0], [0, 1]],
                               dtype=int)
        lcm_voxel_size = Coordinate((1, 1))

        shifted_points = Graph([Node(id=1, location=np.array([0, 2]))], [],
                               GraphSpec(request_roi))
        result = ShiftAugment.shift_points(
            points,
            request_roi,
            shift_array,
            shift_axis=0,
            lcm_voxel_size=lcm_voxel_size,
        )
        # print("test 3", result.data, shifted_points.data)
        self.assertTrue(self.points_equal(result.nodes, shifted_points.nodes))
        self.assertTrue(result.spec == GraphSpec(request_roi))
Ejemplo n.º 13
0
    def setup(self):
        roi = Roi(Coordinate([0] * len(self.size)), self.size)
        for points_key in self.points:
            self.provides(points_key, GraphSpec(roi=roi,
                                                directed=self.directed))

        k = min(self.size)
        nodes = [
            Node(id=i, location=np.array([i * k / self.num_points] * 3))
            for i in range(self.num_points)
        ]
        edges = [Edge(i, i + 1) for i in range(self.num_points - 1)]

        self.graph = Graph(nodes, edges,
                           GraphSpec(roi=roi, directed=self.directed))
Ejemplo n.º 14
0
    def provide(self, request):

        # print("ScanTestSource: Got request " + str(request))

        batch = Batch()

        # have the pixels encode their position
        for (array_key, spec) in request.array_specs.items():

            roi = spec.roi
            roi_voxel = roi // self.spec[array_key].voxel_size
            # print("ScanTestSource: Adding " + str(array_key))

            # the z,y,x coordinates of the ROI
            meshgrids = np.meshgrid(range(roi_voxel.get_begin()[0],
                                          roi_voxel.get_end()[0]),
                                    range(roi_voxel.get_begin()[1],
                                          roi_voxel.get_end()[1]),
                                    range(roi_voxel.get_begin()[2],
                                          roi_voxel.get_end()[2]),
                                    indexing='ij')
            data = meshgrids[0] + meshgrids[1] + meshgrids[2]

            # print("Roi is: " + str(roi))

            spec = self.spec[array_key].copy()
            spec.roi = roi
            batch.arrays[array_key] = Array(data, spec)

        for graph_key, spec in request.graph_specs.items():
            # node at x, y, z if x%100==0, y%10==0, z%10==0
            nodes = []
            start = spec.roi.get_begin() - tuple(
                x % s for x, s in zip(spec.roi.get_begin(), [100, 10, 10]))
            for i, j, k in itertools.product(*[
                    range(a, b, s) for a, b, s in zip(
                        start, spec.roi.get_end(), [100, 10, 10])
            ]):
                location = np.array([i, j, k])
                if spec.roi.contains(location):
                    nodes.append(
                        Node(id=coordinate_to_id(i, j, k), location=location))
            batch.graphs[graph_key] = Graph(nodes, [], spec)

        return batch
Ejemplo n.º 15
0
    def setup(self):

        self.points = [
            Node(
                i,
                np.array([(i // 100) % 10 * 4, (i // 10) % 10 * 4,
                          i % 10 * 4])) for i in range(1000)
        ]

        self.provides(
            GraphKeys.TEST_POINTS,
            GraphSpec(roi=Roi((-40, -40, -40), (120, 120, 120))),
        )

        self.provides(
            ArrayKeys.TEST_LABELS,
            ArraySpec(
                roi=Roi((-40, -40, -40), (120, 120, 120)),
                voxel_size=Coordinate((4, 1, 1)),
                interpolatable=False,
            ),
        )
Ejemplo n.º 16
0
    def provide(self, request):
        outputs = Batch()

        nodes = [
            Node(id=0, location=np.array((1, 1, 1))),
            Node(id=1, location=np.array((10, 10, 10))),
            Node(id=2, location=np.array((19, 19, 19))),
            Node(id=3, location=np.array((21, 21, 21))),
            Node(id=104, location=np.array((30, 30, 30))),
            Node(id=5, location=np.array((39, 39, 39))),
        ]
        edges = [Edge(0, 1), Edge(1, 2), Edge(3, 104), Edge(104, 5)]
        spec = self.spec[GraphKeys.RAW].copy()
        spec.roi = request[GraphKeys.RAW].roi
        graph = Graph(nodes, edges, spec)

        outputs[GraphKeys.RAW] = graph.crop(spec.roi)
        return outputs
Ejemplo n.º 17
0
    def setup(self):

        self.points = [
            Node(0, np.array([0, 0, 0])),
            Node(1, np.array([0, 10, 0])),
            Node(2, np.array([0, 20, 0])),
            Node(3, np.array([0, 30, 0])),
            Node(4, np.array([0, 40, 0])),
            Node(5, np.array([0, 50, 0])),
        ]

        self.provides(GraphKeys.TEST_POINTS,
                      GraphSpec(roi=Roi((-100, -100, -100), (200, 200, 200))))

        self.provides(
            ArrayKeys.TEST_LABELS,
            ArraySpec(
                roi=Roi((-100, -100, -100), (200, 200, 200)),
                voxel_size=Coordinate((4, 1, 1)),
                interpolatable=False,
            ),
        )
Ejemplo n.º 18
0
    def __get_pre_and_postsyn_locations(self, roi):

        presyn_locs, postsyn_locs = {}, {}
        min_dist_between_presyn_locs = 250
        voxel_size_points = self.spec[ArrayKeys.RAW].voxel_size
        min_dist_pre_to_postsyn_loc, max_dist_pre_to_postsyn_loc = 60, 120
        num_presyn_locations = roi.size() // (
            np.prod(50 * np.asarray(voxel_size_points))
        )  # 1 synapse per 50vx^3 cube
        num_postsyn_locations = np.random.randint(
            low=1, high=3
        )  # 1 to 3 postsyn partners

        loc_id = 0
        all_presyn_locs = []
        for nr_presyn_loc in range(num_presyn_locations):
            loc_id = loc_id + 1
            presyn_loc_id = loc_id

            presyn_loc_too_close = True
            while presyn_loc_too_close:
                presyn_location = np.asarray(
                    [
                        np.random.randint(
                            low=roi.get_begin()[0], high=roi.get_end()[0]
                        ),
                        np.random.randint(
                            low=roi.get_begin()[1], high=roi.get_end()[1]
                        ),
                        np.random.randint(
                            low=roi.get_begin()[2], high=roi.get_end()[2]
                        ),
                    ]
                )
                # ensure that partner locations of diff presyn locations are not overlapping
                presyn_loc_too_close = False
                for previous_loc in all_presyn_locs:
                    if np.linalg.norm(presyn_location - previous_loc) < (
                        min_dist_between_presyn_locs
                    ):
                        presyn_loc_too_close = True

            syn_id = nr_presyn_loc

            partner_ids = []
            for nr_partner_loc in range(num_postsyn_locations):
                loc_id = loc_id + 1
                partner_ids.append(loc_id)
                postsyn_loc_is_inside = False
                while not postsyn_loc_is_inside:
                    postsyn_location = presyn_location + np.random.choice(
                        (-1, 1), size=3, replace=True
                    ) * np.random.randint(
                        min_dist_pre_to_postsyn_loc, max_dist_pre_to_postsyn_loc, size=3
                    )
                    if roi.contains(Coordinate(postsyn_location)):
                        postsyn_loc_is_inside = True

                postsyn_locs[int(loc_id)] = deepcopy(
                    Node(
                        loc_id,
                        location=postsyn_location,
                        attrs={
                            "location_id": loc_id,
                            "synapse_id": syn_id,
                            "partner_ids": [presyn_loc_id],
                            "props": {},
                        },
                    )
                )

            presyn_locs[int(presyn_loc_id)] = deepcopy(
                Node(
                    presyn_loc_id,
                    location=presyn_location,
                    attrs={
                        "location_id": presyn_loc_id,
                        "synapse_id": syn_id,
                        "partner_ids": partner_ids,
                        "props": {},
                    },
                )
            )

        return presyn_locs, postsyn_locs
Ejemplo n.º 19
0
    def test_output(self):

        GraphKey("TEST_GRAPH")

        pipeline = ExampleGraphSource() + GrowFilter()

        with build(pipeline):

            batch = pipeline.request_batch(
                BatchRequest({
                    GraphKeys.TEST_GRAPH:
                    GraphSpec(roi=Roi((0, 0, 0), (50, 50, 50)))
                }))

            graph = batch[GraphKeys.TEST_GRAPH]
            expected_vertices = (
                Node(id=1, location=np.array([1.0, 1.0, 1.0], dtype=float)),
                Node(
                    id=2,
                    location=np.array([50.0, 50.0, 50.0], dtype=float),
                    temporary=True,
                ),
            )
            seen_vertices = tuple(graph.nodes)
            self.assertCountEqual(
                [v.original_id for v in expected_vertices],
                [v.original_id for v in seen_vertices],
            )
            for expected, actual in zip(
                    sorted(expected_vertices, key=lambda v: tuple(v.location)),
                    sorted(seen_vertices, key=lambda v: tuple(v.location)),
            ):
                assert all(np.isclose(expected.location, actual.location))

            batch = pipeline.request_batch(
                BatchRequest({
                    GraphKeys.TEST_GRAPH:
                    GraphSpec(roi=Roi((25, 25, 25), (500, 500, 500)))
                }))

            graph = batch[GraphKeys.TEST_GRAPH]
            expected_vertices = (
                Node(
                    id=1,
                    location=np.array([25.0, 25.0, 25.0], dtype=float),
                    temporary=True,
                ),
                Node(id=2,
                     location=np.array([500.0, 500.0, 500.0], dtype=float)),
                Node(
                    id=3,
                    location=np.array([525.0, 525.0, 525.0], dtype=float),
                    temporary=True,
                ),
            )
            seen_vertices = tuple(graph.nodes)
            self.assertCountEqual(
                [v.original_id for v in expected_vertices],
                [v.original_id for v in seen_vertices],
            )
            for expected, actual in zip(
                    sorted(expected_vertices, key=lambda v: tuple(v.location)),
                    sorted(seen_vertices, key=lambda v: tuple(v.location)),
            ):
                assert all(np.isclose(expected.location, actual.location))
Ejemplo n.º 20
0
    def _toy_swc_points(self):
        """
        shape:

        -----------
        |
        |
        |----------
        |
        |
        -----------
        """
        arr = np.array
        points = [
            # backbone
            Node(id=0,
                 location=arr([0, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=1,
                 location=arr([1, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=2,
                 location=arr([2, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=3,
                 location=arr([3, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=4,
                 location=arr([4, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=5,
                 location=arr([5, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=6,
                 location=arr([6, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=7,
                 location=arr([7, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=8,
                 location=arr([8, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=9,
                 location=arr([9, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=10,
                 location=arr([10, 0, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            # bottom line
            Node(id=11,
                 location=arr([0, 1, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=12,
                 location=arr([0, 2, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=13,
                 location=arr([0, 3, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=14,
                 location=arr([0, 4, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=15,
                 location=arr([0, 5, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=16,
                 location=arr([0, 6, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=17,
                 location=arr([0, 7, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=18,
                 location=arr([0, 8, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=19,
                 location=arr([0, 9, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=20,
                 location=arr([0, 10, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            # mid line
            Node(id=21,
                 location=arr([5, 1, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=22,
                 location=arr([5, 2, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=23,
                 location=arr([5, 3, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=24,
                 location=arr([5, 4, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=25,
                 location=arr([5, 5, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=26,
                 location=arr([5, 6, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=27,
                 location=arr([5, 7, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=28,
                 location=arr([5, 8, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=29,
                 location=arr([5, 9, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=30,
                 location=arr([5, 10, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            # top line
            Node(id=31,
                 location=arr([10, 1, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=32,
                 location=arr([10, 2, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=33,
                 location=arr([10, 3, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=34,
                 location=arr([10, 4, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=35,
                 location=arr([10, 5, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=36,
                 location=arr([10, 6, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=37,
                 location=arr([10, 7, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=38,
                 location=arr([10, 8, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=39,
                 location=arr([10, 9, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
            Node(id=40,
                 location=arr([10, 10, 5]),
                 attrs={
                     "radius": 0,
                     "node_type": 0
                 }),
        ]

        edges = [
            Edge(0, 0),
            Edge(0, 1),
            Edge(1, 2),
            Edge(2, 3),
            Edge(3, 4),
            Edge(4, 5),
            Edge(5, 6),
            Edge(6, 7),
            Edge(7, 8),
            Edge(8, 9),
            Edge(9, 10),
            Edge(0, 11),
            Edge(11, 12),
            Edge(12, 13),
            Edge(13, 14),
            Edge(14, 15),
            Edge(15, 16),
            Edge(16, 17),
            Edge(17, 18),
            Edge(18, 19),
            Edge(19, 20),
            Edge(5, 21),
            Edge(21, 22),
            Edge(22, 23),
            Edge(23, 24),
            Edge(24, 25),
            Edge(25, 26),
            Edge(26, 27),
            Edge(27, 28),
            Edge(28, 29),
            Edge(29, 30),
            Edge(10, 31),
            Edge(31, 32),
            Edge(32, 33),
            Edge(33, 34),
            Edge(34, 35),
            Edge(35, 36),
            Edge(36, 37),
            Edge(37, 38),
            Edge(38, 39),
            Edge(39, 40),
        ]

        return Graph(
            points,
            edges,
            GraphSpec(
                roi=Roi(Coordinate((-100, -100, -100)),
                        Coordinate((300, 300, 300))),
                directed=True,
            ),
        )
Ejemplo n.º 21
0
    def process(self, batch, request):
        outputs = Batch()

        raw_base_spec = batch[self.raw_base].spec.copy()

        # Get base arrays
        raw_base_array = batch[self.raw_base].data
        labels_base_array = batch[self.labels_base].data

        # Get add arrays
        raw_add_array = batch[self.raw_add].data
        labels_add_array = batch[self.labels_add].data

        if self.scale_add_volume:
            raw_base_median = np.median(raw_base_array)
            raw_add_median = np.median(raw_add_array)
            diff = raw_base_median - raw_add_median
            raw_add_array = raw_add_array + diff

        # fuse labels
        fused_labels_array = self._relabel(labels_base_array)
        next_label_id = np.max(fused_labels_array) + 1

        add_mask = np.zeros_like(fused_labels_array, dtype=bool)
        for label in np.unique(labels_add_array):
            if label == 0:
                continue
            label_mask = labels_add_array == label

            # handle overlap
            overlap = np.logical_and(fused_labels_array, label_mask)
            fused_labels_array[overlap] = -1

            # assign new label
            add_mask[label_mask] = True
            fused_labels_array[label_mask] = next_label_id
            next_label_id += 1

        # fuse raw
        if self.blend_mode == "intensity":

            add_mask = raw_add_array.astype(np.float32) / np.max(raw_add_array)
            raw_fused_array = add_mask * raw_add_array + (1 - add_mask) * raw_base_array

        elif self.blend_mode == "add":
            raw_fused_array = 0.5*raw_add_array / np.max(
                raw_add_array
            ) + 0.5*raw_base_array / np.max(raw_base_array)
            raw_fused_array = np.clip(raw_fused_array, 0, 1)

        elif self.blend_mode == "labels_mask":

            soft_mask = np.zeros_like(add_mask, dtype="float32")
            ndimage.gaussian_filter(
                add_mask.astype("float32"),
                sigma=self.blend_smoothness / np.array(raw_base_spec.voxel_size),
                output=soft_mask,
                mode=self.gaussian_smooth_mode,
            )
            soft_mask /= np.clip(np.max(soft_mask), 1e-5, float("inf"))
            soft_mask = np.clip((soft_mask * 2), 0, 1)
            if self.soft_mask is not None:
                outputs.arrays[self.soft_mask] = Array(
                    soft_mask,
                    spec=ArraySpec(
                        roi=raw_base_spec.roi, voxel_size=raw_base_spec.voxel_size
                    ),
                )
            if self.masked_base is not None:
                outputs.arrays[self.masked_base] = Array(
                    raw_base_array * (soft_mask > 0.25), spec=raw_base_spec.copy()
                )
            if self.masked_add is not None:
                outputs.arrays[self.masked_add] = Array(
                    raw_add_array * soft_mask,
                    spec=ArraySpec(
                        roi=raw_base_spec.roi, voxel_size=raw_base_spec.voxel_size
                    ),
                )
            if self.mask_maxed is not None:
                outputs.arrays[self.mask_maxed] = Array(
                    np.maximum(
                        raw_base_array * (soft_mask > 0.25), raw_add_array * soft_mask
                    ),
                    spec=ArraySpec(
                        roi=raw_base_spec.roi, voxel_size=raw_base_spec.voxel_size
                    ),
                )

            raw_fused_array = np.maximum(soft_mask * raw_add_array, raw_base_array)
            raw_fused_array = np.clip(raw_fused_array, 0, 1)

        else:
            raise NotImplementedError("Unknown blend mode %s." % self.blend_mode)

        # load specs
        labels_add_spec = batch[self.labels_add].spec.copy()
        raw_base_spec = batch[self.raw_base].spec.copy()
        raw_dtype = batch[self.raw_base].data.dtype
        raw_base_spec.dtype = raw_dtype

        # return raw and labels for "fused" volume
        # raw_fused_array.astype(raw_base_spec.dtype)
        outputs.arrays[self.raw_fused] = Array(
            data=raw_fused_array.astype(raw_base_spec.dtype), spec=raw_base_spec
        )
        outputs.arrays[self.labels_fused] = Array(
            data=fused_labels_array, spec=labels_add_spec
        )

        # fuse points:
        if self.points_fused in request:
            node_ids = [node.id for node in batch.graphs[self.points_base].nodes]
            num_nodes = len(node_ids)
            offset = 0 if num_nodes == 0 else max(node_ids) + 1
            fused_graph = batch.graphs[self.points_base].copy()
            for node in batch.graphs[self.points_add].nodes:
                attrs = deepcopy(node.all)
                attrs["id"] += offset
                fused_graph.add_node(Node.from_attrs(attrs))
            for edge in batch.graphs[self.points_add].edges:
                edge = Edge(edge.u + offset, edge.v + offset)
                fused_graph.add_edge(edge)
            outputs.graphs[self.points_fused] = fused_graph

        return outputs