예제 #1
0
def get_test_data_sources(setup_config):

    input_shape = Coordinate(setup_config["INPUT_SHAPE"])
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    input_size = input_shape * voxel_size

    micron_scale = voxel_size[0]

    # New array keys
    # Note: These are intended to be requested with size input_size
    raw = ArrayKey("RAW")
    matched = GraphKey("MATCHED")
    nonempty_placeholder = GraphKey("NONEMPTY")
    labels = ArrayKey("LABELS")

    ensure_nonempty = matched

    data_sources = ((
        TestImageSource(
            array=raw,
            array_specs={
                raw:
                ArraySpec(interpolatable=True,
                          voxel_size=voxel_size,
                          dtype=np.uint16)
            },
            size=input_size * 3,
            voxel_size=voxel_size,
        ),
        TestPointSource(
            points=[matched, nonempty_placeholder],
            directed=False,
            size=input_size * 3,
            num_points=333,
        ),
    ) + MergeProvider() + RandomLocation(
        ensure_nonempty=ensure_nonempty,
        ensure_centered=True,
        point_balance_radius=10 * micron_scale,
    ) + RasterizeSkeleton(
        points=matched,
        array=labels,
        array_spec=ArraySpec(
            interpolatable=False, voxel_size=voxel_size, dtype=np.uint64),
    ) + Normalize(raw))

    return (
        data_sources,
        raw,
        labels,
        nonempty_placeholder,
        matched,
    )
예제 #2
0
    def test_merge_basics(self):
        voxel_size = (1, 1, 1)
        GraphKey("PRESYN")
        ArrayKey("GT_LABELS")
        graphsource = GraphTestSource(voxel_size)
        arraysource = ArrayTestSoure(voxel_size)
        pipeline = (graphsource, arraysource) + MergeProvider() + RandomLocation()
        window_request = Coordinate((50, 50, 50))
        with build(pipeline):
            # Check basic merging.
            request = BatchRequest()
            request.add((GraphKeys.PRESYN), window_request)
            request.add((ArrayKeys.GT_LABELS), window_request)
            batch_res = pipeline.request_batch(request)
            self.assertTrue(ArrayKeys.GT_LABELS in batch_res.arrays)
            self.assertTrue(GraphKeys.PRESYN in batch_res.graphs)

            # Check that request of only one source also works.
            request = BatchRequest()
            request.add((GraphKeys.PRESYN), window_request)
            batch_res = pipeline.request_batch(request)
            self.assertFalse(ArrayKeys.GT_LABELS in batch_res.arrays)
            self.assertTrue(GraphKeys.PRESYN in batch_res.graphs)

        # Check that it fails, when having two sources that provide the same type.
        arraysource2 = ArrayTestSoure(voxel_size)
        pipeline_fail = (arraysource, arraysource2) + MergeProvider() + RandomLocation()
        with self.assertRaises(PipelineSetupError):
            with build(pipeline_fail):
                pass
예제 #3
0
def test_6_neighborhood():
    # array keys
    graph = GraphKey("GRAPH")
    neighborhood = ArrayKey("NEIGHBORHOOD")
    neighborhood_mask = ArrayKey("NEIGHBORHOOD_MASK")

    distance = 1

    pipeline = TestSource(graph) + Neighborhood(
        graph,
        neighborhood,
        neighborhood_mask,
        distance,
        array_specs={
            neighborhood: ArraySpec(voxel_size=Coordinate((1, 1, 1))),
            neighborhood_mask: ArraySpec(voxel_size=Coordinate((1, 1, 1))),
        },
        k=6,
    )

    request = BatchRequest()
    request[neighborhood] = ArraySpec(roi=Roi((0, 0, 0), (10, 10, 10)))
    request[neighborhood_mask] = ArraySpec(roi=Roi((0, 0, 0), (10, 10, 10)))

    with build(pipeline):
        batch = pipeline.request_batch(request)
        n_data = batch[neighborhood].data
        n_mask = batch[neighborhood_mask].data
        masked_ind = list(
            set([(0, i, 0) for i in range(10) if i not in [0, 4]] +
                [(i, 5, 0)
                 for i in range(10)] + [(i, 4, 0)
                                        for i in range(10) if i not in [0]]))
        assert all(n_mask[tuple(zip(*masked_ind))]
                   ), f"expected {masked_ind} but saw {np.where(n_mask==1)}"
예제 #4
0
    def test_mirror(self):
        test_graph = GraphKey("TEST_GRAPH")

        pipeline = TestSource() + SimpleAugment(
            mirror_only=[0, 1, 2], transpose_only=[]
        )

        request = BatchRequest()
        request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=Roi((0, 20, 33), (100, 100, 120)))
        possible_loc = [[50, 49], [70, 29], [100, 86]]
        with build(pipeline):
            seen_mirrored = False
            for i in range(100):
                batch = pipeline.request_batch(request)

                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
                node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0]
                logging.debug(node.location)
                assert all(
                    [
                        node.location[dim] in possible_loc[dim] 
                        for dim in range(3)
                    ]
                )
                seen_mirrored = seen_mirrored or any(
                    [node.location[dim] == possible_loc[dim][1] for dim in range(3)]
                )
                assert Roi((0, 20, 33), (100, 100, 120)).contains(batch[GraphKeys.TEST_GRAPH].spec.roi)
                assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(node.location)
            assert seen_mirrored
예제 #5
0
    def test_relabel_components(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc = GraphKey("SWC")
        source = SwcFileSource(path, [swc])

        with build(source):
            batch = source.request_batch(
                BatchRequest({swc:
                              GraphSpec(roi=Roi((0, 1, 5), (11, 10, 1)))}))

        temp_g = batch.points[swc]
        temp_g.relabel_connected_components()

        previous_label = None
        ccs = list(temp_g.connected_components)
        self.assertEqual(len(ccs), 3)
        for cc in ccs:
            self.assertEqual(len(cc), 10)
            label = None
            for point_id in cc:
                if label is None:
                    label = temp_g.node(point_id).attrs["component"]
                    self.assertNotEqual(label, previous_label)
                self.assertEqual(
                    temp_g.node(point_id).attrs["component"], label)
            previous_label = label
예제 #6
0
    def test_without_placeholder(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = GraphKey("TEST_POINTS")

        pipeline = (
            PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) +
            ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) +
            Snapshot(
                {test_labels: "volumes/labels"},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            ))

        with build(pipeline):
            for i in range(2):

                request_size = Coordinate((40, 40, 40))

                request_a = BatchRequest(random_seed=i)
                request_a.add(test_points, request_size)

                request_b = BatchRequest(random_seed=i)
                request_b.add(test_points, request_size)
                request_b.add(test_labels, request_size)

                # No array to provide a voxel size to ElasticAugment
                with pytest.raises(PipelineRequestError):
                    pipeline.request_batch(request_a)
                batch_b = pipeline.request_batch(request_b)

                self.assertIn(test_labels, batch_b)
예제 #7
0
    def test_output(self):

        cropped_roi_raw = Roi((400, 40, 40), (1000, 100, 100))
        cropped_roi_presyn = Roi((800, 80, 80), (800, 80, 80))

        GraphKey("PRESYN")

        pipeline = (
            ExampleSourceCrop()
            + Crop(ArrayKeys.RAW, cropped_roi_raw)
            + Crop(GraphKeys.PRESYN, cropped_roi_presyn)
        )

        with build(pipeline):

            self.assertTrue(pipeline.spec[ArrayKeys.RAW].roi == cropped_roi_raw)
            self.assertTrue(pipeline.spec[GraphKeys.PRESYN].roi == cropped_roi_presyn)

        pipeline = ExampleSourceCrop() + Crop(
            ArrayKeys.RAW,
            fraction_negative=(0.25, 0, 0),
            fraction_positive=(0.25, 0, 0),
        )
        expected_roi_raw = Roi((650, 20, 20), (900, 180, 180))

        with build(pipeline):

            logger.info(pipeline.spec[ArrayKeys.RAW].roi)
            logger.info(expected_roi_raw)
            self.assertTrue(pipeline.spec[ArrayKeys.RAW].roi == expected_roi_raw)
예제 #8
0
    def test_square(self):
        

        test_graph = GraphKey("TEST_GRAPH")
        test_array1 = ArrayKey("TEST_ARRAY1")
        test_array2 = ArrayKey("TEST_ARRAY2")

        pipeline = ((ArrayTestSource(), TestSource()) + MergeProvider() + 
                    Pad(test_array1, None) + Pad(test_array2, None) + Pad(test_graph, None)
                    + SimpleAugment(
            mirror_only=[1,2], transpose_only=[1,2]
        ))

        request = BatchRequest()
        request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=Roi((0, 50, 65), (100, 100, 100)))
        request[ArrayKeys.TEST_ARRAY1] = ArraySpec(roi=Roi((0, 0, 15), (100, 200, 200)))
        request[ArrayKeys.TEST_ARRAY2] = ArraySpec(roi=Roi((0, 50, 65), (100, 100, 100)))

        
        with build(pipeline):
            for i in range(100):
                batch = pipeline.request_batch(request)
                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1

                for (array_key, array) in batch.arrays.items():
                    assert batch.arrays[array_key].data.shape == batch.arrays[array_key].spec.roi.get_shape()
예제 #9
0
    def test_placeholder(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = GraphKey("TEST_POINTS")

        pipeline = (
            PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) +
            ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) +
            Snapshot(
                {test_labels: "volumes/labels"},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            ))

        with build(pipeline):
            for i in range(2):

                request_size = Coordinate((40, 40, 40))

                request_a = BatchRequest(random_seed=i)
                request_a.add(test_points, request_size)
                request_a.add(test_labels, request_size, placeholder=True)

                request_b = BatchRequest(random_seed=i)
                request_b.add(test_points, request_size)
                request_b.add(test_labels, request_size)

                batch_a = pipeline.request_batch(request_a)
                batch_b = pipeline.request_batch(request_b)

                points_a = batch_a[test_points].nodes
                points_b = batch_b[test_points].nodes

                for a, b in zip(points_a, points_b):
                    assert all(np.isclose(a.location, b.location))
예제 #10
0
def test_transpose():
    voxel_size = Coordinate((20, 20))
    graph_key = GraphKey("GRAPH")
    array_key = ArrayKey("ARRAY")
    graph = Graph(
        [Node(id=1, location=np.array([450, 550]))],
        [],
        GraphSpec(roi=Roi((100, 200), (800, 600))),
    )
    data = np.zeros([40, 30])
    data[17, 17] = 1
    array = Array(
        data, ArraySpec(roi=Roi((100, 200), (800, 600)),
                        voxel_size=voxel_size))

    default_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[], transpose_only=[0, 1], transpose_probs=[0, 0]))

    transpose_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[], transpose_only=[0, 1], transpose_probs=[1, 1]))

    request = BatchRequest()
    request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300)))
    request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300)))
    with build(default_pipeline):
        expected_location = [450, 550]
        batch = default_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert (
            batch[array_key].data[node_voxel_index] == 1
        ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"

    with build(transpose_pipeline):
        expected_location = [410, 590]
        batch = transpose_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert (
            batch[array_key].data[node_voxel_index] == 1
        ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
예제 #11
0
    def test_3d_basics(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = GraphKey("TEST_POINTS")
        test_raster = ArrayKey("TEST_RASTER")

        pipeline = (
            PointTestSource3D() + ElasticAugment(
                [10, 10, 10],
                [0.1, 0.1, 0.1],
                # [0, 0, 0], # no jitter
                [0, 2.0 * math.pi],
            ) + RasterizeGraph(
                test_points,
                test_raster,
                settings=RasterizationSettings(radius=2, mode="peak"),
            ) + Snapshot(
                {
                    test_labels: "volumes/labels",
                    test_raster: "volumes/raster"
                },
                dataset_dtypes={test_raster: np.float32},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            ))

        for _ in range(5):

            with build(pipeline):

                request_roi = Roi((-20, -20, -20), (40, 40, 40))

                request = BatchRequest()
                request[test_labels] = ArraySpec(roi=request_roi)
                request[test_points] = GraphSpec(roi=request_roi)
                request[test_raster] = ArraySpec(roi=request_roi)

                batch = pipeline.request_batch(request)
                labels = batch[test_labels]
                points = batch[test_points]

                # the point at (0, 0, 0) should not have moved
                self.assertTrue(points.contains(0))

                labels_data_roi = (
                    labels.spec.roi -
                    labels.spec.roi.get_begin()) / labels.spec.voxel_size

                # points should have moved together with the voxels
                for point in points.nodes:
                    loc = point.location - labels.spec.roi.get_begin()
                    loc = loc / labels.spec.voxel_size
                    loc = Coordinate(int(round(x)) for x in loc)
                    if labels_data_roi.contains(loc):
                        self.assertEqual(labels.data[loc], point.id)
예제 #12
0
    def test_3d(self):

        test_graph = GraphKey("TEST_GRAPH")
        graph_spec = GraphSpec(roi=Roi((0, 0, 0), (5, 5, 5)))
        test_array = ArrayKey("TEST_ARRAY")
        array_spec = ArraySpec(
            roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1))
        )
        test_array2 = ArrayKey("TEST_ARRAY2")
        array2_spec = ArraySpec(
            roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1))
        )

        snapshot_request = BatchRequest()
        snapshot_request.add(test_graph, Coordinate((5, 5, 5)))

        pipeline = ExampleSource(
            [test_graph, test_array, test_array2], [graph_spec, array_spec, array2_spec]
        ) + Snapshot(
            {
                test_graph: "graphs/graph",
                test_array: "volumes/array",
                test_array2: "volumes/array2",
            },
            output_dir=str(self.test_dir),
            every=2,
            additional_request=snapshot_request,
            output_filename="snapshot.hdf",
        )

        snapshot_file_path = Path(self.test_dir, "snapshot.hdf")

        with build(pipeline):

            request = BatchRequest()
            roi = Roi((0, 0, 0), (5, 5, 5))

            request[test_array] = ArraySpec(roi=roi)
            request[test_array2] = ArraySpec(roi=roi)

            pipeline.request_batch(request)

            assert snapshot_file_path.exists()
            f = h5py.File(snapshot_file_path)
            assert f["volumes/array"] is not None
            assert f["graphs/graph-ids"] is not None

            snapshot_file_path.unlink()

            pipeline.request_batch(request)

            assert not snapshot_file_path.exists()
예제 #13
0
    def test_multi_transpose(self):
        test_graph = GraphKey("TEST_GRAPH")
        test_array1 = ArrayKey("TEST_ARRAY1")
        test_array2 = ArrayKey("TEST_ARRAY2")
        point = np.array([50, 70, 100])

        transpose_dims = [0, 1, 2]
        pipeline = (ArrayTestSource(),
                    ExampleSource()) + MergeProvider() + SimpleAugment(
                        mirror_only=[], transpose_only=transpose_dims)

        request = BatchRequest()
        offset = (0, 20, 33)
        request[GraphKeys.TEST_GRAPH] = GraphSpec(
            roi=Roi(offset, (100, 100, 120)))
        request[ArrayKeys.TEST_ARRAY1] = ArraySpec(
            roi=Roi((0, 0, 0), (100, 200, 300)))
        request[ArrayKeys.TEST_ARRAY2] = ArraySpec(
            roi=Roi((0, 100, 250), (100, 100, 50)))

        # Create all possible permurations of our transpose dims
        transpose_combinations = list(permutations(transpose_dims, 3))
        possible_loc = np.zeros((len(transpose_combinations), 3))

        # Transpose points in all possible ways
        for i, comb in enumerate(transpose_combinations):
            possible_loc[i] = point[np.array(comb)]

        with build(pipeline):
            seen_transposed = False
            seen_node = True
            for i in range(100):
                batch = pipeline.request_batch(request)

                if len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1:
                    seen_node = True
                    node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0]

                    assert node.location in possible_loc

                    seen_transposed = seen_transposed or any(
                        [node.location[dim] != point[dim] for dim in range(3)])
                    assert Roi((0, 20, 33), (100, 100, 120)).contains(
                        batch[GraphKeys.TEST_GRAPH].spec.roi)
                    assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(
                        node.location)

                for (array_key, array) in batch.arrays.items():
                    assert batch.arrays[array_key].data.shape == batch.arrays[
                        array_key].spec.roi.get_shape()
            assert seen_transposed
            assert seen_node
예제 #14
0
    def test_pipeline3(self):
        array_key = ArrayKey("TEST_ARRAY")
        points_key = GraphKey("TEST_POINTS")
        voxel_size = Coordinate((1, 1))
        spec = ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"},
                                 array_specs={array_key: spec})
        csv_source = CsvPointsSource(
            self.fake_points_file,
            points_key,
            GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))),
        )

        request = BatchRequest()
        shape = Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = ShiftAugment(prob_slip=0.2,
                                  prob_shift=0.2,
                                  sigma=5,
                                  shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + MergeProvider() +
                    RandomLocation(ensure_nonempty=points_key) + shift_node)
        with build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = list(request[points_key].nodes)
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(
                    result_val,
                    list(result_points),
                    target_vals,
                    self.fake_points,
                ),
            )
예제 #15
0
    def test_output(self):
        """
        Fails due to probabilities being calculated in advance, rather than after creating
        each roi. The new approach does not account for all possible roi's containing
        each point, some of which may not contain its nearest neighbors.
        """

        GraphKey('TEST_POINTS')

        pipeline = (ExampleSourceRandomLocation() + RandomLocation(
            ensure_nonempty=GraphKeys.TEST_POINTS, point_balance_radius=100))

        # count the number of times we get each point
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_POINTS:
                        GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))
                    }))

                points = {
                    node.id: node
                    for node in batch[GraphKeys.TEST_POINTS].nodes
                }

                self.assertTrue(len(points) > 0)
                self.assertTrue((1 in points) != (2 in points or 3 in points),
                                points)

                for node in batch[GraphKeys.TEST_POINTS].nodes:
                    if node.id not in histogram:
                        histogram[node.id] = 1
                    else:
                        histogram[node.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1)
예제 #16
0
    def test_read_single_swc(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc = GraphKey("SWC")
        source = SwcFileSource(path, [swc])

        with build(source):
            batch = source.request_batch(
                BatchRequest({swc:
                              GraphSpec(roi=Roi((0, 0, 5), (11, 11, 1)))}))

        for node in self._toy_swc_points().nodes:
            self.assertCountEqual(node.location,
                                  batch.points[swc].node(node.id).location)
예제 #17
0
    def test_req_full_roi(self):

        GraphKey("TEST_GRAPH")

        possible_roi = Roi((0, 0, 0), (1000, 1000, 1000))

        pipeline = (SourceGraphLocation() +
                    BatchTester(possible_roi, exact=False) +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH))
        with build(pipeline):

            batch = pipeline.request_batch(
                BatchRequest({
                    GraphKeys.TEST_GRAPH:
                    GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000)))
                }))

            assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
예제 #18
0
    def test_roi_one_point(self):

        GraphKey("TEST_GRAPH")
        upstream_roi = Roi((500, 500, 500), (1, 1, 1))

        pipeline = (SourceGraphLocation() +
                    BatchTester(upstream_roi, exact=True) +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH))

        with build(pipeline):
            for i in range(500):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_GRAPH:
                        GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1)))
                    }))

                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
예제 #19
0
    def test_dim_size_1(self):

        GraphKey("TEST_GRAPH")
        upstream_roi = Roi((500, 401, 401), (1, 200, 200))
        pipeline = (SourceGraphLocation() +
                    BatchTester(upstream_roi, exact=False) +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH))

        # count the number of times we get each node
        with build(pipeline):

            for i in range(500):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_GRAPH:
                        GraphSpec(roi=Roi((0, 0, 0), (1, 100, 100)))
                    }))

                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
예제 #20
0
    def test_two_transpose(self):
        test_graph = GraphKey("TEST_GRAPH")
        test_array1 = ArrayKey("TEST_ARRAY1")
        test_array2 = ArrayKey("TEST_ARRAY2")

        transpose_dims = [1, 2]
        pipeline = (ArrayTestSource(),
                    ExampleSource()) + MergeProvider() + SimpleAugment(
                        mirror_only=[], transpose_only=transpose_dims)

        request = BatchRequest()
        request[GraphKeys.TEST_GRAPH] = GraphSpec(
            roi=Roi((0, 20, 33), (100, 100, 120)))
        request[ArrayKeys.TEST_ARRAY1] = ArraySpec(
            roi=Roi((0, 0, 0), (100, 200, 300)))
        request[ArrayKeys.TEST_ARRAY2] = ArraySpec(
            roi=Roi((0, 100, 250), (100, 100, 50)))

        possible_loc = [[50, 50], [70, 100], [100, 70]]
        with build(pipeline):
            seen_transposed = False
            for i in range(100):
                batch = pipeline.request_batch(request)
                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
                node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0]
                logging.debug(node.location)
                assert all([
                    node.location[dim] in possible_loc[dim] for dim in range(3)
                ])
                seen_transposed = seen_transposed or any([
                    node.location[dim] != possible_loc[dim][0]
                    for dim in range(3)
                ])
                assert Roi((0, 20, 33), (100, 100, 120)).contains(
                    batch[GraphKeys.TEST_GRAPH].spec.roi)
                assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(
                    node.location)

                for (array_key, array) in batch.arrays.items():
                    assert batch.arrays[array_key].data.shape == batch.arrays[
                        array_key].spec.roi.get_shape()

            assert seen_transposed
예제 #21
0
    def test_output(self):

        GraphKey("TEST_GRAPH")

        pipeline = TestSourceRandomLocation() + RandomLocation(
            ensure_nonempty=GraphKeys.TEST_GRAPH)

        # count the number of times we get each node
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_GRAPH:
                        GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))
                    }))

                nodes = list(batch[GraphKeys.TEST_GRAPH].nodes)
                node_ids = [v.id for v in nodes]

                self.assertTrue(len(nodes) > 0)
                self.assertTrue(
                    (1 in node_ids) != (2 in node_ids or 3 in node_ids),
                    node_ids,
                )

                for node in batch[GraphKeys.TEST_GRAPH].nodes:
                    if node.id not in histogram:
                        histogram[node.id] = 1
                    else:
                        histogram[node.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1)
예제 #22
0
    def test_equal_probability(self):

        GraphKey('TEST_POINTS')

        pipeline = (ExampleSourceRandomLocation() +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_POINTS))

        # count the number of times we get each point
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_POINTS:
                        GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10)))
                    }))

                points = {
                    node.id: node
                    for node in batch[GraphKeys.TEST_POINTS].nodes
                }

                self.assertTrue(len(points) > 0)
                self.assertTrue((1 in points) != (2 in points or 3 in points),
                                points)

                for point in batch[GraphKeys.TEST_POINTS].nodes:
                    if point.id not in histogram:
                        histogram[point.id] = 1
                    else:
                        histogram[point.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1)
예제 #23
0
def test_filter_components():
    raw = GraphKey("RAW")

    pipeline = TestSource() + FilterComponents(raw, 100,
                                               Coordinate((10, 10, 10)))

    request_no_fallback = BatchRequest()
    request_no_fallback[raw] = GraphSpec(roi=Roi((0, 0, 0), (20, 20, 20)))

    with build(pipeline):
        batch = pipeline.request_batch(request_no_fallback)
        assert raw in batch
        assert len(list(batch[raw].connected_components)) == 1

    request_fallback = BatchRequest()
    request_fallback[raw] = GraphSpec(roi=Roi((20, 20, 20), (20, 20, 20)))

    with build(pipeline):
        batch = pipeline.request_batch(request_fallback)
        assert raw in batch
        assert len(list(batch[raw].connected_components)) == 0
예제 #24
0
    def test_keep_node_ids(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(
            path,
            self._toy_swc_points().to_nx_graph(),
            {"resolution": np.array([2, 2, 2])},
        )

        # read arrays
        swc = GraphKey("SWC")
        source = SwcFileSource(path, [swc], keep_ids=True)

        with build(source):
            batch = source.request_batch(
                BatchRequest({swc: GraphSpec(roi=Roi((0, 5, 10), (1, 2, 1)))}))

        temp_g = batch.points[swc]

        # root is only node with in_degree 0
        current = [n for n, d in temp_g.in_degree() if d == 0][0]

        # edge nodes can't keep the same id in case one node has multiple children
        # in the roi.
        expected_path = [
            tuple(np.array([0.0, 5.0, 10.0])),
            tuple(np.array([0.0, 6.0, 10.0])),
            tuple(np.array([0.0, 7.0, 10.0])),
        ]
        path = []
        while current is not None:
            current_node = temp_g.node(current)
            path.append(tuple(current_node.location))
            successors = list(temp_g.successors(current_node))
            current = successors[0] if len(successors) == 1 else None

        for a, b in zip(path, expected_path):
            assert all(np.isclose(a, b))
예제 #25
0
    def test_output(self):

        graph = GraphKey("TEST_GRAPH")
        labels = ArrayKey("TEST_LABELS")

        pipeline = (TestSourcePad() +
                    Pad(labels, Coordinate((20, 20, 20)), value=1) +
                    Pad(graph, Coordinate((10, 10, 10))))

        with build(pipeline):

            self.assertTrue(
                pipeline.spec[labels].roi == Roi((180, 0, 0), (1840, 220,
                                                               220)))
            self.assertTrue(
                pipeline.spec[graph].roi == Roi((190, 10, 10), (1820, 200,
                                                                200)))

            batch = pipeline.request_batch(
                BatchRequest(
                    {labels: ArraySpec(Roi((180, 0, 0), (20, 20, 20)))}))

            self.assertEqual(np.sum(batch.arrays[labels].data), 1 * 10 * 10)
예제 #26
0
    def test_overlap(self):
        path = Path(self.path_to("test_swc_sources"))
        path.mkdir(parents=True, exist_ok=True)

        # write test swc
        for i in range(3):
            self._write_swc(
                path / "{}.swc".format(i),
                self._toy_swc_points().to_nx_graph(),
                {"offset": np.array([0, i, 0])},
            )

        # read arrays
        swc = GraphKey("SWC")
        source = SwcFileSource(path, [swc])

        with build(source):
            batch = source.request_batch(
                BatchRequest({swc:
                              GraphSpec(roi=Roi((0, 0, 5), (11, 13, 1)))}))

        temp_g = batch.points[swc]
        temp_g.relabel_connected_components()

        previous_label = None
        ccs = list(temp_g.connected_components)
        self.assertEqual(len(ccs), 3)
        for cc in ccs:
            self.assertEqual(len(cc), 41)
            label = None
            for point_id in cc:
                if label is None:
                    label = temp_g.node(point_id).attrs["component"]
                    self.assertNotEqual(label, previous_label)
                self.assertEqual(
                    temp_g.node(point_id).attrs["component"], label)
            previous_label = label
예제 #27
0
    def test_output_min_distance(self):

        voxel_size = Coordinate((20, 2, 2))

        ArrayKey("GT_VECTORS_MAP_PRESYN")
        GraphKey("PRESYN")
        GraphKey("POSTSYN")

        arraytypes_to_source_target_pointstypes = {
            ArrayKeys.GT_VECTORS_MAP_PRESYN: (GraphKeys.PRESYN, GraphKeys.POSTSYN)
        }
        arraytypes_to_stayinside_arraytypes = {
            ArrayKeys.GT_VECTORS_MAP_PRESYN: ArrayKeys.GT_LABELS
        }

        # test for partner criterion 'min_distance'
        radius_phys = 30
        pipeline_min_distance = AddVectorMapTestSource() + AddVectorMap(
            src_and_trg_points=arraytypes_to_source_target_pointstypes,
            voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size},
            radius_phys=radius_phys,
            partner_criterion="min_distance",
            stayinside_array_keys=arraytypes_to_stayinside_arraytypes,
            pad_for_partners=(0, 0, 0),
        )

        with build(pipeline_min_distance):

            request = BatchRequest()
            raw_roi = pipeline_min_distance.spec[ArrayKeys.RAW].roi
            gt_labels_roi = pipeline_min_distance.spec[ArrayKeys.GT_LABELS].roi
            presyn_roi = pipeline_min_distance.spec[GraphKeys.PRESYN].roi

            request.add(ArrayKeys.RAW, raw_roi.get_shape())
            request.add(ArrayKeys.GT_LABELS, gt_labels_roi.get_shape())
            request.add(GraphKeys.PRESYN, presyn_roi.get_shape())
            request.add(GraphKeys.POSTSYN, presyn_roi.get_shape())
            request.add(ArrayKeys.GT_VECTORS_MAP_PRESYN, presyn_roi.get_shape())
            for identifier, spec in request.items():
                spec.roi = spec.roi.shift((1000, 1000, 1000))

            batch = pipeline_min_distance.request_batch(request)

        presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes}
        postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes}
        vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data
        offset_vector_map_presyn = request[
            ArrayKeys.GT_VECTORS_MAP_PRESYN
        ].roi.get_offset()

        self.assertTrue(len(presyn_locs) > 0)
        self.assertTrue(len(postsyn_locs) > 0)

        for loc_id, point in presyn_locs.items():

            if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains(
                Coordinate(point.location)
            ):
                self.assertTrue(
                    batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains(
                        Coordinate(point.location)
                    )
                )

                dist_to_loc = {}
                for partner_id in point.attrs["partner_ids"]:
                    if partner_id in postsyn_locs.keys():
                        partner_location = postsyn_locs[partner_id].location
                        dist_to_loc[
                            np.linalg.norm(partner_location - point.location)
                        ] = partner_location
                min_dist = np.min(list(dist_to_loc.keys()))
                relevant_partner_loc = dist_to_loc[min_dist]

                presyn_loc_shifted_vx = (
                    point.location - offset_vector_map_presyn
                ) // voxel_size
                radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size]
                region_to_check = np.clip(
                    [
                        (presyn_loc_shifted_vx - radius_vx),
                        (presyn_loc_shifted_vx + radius_vx),
                    ],
                    a_min=(0, 0, 0),
                    a_max=vector_map_presyn.shape[-3:],
                )
                for x, y, z in itertools.product(
                    range(int(region_to_check[0][0]), int(region_to_check[1][0])),
                    range(int(region_to_check[0][1]), int(region_to_check[1][1])),
                    range(int(region_to_check[0][2]), int(region_to_check[1][2])),
                ):
                    if (
                        np.linalg.norm(
                            (np.array((x, y, z)) - np.asarray(point.location))
                        )
                        < radius_phys
                    ):
                        vector = [
                            vector_map_presyn[dim][x, y, z]
                            for dim in range(vector_map_presyn.shape[0])
                        ]
                        if not np.sum(vector) == 0:
                            trg_loc_of_vector_phys = (
                                np.asarray(offset_vector_map_presyn)
                                + (voxel_size * np.array([x, y, z]))
                                + np.asarray(vector)
                            )
                            self.assertTrue(
                                np.array_equal(
                                    trg_loc_of_vector_phys, relevant_partner_loc
                                )
                            )

        # test for partner criterion 'all'
        pipeline_all = AddVectorMapTestSource() + AddVectorMap(
            src_and_trg_points=arraytypes_to_source_target_pointstypes,
            voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size},
            radius_phys=radius_phys,
            partner_criterion="all",
            stayinside_array_keys=arraytypes_to_stayinside_arraytypes,
            pad_for_partners=(0, 0, 0),
        )

        with build(pipeline_all):
            batch = pipeline_all.request_batch(request)

        presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes}
        postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes}
        vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data
        offset_vector_map_presyn = request[
            ArrayKeys.GT_VECTORS_MAP_PRESYN
        ].roi.get_offset()

        self.assertTrue(len(presyn_locs) > 0)
        self.assertTrue(len(postsyn_locs) > 0)

        for loc_id, point in presyn_locs.items():

            if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains(
                Coordinate(point.location)
            ):
                self.assertTrue(
                    batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains(
                        Coordinate(point.location)
                    )
                )

                partner_ids_to_locs_per_src, count_vectors_per_partner = {}, {}
                for partner_id in point.attrs["partner_ids"]:
                    if partner_id in postsyn_locs.keys():
                        partner_ids_to_locs_per_src[partner_id] = postsyn_locs[
                            partner_id
                        ].location.tolist()
                        count_vectors_per_partner[partner_id] = 0

                presyn_loc_shifted_vx = (
                    point.location - offset_vector_map_presyn
                ) // voxel_size
                radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size]
                region_to_check = np.clip(
                    [
                        (presyn_loc_shifted_vx - radius_vx),
                        (presyn_loc_shifted_vx + radius_vx),
                    ],
                    a_min=(0, 0, 0),
                    a_max=vector_map_presyn.shape[-3:],
                )
                for x, y, z in itertools.product(
                    range(int(region_to_check[0][0]), int(region_to_check[1][0])),
                    range(int(region_to_check[0][1]), int(region_to_check[1][1])),
                    range(int(region_to_check[0][2]), int(region_to_check[1][2])),
                ):
                    if (
                        np.linalg.norm(
                            (np.array((x, y, z)) - np.asarray(point.location))
                        )
                        < radius_phys
                    ):
                        vector = [
                            vector_map_presyn[dim][x, y, z]
                            for dim in range(vector_map_presyn.shape[0])
                        ]
                        if not np.sum(vector) == 0:
                            trg_loc_of_vector_phys = (
                                np.asarray(offset_vector_map_presyn)
                                + (voxel_size * np.array([x, y, z]))
                                + np.asarray(vector)
                            )
                            self.assertTrue(
                                trg_loc_of_vector_phys.tolist()
                                in partner_ids_to_locs_per_src.values()
                            )

                            for (
                                partner_id,
                                partner_loc,
                            ) in partner_ids_to_locs_per_src.items():
                                if np.array_equal(
                                    np.asarray(trg_loc_of_vector_phys), partner_loc
                                ):
                                    count_vectors_per_partner[partner_id] += 1
                self.assertTrue(
                    (
                        list(count_vectors_per_partner.values())
                        - np.min(list(count_vectors_per_partner.values()))
                        <= len(count_vectors_per_partner.keys())
                    ).all()
                )
예제 #28
0
    def test_fast_transform_no_recompute(self):
        test_labels = ArrayKey("TEST_LABELS")
        test_points = GraphKey("TEST_POINTS")
        test_raster = ArrayKey("TEST_RASTER")
        fast_pipeline = (DensePointTestSource3D() + ElasticAugment(
            [10, 10, 10],
            [0.1, 0.1, 0.1],
            [0, 2.0 * math.pi],
            use_fast_points_transform=True,
            recompute_missing_points=False,
        ) + RasterizeGraph(
            test_points,
            test_raster,
            settings=RasterizationSettings(radius=2, mode="peak"),
        ))

        reference_pipeline = (
            DensePointTestSource3D() +
            ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) +
            RasterizeGraph(
                test_points,
                test_raster,
                settings=RasterizationSettings(radius=2, mode="peak"),
            ))

        timings = []
        for i in range(5):
            points_fast = {}
            points_reference = {}
            # seed chosen specifically to make this test fail
            seed = i + 15
            with build(fast_pipeline):

                request_roi = Roi((0, 0, 0), (40, 40, 40))

                request = BatchRequest(random_seed=seed)
                request[test_labels] = ArraySpec(roi=request_roi)
                request[test_points] = GraphSpec(roi=request_roi)
                request[test_raster] = ArraySpec(roi=request_roi)

                t1_fast = time.time()
                batch = fast_pipeline.request_batch(request)
                t2_fast = time.time()
                points_fast = {
                    node.id: node
                    for node in batch[test_points].nodes
                }

            with build(reference_pipeline):

                request_roi = Roi((0, 0, 0), (40, 40, 40))

                request = BatchRequest(random_seed=seed)
                request[test_labels] = ArraySpec(roi=request_roi)
                request[test_points] = GraphSpec(roi=request_roi)
                request[test_raster] = ArraySpec(roi=request_roi)

                t1_ref = time.time()
                batch = reference_pipeline.request_batch(request)
                t2_ref = time.time()
                points_reference = {
                    node.id: node
                    for node in batch[test_points].nodes
                }

            timings.append((t2_fast - t1_fast, t2_ref - t1_ref))
            diffs = []
            missing = 0
            for point_id, point in points_reference.items():
                if point_id not in points_fast:
                    missing += 1
                    continue
                diff = point.location - points_fast[point_id].location
                diffs.append(tuple(diff))
                self.assertAlmostEqual(
                    np.linalg.norm(diff),
                    0,
                    delta=1,
                    msg=
                    "fast transform returned location {} but expected {} for point {}"
                    .format(point.location, points_fast[point_id].location,
                            point_id),
                )

            t_fast, t_ref = [np.mean(x) for x in zip(*timings)]
            self.assertLess(t_fast, t_ref)
            self.assertGreater(missing, 0)
예제 #29
0
    def test_random_seed(self):

        test_labels = ArrayKey('TEST_LABELS')
        test_points = GraphKey('TEST_POINTS')
        test_raster = ArrayKey('TEST_RASTER')

        pipeline = (
            PointTestSource3D() + ElasticAugment(
                [10, 10, 10],
                [0.1, 0.1, 0.1],
                # [0, 0, 0], # no jitter
                [0, 2.0 * math.pi]) +  # rotate randomly
            # [math.pi/4, math.pi/4]) + # rotate by 45 deg
            # [0, 0]) + # no rotation
            RasterizeGraph(test_points,
                           test_raster,
                           settings=RasterizationSettings(radius=2,
                                                          mode='peak')) +
            Snapshot(
                {
                    test_labels: 'volumes/labels',
                    test_raster: 'volumes/raster'
                },
                dataset_dtypes={test_raster: np.float32},
                output_dir=self.path_to(),
                output_filename='elastic_augment_test{id}-{iteration}.hdf'))

        batch_points = []
        for _ in range(5):

            with build(pipeline):

                request_roi = Roi((-20, -20, -20), (40, 40, 40))

                request = BatchRequest(random_seed=10)
                request[test_labels] = ArraySpec(roi=request_roi)
                request[test_points] = GraphSpec(roi=request_roi)
                request[test_raster] = ArraySpec(roi=request_roi)
                batch = pipeline.request_batch(request)
                labels = batch[test_labels]
                points = batch[test_points]
                batch_points.append(
                    tuple((node.id, tuple(node.location))
                          for node in points.nodes))

                # the point at (0, 0, 0) should not have moved
                data = {node.id: node for node in points.nodes}
                self.assertTrue(0 in data)

                labels_data_roi = (
                    labels.spec.roi -
                    labels.spec.roi.get_begin()) / labels.spec.voxel_size

                # points should have moved together with the voxels
                for node in points.nodes:
                    loc = node.location - labels.spec.roi.get_begin()
                    loc = loc / labels.spec.voxel_size
                    loc = Coordinate(int(round(x)) for x in loc)
                    if labels_data_roi.contains(loc):
                        self.assertEqual(labels.data[loc], node.id)

        for point_data in zip(*batch_points):
            self.assertEqual(len(set(point_data)), 1)
예제 #30
0
    def test_ensure_centered(self):
        """
        Expected failure due to emergent behavior of two desired rules:
        1) Points on the upper bound of Roi are not considered contained
        2) When considering a point as a center of a random location,
            scale by the number of points within some delta distance

        if two points are equally likely to be chosen, and centering
        a roi on either of them means the other is on the bounding box
        of the roi, then it can be the case that if the roi is centered
        one of them, the roi contains only that one, but if the roi is
        centered on the second, then both are considered contained,
        breaking the equal likelihood of picking each point.
        """

        GraphKey("TEST_POINTS")

        pipeline = ExampleSourceRandomLocation() + RandomLocation(
            ensure_nonempty=GraphKeys.TEST_POINTS, ensure_centered=True)

        # count the number of times we get each point
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_POINTS:
                        GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))
                    }))

                points = batch[GraphKeys.TEST_POINTS].data
                roi = batch[GraphKeys.TEST_POINTS].spec.roi

                locations = tuple(
                    [Coordinate(point.location) for point in points.values()])
                self.assertTrue(
                    Coordinate([50, 50, 50]) in locations,
                    f"locations: {tuple([point.location for point in points.values()])}"
                )

                self.assertTrue(len(points) > 0)
                self.assertTrue((1 in points) != (2 in points or 3 in points),
                                points)

                for point_id in batch[GraphKeys.TEST_POINTS].data.keys():
                    if point_id not in histogram:
                        histogram[point_id] = 1
                    else:
                        histogram[node.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1,
                                       histogram)