Пример #1
0
    def test_merge_basics(self):
        voxel_size = (1, 1, 1)
        GraphKey("PRESYN")
        ArrayKey("GT_LABELS")
        graphsource = GraphTestSource(voxel_size)
        arraysource = ArrayTestSoure(voxel_size)
        pipeline = (graphsource, arraysource) + MergeProvider() + RandomLocation()
        window_request = Coordinate((50, 50, 50))
        with build(pipeline):
            # Check basic merging.
            request = BatchRequest()
            request.add((GraphKeys.PRESYN), window_request)
            request.add((ArrayKeys.GT_LABELS), window_request)
            batch_res = pipeline.request_batch(request)
            self.assertTrue(ArrayKeys.GT_LABELS in batch_res.arrays)
            self.assertTrue(GraphKeys.PRESYN in batch_res.graphs)

            # Check that request of only one source also works.
            request = BatchRequest()
            request.add((GraphKeys.PRESYN), window_request)
            batch_res = pipeline.request_batch(request)
            self.assertFalse(ArrayKeys.GT_LABELS in batch_res.arrays)
            self.assertTrue(GraphKeys.PRESYN in batch_res.graphs)

        # Check that it fails, when having two sources that provide the same type.
        arraysource2 = ArrayTestSoure(voxel_size)
        pipeline_fail = (arraysource, arraysource2) + MergeProvider() + RandomLocation()
        with self.assertRaises(PipelineSetupError):
            with build(pipeline_fail):
                pass
Пример #2
0
def test_transpose():
    voxel_size = Coordinate((20, 20))
    graph_key = GraphKey("GRAPH")
    array_key = ArrayKey("ARRAY")
    graph = Graph(
        [Node(id=1, location=np.array([450, 550]))],
        [],
        GraphSpec(roi=Roi((100, 200), (800, 600))),
    )
    data = np.zeros([40, 30])
    data[17, 17] = 1
    array = Array(
        data, ArraySpec(roi=Roi((100, 200), (800, 600)),
                        voxel_size=voxel_size))

    default_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[], transpose_only=[0, 1], transpose_probs=[0, 0]))

    transpose_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[], transpose_only=[0, 1], transpose_probs=[1, 1]))

    request = BatchRequest()
    request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300)))
    request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300)))
    with build(default_pipeline):
        expected_location = [450, 550]
        batch = default_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert (
            batch[array_key].data[node_voxel_index] == 1
        ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"

    with build(transpose_pipeline):
        expected_location = [410, 590]
        batch = transpose_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert (
            batch[array_key].data[node_voxel_index] == 1
        ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
Пример #3
0
    def test_output(self):
        a = ArrayKey("A")
        b = ArrayKey("B")
        source_a = TestSourceRandomLocation(a)
        source_b = TestSourceRandomLocation(b)

        pipeline = (source_a, source_b) + \
            MergeProvider() + CustomRandomLocation()

        with build(pipeline):

            for i in range(10):
                batch = pipeline.request_batch(
                    BatchRequest({
                        a: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20))),
                        b: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20)))
                    }))

                self.assertTrue(np.sum(batch.arrays[a].data) > 0)
                self.assertTrue(np.sum(batch.arrays[b].data) > 0)

                # Request a ROI with the same shape as the entire ROI
                full_roi_a = Roi((0, 0, 0), source_a.roi.get_shape())
                full_roi_b = Roi((0, 0, 0), source_b.roi.get_shape())
                batch = pipeline.request_batch(
                    BatchRequest({
                        a: ArraySpec(roi=full_roi_a),
                        b: ArraySpec(roi=full_roi_b)
                    }))
Пример #4
0
    def test_square(self):
        

        test_graph = GraphKey("TEST_GRAPH")
        test_array1 = ArrayKey("TEST_ARRAY1")
        test_array2 = ArrayKey("TEST_ARRAY2")

        pipeline = ((ArrayTestSource(), TestSource()) + MergeProvider() + 
                    Pad(test_array1, None) + Pad(test_array2, None) + Pad(test_graph, None)
                    + SimpleAugment(
            mirror_only=[1,2], transpose_only=[1,2]
        ))

        request = BatchRequest()
        request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=Roi((0, 50, 65), (100, 100, 100)))
        request[ArrayKeys.TEST_ARRAY1] = ArraySpec(roi=Roi((0, 0, 15), (100, 200, 200)))
        request[ArrayKeys.TEST_ARRAY2] = ArraySpec(roi=Roi((0, 50, 65), (100, 100, 100)))

        
        with build(pipeline):
            for i in range(100):
                batch = pipeline.request_batch(request)
                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1

                for (array_key, array) in batch.arrays.items():
                    assert batch.arrays[array_key].data.shape == batch.arrays[array_key].spec.roi.get_shape()
Пример #5
0
def test_realistic_valid_examples(example, use_gurobi):
    penalty_attr = "penalty"
    location_attr = "location"
    example_dir = Path(__file__).parent / "mouselight_examples" / "valid" / example

    consensus = PointsKey("CONSENSUS")
    skeletonization = PointsKey("SKELETONIZATION")
    matched = PointsKey("MATCHED")
    matched_with_fallback = PointsKey("MATCHED_WITH_FALLBACK")

    inf_roi = Roi(Coordinate((None,) * 3), Coordinate((None,) * 3))

    request = BatchRequest()
    request[matched] = PointsSpec(roi=inf_roi)
    request[matched_with_fallback] = PointsSpec(roi=inf_roi)
    request[consensus] = PointsSpec(roi=inf_roi)

    pipeline = (
        (
            GraphSource(example_dir / "graph.obj", [skeletonization]),
            GraphSource(example_dir / "tree.obj", [consensus]),
        )
        + MergeProvider()
        + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            expected_edge_len=10,
            match_distance_threshold=76,
            max_gap_crossing=48,
            use_gurobi=use_gurobi,
            location_attr=location_attr,
            penalty_attr=penalty_attr,
        )
        + TopologicalMatcher(
            skeletonization,
            consensus,
            matched_with_fallback,
            expected_edge_len=10,
            match_distance_threshold=76,
            max_gap_crossing=48,
            use_gurobi=use_gurobi,
            location_attr=location_attr,
            penalty_attr=penalty_attr,
            with_fallback=True,
        )
    )

    with build(pipeline):
        batch = pipeline.request_batch(request)
        consensus_ccs = list(batch[consensus].connected_components)
        matched_with_fallback_ccs = list(batch[matched_with_fallback].connected_components)
        matched_ccs = list(batch[matched].connected_components)

        assert len(matched_ccs) == len(consensus_ccs)
Пример #6
0
    def test_multi_transpose(self):
        test_graph = GraphKey("TEST_GRAPH")
        test_array1 = ArrayKey("TEST_ARRAY1")
        test_array2 = ArrayKey("TEST_ARRAY2")
        point = np.array([50, 70, 100])

        transpose_dims = [0, 1, 2]
        pipeline = (ArrayTestSource(),
                    ExampleSource()) + MergeProvider() + SimpleAugment(
                        mirror_only=[], transpose_only=transpose_dims)

        request = BatchRequest()
        offset = (0, 20, 33)
        request[GraphKeys.TEST_GRAPH] = GraphSpec(
            roi=Roi(offset, (100, 100, 120)))
        request[ArrayKeys.TEST_ARRAY1] = ArraySpec(
            roi=Roi((0, 0, 0), (100, 200, 300)))
        request[ArrayKeys.TEST_ARRAY2] = ArraySpec(
            roi=Roi((0, 100, 250), (100, 100, 50)))

        # Create all possible permurations of our transpose dims
        transpose_combinations = list(permutations(transpose_dims, 3))
        possible_loc = np.zeros((len(transpose_combinations), 3))

        # Transpose points in all possible ways
        for i, comb in enumerate(transpose_combinations):
            possible_loc[i] = point[np.array(comb)]

        with build(pipeline):
            seen_transposed = False
            seen_node = True
            for i in range(100):
                batch = pipeline.request_batch(request)

                if len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1:
                    seen_node = True
                    node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0]

                    assert node.location in possible_loc

                    seen_transposed = seen_transposed or any(
                        [node.location[dim] != point[dim] for dim in range(3)])
                    assert Roi((0, 20, 33), (100, 100, 120)).contains(
                        batch[GraphKeys.TEST_GRAPH].spec.roi)
                    assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(
                        node.location)

                for (array_key, array) in batch.arrays.items():
                    assert batch.arrays[array_key].data.shape == batch.arrays[
                        array_key].spec.roi.get_shape()
            assert seen_transposed
            assert seen_node
Пример #7
0
    def test_pipeline3(self):
        array_key = ArrayKey("TEST_ARRAY")
        points_key = GraphKey("TEST_POINTS")
        voxel_size = Coordinate((1, 1))
        spec = ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"},
                                 array_specs={array_key: spec})
        csv_source = CsvPointsSource(
            self.fake_points_file,
            points_key,
            GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))),
        )

        request = BatchRequest()
        shape = Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = ShiftAugment(prob_slip=0.2,
                                  prob_shift=0.2,
                                  sigma=5,
                                  shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + MergeProvider() +
                    RandomLocation(ensure_nonempty=points_key) + shift_node)
        with build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = list(request[points_key].nodes)
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(
                    result_val,
                    list(result_points),
                    target_vals,
                    self.fake_points,
                ),
            )
Пример #8
0
    def test_impossible(self):
        a = ArrayKey("A")
        b = ArrayKey("B")
        source_a = TestSourceRandomLocation(a)
        source_b = TestSourceRandomLocation(b)

        pipeline = (source_a, source_b) + \
            MergeProvider() + CustomRandomLocation()

        with build(pipeline):
            with self.assertRaises(AssertionError):
                batch = pipeline.request_batch(
                    BatchRequest({
                        a:
                        ArraySpec(roi=Roi((0, 0, 0), (200, 20, 20))),
                        b:
                        ArraySpec(roi=Roi((1000, 100, 100), (220, 22, 22))),
                    }))
Пример #9
0
    def test_two_transpose(self):
        test_graph = GraphKey("TEST_GRAPH")
        test_array1 = ArrayKey("TEST_ARRAY1")
        test_array2 = ArrayKey("TEST_ARRAY2")

        transpose_dims = [1, 2]
        pipeline = (ArrayTestSource(),
                    ExampleSource()) + MergeProvider() + SimpleAugment(
                        mirror_only=[], transpose_only=transpose_dims)

        request = BatchRequest()
        request[GraphKeys.TEST_GRAPH] = GraphSpec(
            roi=Roi((0, 20, 33), (100, 100, 120)))
        request[ArrayKeys.TEST_ARRAY1] = ArraySpec(
            roi=Roi((0, 0, 0), (100, 200, 300)))
        request[ArrayKeys.TEST_ARRAY2] = ArraySpec(
            roi=Roi((0, 100, 250), (100, 100, 50)))

        possible_loc = [[50, 50], [70, 100], [100, 70]]
        with build(pipeline):
            seen_transposed = False
            for i in range(100):
                batch = pipeline.request_batch(request)
                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
                node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0]
                logging.debug(node.location)
                assert all([
                    node.location[dim] in possible_loc[dim] for dim in range(3)
                ])
                seen_transposed = seen_transposed or any([
                    node.location[dim] != possible_loc[dim][0]
                    for dim in range(3)
                ])
                assert Roi((0, 20, 33), (100, 100, 120)).contains(
                    batch[GraphKeys.TEST_GRAPH].spec.roi)
                assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(
                    node.location)

                for (array_key, array) in batch.arrays.items():
                    assert batch.arrays[array_key].data.shape == batch.arrays[
                        array_key].spec.roi.get_shape()

            assert seen_transposed
Пример #10
0
    def test_impossible(self):
        a = ArrayKey("A")
        b = ArrayKey("B")
        null_key = ArrayKey("NULL")
        source_a = ExampleSourceRandomLocation(a)
        source_b = ExampleSourceRandomLocation(b)

        pipeline = ((source_a, source_b) + MergeProvider() +
                    CustomRandomLocation(null_key))

        with build(pipeline):
            with self.assertRaises(PipelineRequestError):
                batch = pipeline.request_batch(
                    BatchRequest({
                        a:
                        ArraySpec(roi=Roi((0, 0, 0), (200, 20, 20))),
                        b:
                        ArraySpec(roi=Roi((1000, 100, 100), (220, 22, 22))),
                    }))
Пример #11
0
    def test_two_disjoint_lines_intensity(self):
        LABEL_RADIUS = 3
        RAW_RADIUS = 3
        BLEND_SMOOTHNESS = 3

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points.to_nx_graph())

        # create swc sources
        fused = ArrayKey("FUSED")
        fused_labels = ArrayKey("FUSED_LABELS")
        fused_swc = PointsKey("FUSED_SWC")
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")
        raw_key_names = ("RAW_A", "RAW_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)
        raw_keys = tuple(ArrayKey(name) for name in raw_key_names)

        # add request
        request = BatchRequest()
        request.add(fused, bb.get_shape())
        request.add(fused_labels, bb.get_shape())
        request.add(fused_swc, bb.get_shape())
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(raw_keys[0], bb.get_shape())
        request.add(raw_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (
            data_sources_a
            + SwcFileSource(swc_paths[0], [swc_keys[0]], [PointsSpec(roi=bb)])
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=labels_keys[0], radii=[LABEL_RADIUS])
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=raw_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=raw_keys[0], radii=[RAW_RADIUS])
            + Normalize(raw_keys[0])
        )

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (
            data_sources_b
            + SwcFileSource(swc_paths[1], [swc_keys[1]], [PointsSpec(roi=bb)])
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=labels_keys[1], radii=[LABEL_RADIUS])
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=raw_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=raw_keys[1], radii=[RAW_RADIUS])
            + Normalize(raw_keys[1])
        )
        data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider()

        pipeline = data_sources + FusionAugment(
            raw_keys[0],
            raw_keys[1],
            labels_keys[0],
            labels_keys[1],
            swc_keys[0],
            swc_keys[1],
            fused,
            fused_labels,
            fused_swc,
            blend_mode="intensity",
            blend_smoothness=BLEND_SMOOTHNESS,
            num_blended_objects=0,
        )

        with build(pipeline):
            batch = pipeline.request_batch(request)

        fused_data = batch[fused].data
        fused_data = np.pad(fused_data, (1,), "constant", constant_values=(0,))

        a_data = batch[raw_keys[0]].data
        a_data = np.pad(a_data, (1,), "constant", constant_values=(0,))

        b_data = batch[raw_keys[1]].data
        b_data = np.pad(b_data, (1,), "constant", constant_values=(0,))

        diff = np.linalg.norm(fused_data - a_data - b_data)
        self.assertAlmostEqual(diff, 0)
Пример #12
0
def test_recenter():
    path = Path(self.path_to("test_swc_source.swc"))

    # write test swc
    self._write_swc(path, self._toy_swc_points())

    # read arrays
    swc_source = PointsKey("SWC_SOURCE")
    labels_source = ArrayKey("LABELS_SOURCE")
    img_source = ArrayKey("IMG_SOURCE")
    img_swc = PointsKey("IMG_SWC")
    label_swc = PointsKey("LABEL_SWC")
    imgs = ArrayKey("IMGS")
    labels = ArrayKey("LABELS")
    points_a = PointsKey("SKELETON_A")
    points_b = PointsKey("SKELETON_B")
    img_a = ArrayKey("VOLUME_A")
    img_b = ArrayKey("VOLUME_B")
    labels_a = ArrayKey("LABELS_A")
    labels_b = ArrayKey("LABELS_B")

    # Get points from test swc
    swc_file_source = SwcFileSource(
        path, [swc_source], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
    )
    # Create an artificial image source by rasterizing the points
    image_source = (
        SwcFileSource(
            path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=img_swc,
            array=img_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=img_source, labels_binary=imgs)
        + GrowLabels(array=imgs, radius=0)
    )
    # Create an artificial label source by rasterizing the points
    label_source = (
        SwcFileSource(
            path, [label_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=labels_source, labels_binary=labels)
        + GrowLabels(array=labels, radius=1)
    )

    skeleton = tuple()
    skeleton += (
        (swc_file_source, image_source, label_source)
        + MergeProvider()
        + RandomLocation(ensure_nonempty=swc_source, ensure_centered=True)
    )

    pipeline = (
        skeleton
        + GetNeuronPair(
            point_source=swc_source,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=4,
            shift_attempts=100,
        )
        + Recenter(points_a, img_a, max_offset=4)
        + Recenter(points_b, img_b, max_offset=4)
    )

    request = BatchRequest()

    data_shape = 9

    request.add(points_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(points_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_b, Coordinate((data_shape, data_shape, data_shape)))

    with build(pipeline):
        batch = pipeline.request_batch(request)

    data_a = batch[img_a].data
    assert data_a[tuple(np.array(data_a.shape) // 2)] == 1
    data_a = np.pad(data_a, (1,), "constant", constant_values=(0,))
    data_b = batch[img_b].data
    assert data_b[tuple(np.array(data_b.shape) // 2)] == 1
    data_b = np.pad(data_b, (1,), "constant", constant_values=(0,))

    data_c = data_a + data_b

    data = np.array((data_a, data_b, data_c))

    for _, point in batch[points_a].data.items():
        assert (
            data[(0,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1, its {}".format(
            point.location, data[(0,) + tuple(int(x) for x in point.location)]
        )

    for _, point in batch[points_b].data.items():
        assert (
            data[(1,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1".format(point.location)
Пример #13
0
    def test_both(self):
        test_graph = GraphKey("TEST_GRAPH")
        test_array1 = ArrayKey("TEST_ARRAY1")
        test_array2 = ArrayKey("TEST_ARRAY2")
        og_point = np.array([50, 70, 100])

        transpose_dims = [0, 1, 2]
        mirror_dims = [0, 1, 2]
        pipeline = ((ArrayTestSource(), TestSource()) + MergeProvider() + 
                    Pad(test_array1, None) + Pad(test_array2, None) + Pad(test_graph, None)
                    + SimpleAugment(
            mirror_only=mirror_dims, transpose_only=transpose_dims
        ))

        request = BatchRequest()
        offset = (0, 20, 33)
        request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=Roi(offset, (100, 100, 120)))
        request[ArrayKeys.TEST_ARRAY1] = ArraySpec(roi=Roi((0, 0, 0), (100, 200, 300)))
        request[ArrayKeys.TEST_ARRAY2] = ArraySpec(roi=Roi((0, 100, 250), (100, 100, 50)))
        
        # Get all possble mirror locations
        # possible_mirror_loc = [[49, 50], [70, 29], [100, 86]]
        mirror_combs = [[49, 70, 100],
                        [49, 29, 86],
                        [49, 70, 86],
                        [49, 29, 100],
                        [50, 70, 100],
                        [50, 29, 86],
                        [50, 70, 86],
                        [50, 29, 100]]

        # Create all possible permurations of our transpose dims
        transpose_combinations = list(permutations(transpose_dims, 3))

        # Generate all possible tranposes of all possible mirrors
        possible_loc = np.zeros((len(mirror_combs), len(transpose_combinations), 3))
        for i, point in enumerate(mirror_combs): 
            for j, comb in enumerate(transpose_combinations):
                possible_loc[i, j] = np.array(point)[np.array(comb)]

        with build(pipeline):
            seen_transposed = False
            seen_node = True
            for i in range(100):
                batch = pipeline.request_batch(request)

                if len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1:
                    seen_node = True
                    node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0]

                    # Check if your location is possible
                    assert node.location in possible_loc 
                    seen_transposed = seen_transposed or any(
                        [node.location[dim] != og_point[dim] for dim in range(3)]
                    )
                    assert Roi((0, 20, 33), (100, 100, 120)).contains(batch[GraphKeys.TEST_GRAPH].spec.roi)
                    assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(node.location)

                for (array_key, array) in batch.arrays.items():
                    assert batch.arrays[array_key].data.shape == batch.arrays[array_key].spec.roi.get_shape()
            assert seen_transposed
            assert seen_node
Пример #14
0
    def test_two_disjoint_lines_softmask(self):
        LABEL_RADIUS = 3
        RAW_RADIUS = 3
        # exagerated to show problem
        BLEND_SMOOTHNESS = 10

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(
            Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points)

        # create swc sources
        fused = ArrayKey("FUSED")
        fused_labels = ArrayKey("FUSED_LABELS")
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")
        raw_key_names = ("RAW_A", "RAW_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)
        raw_keys = tuple(ArrayKey(name) for name in raw_key_names)

        # add request
        request = BatchRequest()
        request.add(fused, bb.get_shape())
        request.add(fused_labels, bb.get_shape())
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(raw_keys[0], bb.get_shape())
        request.add(raw_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (data_sources_a + SwcFileSource(
            swc_paths[0], swc_keys[0], PointsSpec(roi=bb)) + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=LABEL_RADIUS,
            ) + RasterizeSkeleton(
                points=swc_keys[0],
                array=raw_keys[0],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=RAW_RADIUS,
            ))

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (data_sources_b + SwcFileSource(
            swc_paths[1], swc_keys[1], PointsSpec(roi=bb)) + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=LABEL_RADIUS,
            ) + RasterizeSkeleton(
                points=swc_keys[1],
                array=raw_keys[1],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=RAW_RADIUS,
            ))
        data_sources = tuple([data_sources_a, data_sources_b
                              ]) + MergeProvider()

        pipeline = data_sources + FusionAugment(
            raw_keys[0],
            raw_keys[1],
            labels_keys[0],
            labels_keys[1],
            fused,
            fused_labels,
            blend_mode="labels_mask",
            blend_smoothness=BLEND_SMOOTHNESS,
            num_blended_objects=0,
        )

        with build(pipeline):
            batch = pipeline.request_batch(request)

        fused_data = batch[fused].data
        fused_data = np.pad(fused_data, (1, ),
                            "constant",
                            constant_values=(0, ))

        a_data = batch[raw_keys[0]].data
        a_data = np.pad(a_data, (1, ), "constant", constant_values=(0, ))

        b_data = batch[raw_keys[1]].data
        b_data = np.pad(b_data, (1, ), "constant", constant_values=(0, ))

        all_data = np.zeros((5, ) + fused_data.shape)
        all_data[0, :, :, :] = fused_data
        all_data[1, :, :, :] = a_data + b_data
        all_data[2, :, :, :] = fused_data - a_data - b_data
        all_data[3, :, :, :] = a_data
        all_data[4, :, :, :] = b_data

        # Uncomment to visualize problem
        if imported_volshow:
            volshow(all_data)
            # input("Press enter when you are done viewing the data: ")

        diff = np.linalg.norm(fused_data - a_data - b_data)
        self.assertAlmostEqual(diff, 0)
Пример #15
0
    def test_rasterize_speed(self):
        # This is worryingly slow for such a small volume (256**3) and only 2
        # straight lines for skeletons.

        LABEL_RADIUS = 3

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points)

        # create swc sources
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)

        # add request
        request = BatchRequest()
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (
            data_sources_a
            + SwcFileSource(swc_paths[0], swc_keys[0], PointsSpec(roi=bb))
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint32, voxel_size=voxel_size
                ),
                radius=LABEL_RADIUS,
            )
        )

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (
            data_sources_b
            + SwcFileSource(swc_paths[1], swc_keys[1], PointsSpec(roi=bb))
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint32, voxel_size=voxel_size
                ),
                radius=LABEL_RADIUS,
            )
        )
        data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider()

        pipeline = data_sources

        t1 = time.time()
        with build(pipeline):
            batch = pipeline.request_batch(request)

        a_data = batch[labels_keys[0]].data
        a_data = np.pad(a_data, (1,), "constant", constant_values=(0,))

        b_data = batch[labels_keys[1]].data
        b_data = np.pad(b_data, (1,), "constant", constant_values=(0,))

        t2 = time.time()
        self.assertLess(t2 - t1, 0.1)
Пример #16
0
    def test_get_neuron_pair(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc_source = PointsKey("SWC_SOURCE")
        ensure_nonempty = PointsKey("ENSURE_NONEMPTY")
        labels_source = ArrayKey("LABELS_SOURCE")
        img_source = ArrayKey("IMG_SOURCE")
        img_swc = PointsKey("IMG_SWC")
        label_swc = PointsKey("LABEL_SWC")
        imgs = ArrayKey("IMGS")
        labels = ArrayKey("LABELS")
        points_a = PointsKey("SKELETON_A")
        points_b = PointsKey("SKELETON_B")
        img_a = ArrayKey("VOLUME_A")
        img_b = ArrayKey("VOLUME_B")
        labels_a = ArrayKey("LABELS_A")
        labels_b = ArrayKey("LABELS_B")

        data_shape = 5
        output_shape = Coordinate((data_shape, data_shape, data_shape))

        # Get points from test swc
        swc_file_source = SwcFileSource(
            path,
            [swc_source, ensure_nonempty],
            [
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
            ],
        )
        # Create an artificial image source by rasterizing the points
        image_source = (SwcFileSource(
            path, [img_swc],
            [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]) +
                        RasterizeSkeleton(
                            points=img_swc,
                            array=img_source,
                            array_spec=ArraySpec(
                                interpolatable=True,
                                dtype=np.uint32,
                                voxel_size=Coordinate((1, 1, 1)),
                            ),
                        ) +
                        BinarizeLabels(labels=img_source, labels_binary=imgs) +
                        GrowLabels(array=imgs, radius=0))
        # Create an artificial label source by rasterizing the points
        label_source = (SwcFileSource(path, [label_swc], [
            PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))
        ]) + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True,
                dtype=np.uint32,
                voxel_size=Coordinate((1, 1, 1)),
            ),
        ) + BinarizeLabels(labels=labels_source, labels_binary=labels) +
                        GrowLabels(array=labels, radius=1))

        skeleton = tuple()
        skeleton += ((swc_file_source, image_source, label_source) +
                     MergeProvider() +
                     RandomLocation(ensure_nonempty=ensure_nonempty,
                                    ensure_centered=True))

        pipeline = skeleton + GetNeuronPair(
            point_source=swc_source,
            nonempty_placeholder=ensure_nonempty,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=(1, 3),
            shift_attempts=100,
            request_attempts=10,
            output_shape=output_shape,
        )

        request = BatchRequest()

        request.add(points_a, output_shape)
        request.add(points_b, output_shape)
        request.add(img_a, output_shape)
        request.add(img_b, output_shape)
        request.add(labels_a, output_shape)
        request.add(labels_b, output_shape)

        with build(pipeline):
            for i in range(10):
                batch = pipeline.request_batch(request)
                assert all([
                    x in batch for x in
                    [points_a, points_b, img_a, img_b, labels_a, labels_b]
                ])

                min_dist = 5
                for a, b in itertools.product(
                        batch[points_a].nodes,
                        batch[points_b].nodes,
                ):
                    min_dist = min(
                        min_dist,
                        np.linalg.norm(a.location - b.location),
                    )

                self.assertLessEqual(min_dist, 3)
                self.assertGreaterEqual(min_dist, 1)