コード例 #1
0
ファイル: simple_augment.py プロジェクト: yajivunev/gunpowder
def test_mirror():
    voxel_size = Coordinate((20, 20))
    graph_key = GraphKey("GRAPH")
    array_key = ArrayKey("ARRAY")
    graph = Graph(
        [Node(id=1, location=np.array([450, 550]))],
        [],
        GraphSpec(roi=Roi((100, 200), (800, 600))),
    )
    data = np.zeros([40, 30])
    data[17, 17] = 1
    array = Array(
        data, ArraySpec(roi=Roi((100, 200), (800, 600)),
                        voxel_size=voxel_size))

    default_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[0, 1], transpose_only=[], mirror_probs=[0, 0]))

    mirror_pipeline = (
        (GraphSource(graph_key, graph), ArraySource(array_key, array)) +
        MergeProvider() + SimpleAugment(
            mirror_only=[0, 1], transpose_only=[], mirror_probs=[1, 1]))

    request = BatchRequest()
    request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300)))
    request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300)))
    with build(default_pipeline):
        expected_location = [450, 550]
        batch = default_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert batch[array_key].data[node_voxel_index] == 1

    with build(mirror_pipeline):
        expected_location = [550, 750]
        batch = mirror_pipeline.request_batch(request)

        assert len(list(batch[graph_key].nodes)) == 1
        node = list(batch[graph_key].nodes)[0]
        assert all(np.isclose(node.location, expected_location))
        node_voxel_index = Coordinate(
            (node.location - batch[array_key].spec.roi.get_offset()) /
            voxel_size)
        assert (
            batch[array_key].data[node_voxel_index] == 1
        ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
コード例 #2
0
def test_filter_components():
    raw = GraphKey("RAW")

    pipeline = TestSource() + FilterComponents(raw, 100,
                                               Coordinate((10, 10, 10)))

    request_no_fallback = BatchRequest()
    request_no_fallback[raw] = GraphSpec(roi=Roi((0, 0, 0), (20, 20, 20)))

    with build(pipeline):
        batch = pipeline.request_batch(request_no_fallback)
        assert raw in batch
        assert len(list(batch[raw].connected_components)) == 1

    request_fallback = BatchRequest()
    request_fallback[raw] = GraphSpec(roi=Roi((20, 20, 20), (20, 20, 20)))

    with build(pipeline):
        batch = pipeline.request_batch(request_fallback)
        assert raw in batch
        assert len(list(batch[raw].connected_components)) == 0
コード例 #3
0
ファイル: simple_augment.py プロジェクト: omerbt/gunpowder
    def test_multi_transpose(self):
        test_graph = GraphKey("TEST_GRAPH")
        test_array1 = ArrayKey("TEST_ARRAY1")
        test_array2 = ArrayKey("TEST_ARRAY2")
        point = np.array([50, 70, 100])

        transpose_dims = [0, 1, 2]
        pipeline = (ArrayTestSource(),
                    ExampleSource()) + MergeProvider() + SimpleAugment(
                        mirror_only=[], transpose_only=transpose_dims)

        request = BatchRequest()
        offset = (0, 20, 33)
        request[GraphKeys.TEST_GRAPH] = GraphSpec(
            roi=Roi(offset, (100, 100, 120)))
        request[ArrayKeys.TEST_ARRAY1] = ArraySpec(
            roi=Roi((0, 0, 0), (100, 200, 300)))
        request[ArrayKeys.TEST_ARRAY2] = ArraySpec(
            roi=Roi((0, 100, 250), (100, 100, 50)))

        # Create all possible permurations of our transpose dims
        transpose_combinations = list(permutations(transpose_dims, 3))
        possible_loc = np.zeros((len(transpose_combinations), 3))

        # Transpose points in all possible ways
        for i, comb in enumerate(transpose_combinations):
            possible_loc[i] = point[np.array(comb)]

        with build(pipeline):
            seen_transposed = False
            seen_node = True
            for i in range(100):
                batch = pipeline.request_batch(request)

                if len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1:
                    seen_node = True
                    node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0]

                    assert node.location in possible_loc

                    seen_transposed = seen_transposed or any(
                        [node.location[dim] != point[dim] for dim in range(3)])
                    assert Roi((0, 20, 33), (100, 100, 120)).contains(
                        batch[GraphKeys.TEST_GRAPH].spec.roi)
                    assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(
                        node.location)

                for (array_key, array) in batch.arrays.items():
                    assert batch.arrays[array_key].data.shape == batch.arrays[
                        array_key].spec.roi.get_shape()
            assert seen_transposed
            assert seen_node
コード例 #4
0
ファイル: test_match.py プロジェクト: pattonw/neurolight
def test_realistic_invalid_examples(example, use_gurobi):
    penalty_attr = "penalty"
    location_attr = "location"
    example_dir = Path(__file__).parent / "mouselight_examples" / "invalid" / example

    consensus = PointsKey("CONSENSUS")
    skeletonization = PointsKey("SKELETONIZATION")
    matched = PointsKey("MATCHED")
    matched_with_fallback = PointsKey("MATCHED_WITH_FALLBACK")

    inf_roi = Roi(Coordinate((None,) * 3), Coordinate((None,) * 3))

    request = BatchRequest()
    request[matched] = PointsSpec(roi=inf_roi)
    request[matched_with_fallback] = PointsSpec(roi=inf_roi)

    pipeline = (
        (
            GraphSource(example_dir / "graph.obj", [skeletonization]),
            GraphSource(example_dir / "tree.obj", [consensus]),
        )
        + MergeProvider()
        + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            expected_edge_len=10,
            match_distance_threshold=76,
            max_gap_crossing=48,
            use_gurobi=use_gurobi,
            location_attr=location_attr,
            penalty_attr=penalty_attr,
        )
        + TopologicalMatcher(
            skeletonization,
            consensus,
            matched_with_fallback,
            expected_edge_len=10,
            match_distance_threshold=76,
            max_gap_crossing=48,
            use_gurobi=use_gurobi,
            location_attr=location_attr,
            penalty_attr=penalty_attr,
            with_fallback=True,
        )
    )

    with build(pipeline):
        batch = pipeline.request_batch(request)
        assert matched in batch
        assert len(list(batch[matched].nodes)) == 0
        assert len(list(batch[matched_with_fallback].nodes)) > 0
コード例 #5
0
ファイル: shift_augment.py プロジェクト: yajivunev/gunpowder
    def test_pipeline3(self):
        array_key = ArrayKey("TEST_ARRAY")
        points_key = GraphKey("TEST_POINTS")
        voxel_size = Coordinate((1, 1))
        spec = ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"},
                                 array_specs={array_key: spec})
        csv_source = CsvPointsSource(
            self.fake_points_file,
            points_key,
            GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))),
        )

        request = BatchRequest()
        shape = Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = ShiftAugment(prob_slip=0.2,
                                  prob_shift=0.2,
                                  sigma=5,
                                  shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + MergeProvider() +
                    RandomLocation(ensure_nonempty=points_key) + shift_node)
        with build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = list(request[points_key].nodes)
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(
                    result_val,
                    list(result_points),
                    target_vals,
                    self.fake_points,
                ),
            )
コード例 #6
0
ファイル: simple_augment.py プロジェクト: yajivunev/gunpowder
def test_mismatched_voxel_multiples():
    """
    Ensure we don't shift by half a voxel when transposing 2 axes.

    If voxel_size = [2, 2], and we transpose array of shape [4, 6]:

        center = total_roi.get_center() -> [2, 3]

        # Get distance from center, then transpose
        dist_to_center = center - roi.get_offset() -> [2, 3]
        dist_to_center = transpose(dist_to_center)  -> [3, 2]

        # Using the transposed distance to center, get the offset.
        new_offset = center - dist_to_center -> [-1, 1]

        shape = transpose(shape) -> [6, 4]

        original = ((0, 0), (4, 6))
        transposed = ((-1, 1), (6, 4))

    This result is what we would expect from tranposing, but no longer fits the voxel grid.
    dist_to_center should be limited to multiples of the lcm_voxel_size.

        instead we should get:
        original = ((0, 0), (4, 6))
        transposed = ((0, 0), (6, 4))
    """

    test_array = ArrayKey("TEST_ARRAY")
    data = np.zeros([3, 3])
    data[
        2,
        1] = 1  # voxel has Roi((4, 2) (2, 2)). Contained in Roi((0, 0), (6, 4)). at 2, 1
    source = ArraySource(
        test_array,
        Array(
            data,
            ArraySpec(roi=Roi((0, 0), (6, 6)), voxel_size=(2, 2)),
        ),
    )
    pipeline = source + SimpleAugment(
        mirror_only=[], transpose_only=[0, 1], transpose_probs={(1, 0): 1})

    with build(pipeline):
        request = BatchRequest()
        request[test_array] = ArraySpec(roi=Roi((0, 0), (4, 6)))

        batch = pipeline.request_batch(request)
        data = batch[test_array].data

        assert data[1, 2] == 1, f"{data}"
コード例 #7
0
ファイル: torch_train.py プロジェクト: sailfish009/gunpowder
    def test_precache(self):

        if torch.cuda.is_initialized():
            raise RuntimeError(
                "Cuda is already initialized in the main process! Will not be able "
                "to reinitialize in forked subprocesses.")

        logging.getLogger("gunpowder.torch.nodes.predict").setLevel(
            logging.INFO)

        a = ArrayKey("A")
        pred = ArrayKey("PRED")

        model = TestModel()

        reference_request = BatchRequest()
        reference_request[a] = ArraySpec(roi=Roi((0, 0), (7, 7)))
        reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5)))

        source = TestTorchTrain2DSource()
        predict = Predict(
            model=model,
            inputs={"a": a},
            outputs={0: pred},
            array_specs={pred: ArraySpec()},
        )
        pipeline = source + predict + PreCache(cache_size=3, num_workers=2)

        request = BatchRequest({
            a: ArraySpec(roi=Roi((0, 0), (17, 17))),
            pred: ArraySpec(roi=Roi((0, 0), (15, 15))),
        })

        # train for a couple of iterations
        with build(pipeline):

            batch = pipeline.request_batch(request)
            assert pred in batch
コード例 #8
0
    def test_placeholder(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = PointsKey("TEST_POINTS")

        pipeline = (
            PointTestSource3D()
            + RandomLocation(ensure_nonempty=test_points)
            + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi])
            + Snapshot(
                {test_labels: "volumes/labels"},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            )
        )

        with build(pipeline):
            for i in range(100):

                request_size = Coordinate((40, 40, 40))

                request_a = BatchRequest(random_seed=i)
                request_a.add(test_points, request_size)
                request_a.add(test_labels, request_size, placeholder=True)

                request_b = BatchRequest(random_seed=i)
                request_b.add(test_points, request_size)
                request_b.add(test_labels, request_size)

                batch_a = pipeline.request_batch(request_a)
                batch_b = pipeline.request_batch(request_b)

                points_a = batch_a[test_points].nodes
                points_b = batch_b[test_points].nodes

                for a, b in zip(points_a, points_b):
                    assert all(np.isclose(a.location, b.location))
コード例 #9
0
ファイル: provider_test.py プロジェクト: pattonw/neurolight
    def setUp(self):
        super(ProviderTest, self).setUp()
        # create some common array keys to be used by concrete tests
        ArrayKey("RAW")
        ArrayKey("GT_LABELS")
        ArrayKey("GT_AFFINITIES")
        ArrayKey("GT_AFFINITIES_MASK")
        ArrayKey("GT_MASK")
        ArrayKey("GT_IGNORE")
        ArrayKey("LOSS_SCALE")

        self.test_source = TestSource()
        self.test_request = BatchRequest()
        self.test_request[ArrayKeys.RAW] = ArraySpec(
            roi=Roi((20, 20, 20), (10, 10, 10)))
コード例 #10
0
    def prepare(self, request):
        deps = BatchRequest()

        request[
            self.
            neighborhood_mask].roi, f"Requested {self.neighborhood} and {self.neighborhood_mask} with different roi's"

        request_roi = request[self.neighborhood].roi
        grow_distance = Coordinate(
            (np.ceil(self.distance), ) * len(request_roi.get_shape()))
        request_roi = request_roi.grow(grow_distance, grow_distance)

        deps[self.gt] = GraphSpec(roi=request_roi)

        return deps
コード例 #11
0
def visualize_embedding_pipeline(fusion_pipeline, train_embedding):
    setup_config = DEFAULT_CONFIG
    setup_config["FUSION_PIPELINE"] = fusion_pipeline
    setup_config["TRAIN_EMBEDDING"] = train_embedding
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    output_size = Coordinate(setup_config["OUTPUT_SHAPE"]) * voxel_size
    input_size = Coordinate(setup_config["INPUT_SHAPE"]) * voxel_size
    pipeline, raw, output = embedding_pipeline(setup_config,
                                               get_test_data_sources)
    request = BatchRequest()
    request.add(raw, input_size)
    request.add(output, output_size)
    with build(pipeline):
        pipeline.request_batch(request)
    visualize_hdf5(Path("snapshots/snapshot_1.hdf"), tuple(voxel_size))
コード例 #12
0
ファイル: fusion_augment.py プロジェクト: pattonw/neurolight
    def prepare(self, request):
        # add "base" and "add" volume to request
        deps = BatchRequest()
        deps[self.raw_base] = request[self.raw_fused]
        deps[self.raw_add] = request[self.raw_fused]

        # enlarge roi for labels to be the same size as the raw data for mask generation
        deps[self.labels_base] = request[self.raw_fused]
        deps[self.labels_add] = request[self.raw_fused]

        # make points optional
        if self.points_fused in request:
            deps[self.points_base] = PointsSpec(roi=request[self.raw_fused].roi)
            deps[self.points_add] = PointsSpec(roi=request[self.raw_fused].roi)

        return deps
コード例 #13
0
ファイル: emst.py プロジェクト: pattonw/neurolight
    def prepare(self, request: BatchRequest):
        deps = BatchRequest()

        upstream_dependencies = {
            self.embeddings: self.spec[self.embeddings],
            self.mask: self.spec[self.mask],
        }
        downstream_request = {self.mst: request[self.mst]}
        upstream_dependencies = ProviderSpec(array_specs=upstream_dependencies,
                                             graph_specs=downstream_request)
        upstream_roi = upstream_dependencies.get_common_roi()

        deps[self.embeddings] = ArraySpec(roi=upstream_roi)
        deps[self.mask] = ArraySpec(roi=upstream_roi)

        return deps
コード例 #14
0
    def test_output(self):
        """
        Fails due to probabilities being calculated in advance, rather than after creating
        each roi. The new approach does not account for all possible roi's containing
        each point, some of which may not contain its nearest neighbors.
        """

        GraphKey('TEST_POINTS')

        pipeline = (ExampleSourceRandomLocation() + RandomLocation(
            ensure_nonempty=GraphKeys.TEST_POINTS, point_balance_radius=100))

        # count the number of times we get each point
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_POINTS:
                        GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))
                    }))

                points = {
                    node.id: node
                    for node in batch[GraphKeys.TEST_POINTS].nodes
                }

                self.assertTrue(len(points) > 0)
                self.assertTrue((1 in points) != (2 in points or 3 in points),
                                points)

                for node in batch[GraphKeys.TEST_POINTS].nodes:
                    if node.id not in histogram:
                        histogram[node.id] = 1
                    else:
                        histogram[node.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1)
コード例 #15
0
ファイル: shift_augment.py プロジェクト: yajivunev/gunpowder
    def test_prepare1(self):

        key = ArrayKey("TEST_ARRAY")
        spec = ArraySpec(voxel_size=Coordinate((1, 1)), interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {key: "testdata"},
                                 array_specs={key: spec})

        request = BatchRequest()
        shape = Coordinate((3, 3))
        request.add(key, shape, voxel_size=Coordinate((1, 1)))

        shift_node = ShiftAugment(sigma=1, shift_axis=0)
        with build((hdf5_source + shift_node)):
            shift_node.prepare(request)
            self.assertTrue(shift_node.ndim == 2)
            self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0]))
コード例 #16
0
ファイル: swc_file_source.py プロジェクト: pattonw/neurolight
    def test_read_single_swc(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc = GraphKey("SWC")
        source = SwcFileSource(path, [swc])

        with build(source):
            batch = source.request_batch(
                BatchRequest({swc:
                              GraphSpec(roi=Roi((0, 0, 5), (11, 11, 1)))}))

        for node in self._toy_swc_points().nodes:
            self.assertCountEqual(node.location,
                                  batch.points[swc].node(node.id).location)
コード例 #17
0
    def test_impossible(self):
        a = ArrayKey("A")
        b = ArrayKey("B")
        source_a = TestSourceRandomLocation(a)
        source_b = TestSourceRandomLocation(b)

        pipeline = (source_a, source_b) + \
            MergeProvider() + CustomRandomLocation()

        with build(pipeline):
            with self.assertRaises(AssertionError):
                batch = pipeline.request_batch(
                    BatchRequest({
                        a:
                        ArraySpec(roi=Roi((0, 0, 0), (200, 20, 20))),
                        b:
                        ArraySpec(roi=Roi((1000, 100, 100), (220, 22, 22))),
                    }))
コード例 #18
0
    def test_roi_one_point(self):

        GraphKey("TEST_GRAPH")
        upstream_roi = Roi((500, 500, 500), (1, 1, 1))

        pipeline = (SourceGraphLocation() +
                    BatchTester(upstream_roi, exact=True) +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH))

        with build(pipeline):
            for i in range(500):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_GRAPH:
                        GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1)))
                    }))

                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
コード例 #19
0
ファイル: shift_augment.py プロジェクト: yajivunev/gunpowder
    def test_pipeline2(self):

        key = ArrayKey("TEST_ARRAY")
        spec = ArraySpec(voxel_size=Coordinate((3, 1)), interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {key: "testdata"},
                                 array_specs={key: spec})

        request = BatchRequest()
        shape = Coordinate((3, 3))
        request.add(key, shape, voxel_size=Coordinate((3, 1)))

        shift_node = ShiftAugment(prob_slip=0.2,
                                  prob_shift=0.2,
                                  sigma=1,
                                  shift_axis=0)
        with build((hdf5_source + shift_node)) as b:
            b.request_batch(request)
コード例 #20
0
    def test_req_full_roi(self):

        GraphKey("TEST_GRAPH")

        possible_roi = Roi((0, 0, 0), (1000, 1000, 1000))

        pipeline = (SourceGraphLocation() +
                    BatchTester(possible_roi, exact=False) +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH))
        with build(pipeline):

            batch = pipeline.request_batch(
                BatchRequest({
                    GraphKeys.TEST_GRAPH:
                    GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000)))
                }))

            assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
コード例 #21
0
    def test_dim_size_1(self):

        GraphKey("TEST_GRAPH")
        upstream_roi = Roi((500, 401, 401), (1, 200, 200))
        pipeline = (SourceGraphLocation() +
                    BatchTester(upstream_roi, exact=False) +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH))

        # count the number of times we get each node
        with build(pipeline):

            for i in range(500):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_GRAPH:
                        GraphSpec(roi=Roi((0, 0, 0), (1, 100, 100)))
                    }))

                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
コード例 #22
0
    def test_create_boundary_nodes(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points(),
                        {"resolution": np.array([2, 2, 2])})

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest({swc: PointsSpec(roi=Roi((0, 5, 0), (1, 3, 1)))}))

        temp_g = nx.DiGraph()
        for point_id, point in batch.points[swc].data.items():
            temp_g.add_node(point.point_id,
                            label_id=point.label_id,
                            location=point.location)
            if (point.parent_id != -1 and point.parent_id != point.point_id
                    and point.parent_id in batch.points[swc].data):
                temp_g.add_edge(point.point_id, point.parent_id)
            else:
                root = point.point_id

        current = root
        expected_path = [
            tuple(np.array([0.0, 5.0, 0.0])),
            tuple(np.array([0.0, 6.0, 0.0])),
            tuple(np.array([0.0, 7.0, 0.0])),
        ]
        expected_node_ids = [0, 1, 2]
        path = []
        node_ids = []
        while current is not None:
            node_ids.append(current)
            path.append(tuple(temp_g.nodes[current]["location"]))
            predecessors = list(temp_g._pred[current].keys())
            current = predecessors[0] if len(predecessors) == 1 else None

        self.assertCountEqual(path, expected_path)
        self.assertCountEqual(node_ids, expected_node_ids)
コード例 #23
0
    def test_read_single_swc(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points())

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest(
                    {swc: PointsSpec(roi=Roi((0, 0, 0), (11, 11, 1)))}))

        for point in self._toy_swc_points():
            self.assertCountEqual(
                point.location,
                batch.points[swc].data[point.point_id].location)
コード例 #24
0
    def prepare(self, request: BatchRequest, seed: int,
                direction: Coordinate) -> Tuple[BatchRequest, int]:
        """
        Only request everything with the given seed
        """
        dps = BatchRequest(random_seed=seed)

        if self.nonempty_placeholder is not None:
            # request nonempty placeholder of size request total roi
            # grow such that it can be cropped down to two different locations
            growth = self._get_growth()

            total_roi = request.get_total_roi()
            grown_roi = total_roi.grow(growth, growth)
            dps[self.nonempty_placeholder] = GraphSpec(roi=grown_roi,
                                                       placeholder=True)

        # handle smaller requests
        array_keys = list(request.array_specs.keys())
        voxel_size = self.spec.get_lcm_voxel_size(array_keys)
        direction = Coordinate(direction)
        direction -= Coordinate(
            tuple(np.array(direction) % np.array(voxel_size)))

        if any([points in request for points in self.points]):
            dps[self.point_source] = copy.deepcopy(request[self.points[0]])
            dps[self.point_source].roi = dps[self.point_source].roi.shift(
                direction)
        if any([array in request for array in self.arrays]):
            dps[self.array_source] = copy.deepcopy(request[self.arrays[0]])
            dps[self.array_source].roi = dps[self.array_source].roi.shift(
                direction)
        if any([labels in request for labels in self.labels]):
            dps[self.label_source] = copy.deepcopy(request[self.labels[0]])
            dps[self.label_source].roi = dps[self.label_source].roi.shift(
                direction)

        for source, targets in self.extra_keys.items():
            if targets[0] in request:
                dps[source] = copy.deepcopy(request[targets[0]])
                dps[source].roi = dps[source].roi.shift(direction)

        return dps
コード例 #25
0
    def test_impossible(self):
        a = ArrayKey("A")
        b = ArrayKey("B")
        null_key = ArrayKey("NULL")
        source_a = ExampleSourceRandomLocation(a)
        source_b = ExampleSourceRandomLocation(b)

        pipeline = ((source_a, source_b) + MergeProvider() +
                    CustomRandomLocation(null_key))

        with build(pipeline):
            with self.assertRaises(PipelineRequestError):
                batch = pipeline.request_batch(
                    BatchRequest({
                        a:
                        ArraySpec(roi=Roi((0, 0, 0), (200, 20, 20))),
                        b:
                        ArraySpec(roi=Roi((1000, 100, 100), (220, 22, 22))),
                    }))
コード例 #26
0
ファイル: simple_augment.py プロジェクト: omerbt/gunpowder
def test_mismatched_voxel_multiples():
    """
    Ensure we don't shift by half a voxel when transposing 2 axes.

    If voxel_size = [2, 2], and we transpose array of shape [4, 6]:

        center = total_roi.get_center() -> [2, 3]

        # Get distance from center, then transpose
        dist_to_center = center - roi.get_offset() -> [2, 3]
        dist_to_center = transpose(dist_to_center)  -> [3, 2]

        # Using the tranposed distance to center, get the correct offset.
        new_offset = center - dist_to_center -> [-1, 1]

        shape = transpose(shape) -> [6, 4]

        original = ((0, 0), (4, 6))
        transposed = ((-1, 1), (6, 4))

    This result is what we would expect from tranposing, but no longer fits the voxel grid.
    dist_to_center should be limited to multiples of the lcm_voxel_size.
    """

    test_array = ArrayKey("TEST_ARRAY")

    pipeline = (CornerSource(test_array, voxel_size=(2, 2)) +
                SimpleAugment(transpose_only=[0, 1]))

    request = BatchRequest()
    request[test_array] = ArraySpec(roi=Roi((0, 0), (4, 6)))

    with build(pipeline):
        loop = 100
        while loop > 0:
            loop -= 1

            batch = pipeline.request_batch(request)
            data = batch[test_array].data

            if data.sum(axis=1)[0] == 1:
                loop = -1
        assert loop < 0, "Data was never transposed!"
コード例 #27
0
ファイル: simple_augment.py プロジェクト: omerbt/gunpowder
    def test_two_transpose(self):
        test_graph = GraphKey("TEST_GRAPH")
        test_array1 = ArrayKey("TEST_ARRAY1")
        test_array2 = ArrayKey("TEST_ARRAY2")

        transpose_dims = [1, 2]
        pipeline = (ArrayTestSource(),
                    ExampleSource()) + MergeProvider() + SimpleAugment(
                        mirror_only=[], transpose_only=transpose_dims)

        request = BatchRequest()
        request[GraphKeys.TEST_GRAPH] = GraphSpec(
            roi=Roi((0, 20, 33), (100, 100, 120)))
        request[ArrayKeys.TEST_ARRAY1] = ArraySpec(
            roi=Roi((0, 0, 0), (100, 200, 300)))
        request[ArrayKeys.TEST_ARRAY2] = ArraySpec(
            roi=Roi((0, 100, 250), (100, 100, 50)))

        possible_loc = [[50, 50], [70, 100], [100, 70]]
        with build(pipeline):
            seen_transposed = False
            for i in range(100):
                batch = pipeline.request_batch(request)
                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
                node = list(batch[GraphKeys.TEST_GRAPH].nodes)[0]
                logging.debug(node.location)
                assert all([
                    node.location[dim] in possible_loc[dim] for dim in range(3)
                ])
                seen_transposed = seen_transposed or any([
                    node.location[dim] != possible_loc[dim][0]
                    for dim in range(3)
                ])
                assert Roi((0, 20, 33), (100, 100, 120)).contains(
                    batch[GraphKeys.TEST_GRAPH].spec.roi)
                assert batch[GraphKeys.TEST_GRAPH].spec.roi.contains(
                    node.location)

                for (array_key, array) in batch.arrays.items():
                    assert batch.arrays[array_key].data.shape == batch.arrays[
                        array_key].spec.roi.get_shape()

            assert seen_transposed
コード例 #28
0
    def test_output(self):

        GraphKey("TEST_GRAPH")

        pipeline = TestSourceRandomLocation() + RandomLocation(
            ensure_nonempty=GraphKeys.TEST_GRAPH)

        # count the number of times we get each node
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_GRAPH:
                        GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))
                    }))

                nodes = list(batch[GraphKeys.TEST_GRAPH].nodes)
                node_ids = [v.id for v in nodes]

                self.assertTrue(len(nodes) > 0)
                self.assertTrue(
                    (1 in node_ids) != (2 in node_ids or 3 in node_ids),
                    node_ids,
                )

                for node in batch[GraphKeys.TEST_GRAPH].nodes:
                    if node.id not in histogram:
                        histogram[node.id] = 1
                    else:
                        histogram[node.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1)
コード例 #29
0
    def test_equal_probability(self):

        GraphKey('TEST_POINTS')

        pipeline = (ExampleSourceRandomLocation() +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_POINTS))

        # count the number of times we get each point
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_POINTS:
                        GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10)))
                    }))

                points = {
                    node.id: node
                    for node in batch[GraphKeys.TEST_POINTS].nodes
                }

                self.assertTrue(len(points) > 0)
                self.assertTrue((1 in points) != (2 in points or 3 in points),
                                points)

                for point in batch[GraphKeys.TEST_POINTS].nodes:
                    if point.id not in histogram:
                        histogram[point.id] = 1
                    else:
                        histogram[point.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1)
コード例 #30
0
    def test_grow_labels_speed(self):

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_file = "test_swc.swc"
        swc_path = Path(self.path_to(swc_file))

        swc_points = self._get_points(np.array([1, 1, 1]), np.array([1, 1, 1]),
                                      bb)
        self._write_swc(swc_path, swc_points.graph)

        # create swc sources
        swc_key = PointsKey("SWC")
        labels_key = ArrayKey("LABELS")

        # add request
        request = BatchRequest()
        request.add(labels_key, bb.get_shape())
        request.add(swc_key, bb.get_shape())

        # data source for swc a
        data_source = tuple()
        data_source = (data_source + SwcFileSource(
            swc_path, [swc_key], [PointsSpec(roi=bb)]) + RasterizeSkeleton(
                points=swc_key,
                array=labels_key,
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
            ) + GrowLabels(array=labels_key, radius=3))

        pipeline = data_source

        num_repeats = 10
        t1 = time.time()
        with build(pipeline):
            for i in range(num_repeats):
                pipeline.request_batch(request)

        t2 = time.time()
        self.assertLess((t2 - t1) / num_repeats, 0.1)