Beispiel #1
0
def test_realistic_valid_examples(example, use_gurobi):
    penalty_attr = "penalty"
    location_attr = "location"
    example_dir = Path(__file__).parent / "mouselight_examples" / "valid" / example

    consensus = PointsKey("CONSENSUS")
    skeletonization = PointsKey("SKELETONIZATION")
    matched = PointsKey("MATCHED")
    matched_with_fallback = PointsKey("MATCHED_WITH_FALLBACK")

    inf_roi = Roi(Coordinate((None,) * 3), Coordinate((None,) * 3))

    request = BatchRequest()
    request[matched] = PointsSpec(roi=inf_roi)
    request[matched_with_fallback] = PointsSpec(roi=inf_roi)
    request[consensus] = PointsSpec(roi=inf_roi)

    pipeline = (
        (
            GraphSource(example_dir / "graph.obj", [skeletonization]),
            GraphSource(example_dir / "tree.obj", [consensus]),
        )
        + MergeProvider()
        + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            expected_edge_len=10,
            match_distance_threshold=76,
            max_gap_crossing=48,
            use_gurobi=use_gurobi,
            location_attr=location_attr,
            penalty_attr=penalty_attr,
        )
        + TopologicalMatcher(
            skeletonization,
            consensus,
            matched_with_fallback,
            expected_edge_len=10,
            match_distance_threshold=76,
            max_gap_crossing=48,
            use_gurobi=use_gurobi,
            location_attr=location_attr,
            penalty_attr=penalty_attr,
            with_fallback=True,
        )
    )

    with build(pipeline):
        batch = pipeline.request_batch(request)
        consensus_ccs = list(batch[consensus].connected_components)
        matched_with_fallback_ccs = list(batch[matched_with_fallback].connected_components)
        matched_ccs = list(batch[matched].connected_components)

        assert len(matched_ccs) == len(consensus_ccs)
Beispiel #2
0
    def test_relabel_components(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points())

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest(
                    {swc: PointsSpec(roi=Roi((0, 1, 0), (11, 10, 1)))}))

        temp_g = nx.DiGraph()
        for point_id, point in batch.points[swc].data.items():
            temp_g.add_node(point.point_id, label_id=point.label_id)
            if (point.parent_id != -1 and point.parent_id != point.point_id
                    and point.parent_id in batch.points[swc].data):
                temp_g.add_edge(point.point_id, point.parent_id)

        previous_label = None
        ccs = list(nx.weakly_connected_components(temp_g))
        self.assertEqual(len(ccs), 3)
        for cc in ccs:
            self.assertEqual(len(cc), 10)
            label = None
            for point_id in cc:
                if label is None:
                    label = temp_g.nodes[point_id]["label_id"]
                    self.assertNotEqual(label, previous_label)
                self.assertEqual(temp_g.nodes[point_id]["label_id"], label)
            previous_label = label
    def test_without_placeholder(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = PointsKey("TEST_POINTS")

        pipeline = (
            PointTestSource3D()
            + RandomLocation(ensure_nonempty=test_points)
            + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi])
            + Snapshot(
                {test_labels: "volumes/labels"},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            )
        )

        with build(pipeline):
            for i in range(100):

                request_size = Coordinate((40, 40, 40))

                request_a = BatchRequest(random_seed=i)
                request_a.add(test_points, request_size)

                request_b = BatchRequest(random_seed=i)
                request_b.add(test_points, request_size)
                request_b.add(test_labels, request_size)

                # No array to provide a voxel size to ElasticAugment
                with pytest.raises(PipelineRequestError):
                    pipeline.request_batch(request_a)
                batch_b = pipeline.request_batch(request_b)

                self.assertIn(test_labels, batch_b)
Beispiel #4
0
    def test_ensure_center_non_zero(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc = PointsKey("SWC")
        img = ArrayKey("IMG")
        pipeline = (SwcFileSource(
            path, [swc], [PointsSpec(roi=Roi((0, 0, 0), (11, 11, 11)))]) +
                    RandomLocation(ensure_nonempty=swc, ensure_centered=True) +
                    RasterizeSkeleton(
                        points=swc,
                        array=img,
                        array_spec=ArraySpec(
                            interpolatable=False,
                            dtype=np.uint32,
                            voxel_size=Coordinate((1, 1, 1)),
                        ),
                    ))

        request = BatchRequest()
        request.add(img, Coordinate((5, 5, 5)))
        request.add(swc, Coordinate((5, 5, 5)))

        with build(pipeline):
            batch = pipeline.request_batch(request)

            data = batch[img].data
            g = batch[swc]
            assert g.num_vertices() > 0

            self.assertNotEqual(data[tuple(np.array(data.shape) // 2)], 0)
Beispiel #5
0
    def test_read_single_swc(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points())

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest(
                    {swc: PointsSpec(roi=Roi((0, 0, 0), (11, 11, 1)))}))

        for point in self._toy_swc_points():
            self.assertCountEqual(
                point.location,
                batch.points[swc].data[point.point_id].location)
Beispiel #6
0
    def test_create_boundary_nodes(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points(),
                        {"resolution": np.array([2, 2, 2])})

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest({swc: PointsSpec(roi=Roi((0, 5, 0), (1, 3, 1)))}))

        temp_g = nx.DiGraph()
        for point_id, point in batch.points[swc].data.items():
            temp_g.add_node(point.point_id,
                            label_id=point.label_id,
                            location=point.location)
            if (point.parent_id != -1 and point.parent_id != point.point_id
                    and point.parent_id in batch.points[swc].data):
                temp_g.add_edge(point.point_id, point.parent_id)
            else:
                root = point.point_id

        current = root
        expected_path = [
            tuple(np.array([0.0, 5.0, 0.0])),
            tuple(np.array([0.0, 6.0, 0.0])),
            tuple(np.array([0.0, 7.0, 0.0])),
        ]
        expected_node_ids = [0, 1, 2]
        path = []
        node_ids = []
        while current is not None:
            node_ids.append(current)
            path.append(tuple(temp_g.nodes[current]["location"]))
            predecessors = list(temp_g._pred[current].keys())
            current = predecessors[0] if len(predecessors) == 1 else None

        self.assertCountEqual(path, expected_path)
        self.assertCountEqual(node_ids, expected_node_ids)
Beispiel #7
0
    def test_grow_labels_speed(self):

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_file = "test_swc.swc"
        swc_path = Path(self.path_to(swc_file))

        swc_points = self._get_points(np.array([1, 1, 1]), np.array([1, 1, 1]),
                                      bb)
        self._write_swc(swc_path, swc_points.graph)

        # create swc sources
        swc_key = PointsKey("SWC")
        labels_key = ArrayKey("LABELS")

        # add request
        request = BatchRequest()
        request.add(labels_key, bb.get_shape())
        request.add(swc_key, bb.get_shape())

        # data source for swc a
        data_source = tuple()
        data_source = (data_source + SwcFileSource(
            swc_path, [swc_key], [PointsSpec(roi=bb)]) + RasterizeSkeleton(
                points=swc_key,
                array=labels_key,
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
            ) + GrowLabels(array=labels_key, radius=3))

        pipeline = data_source

        num_repeats = 10
        t1 = time.time()
        with build(pipeline):
            for i in range(num_repeats):
                pipeline.request_batch(request)

        t2 = time.time()
        self.assertLess((t2 - t1) / num_repeats, 0.1)
Beispiel #8
0
    def test_overlap(self):
        path = Path(self.path_to("test_swc_sources"))
        path.mkdir(parents=True, exist_ok=True)

        # write test swc
        for i in range(3):
            self._write_swc(
                path / "{}.swc".format(i),
                self._toy_swc_points(),
                {"offset": np.array([0, 0, 0])},
            )

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest(
                    {swc: PointsSpec(roi=Roi((0, 0, 0), (11, 11, 1)))}))

        temp_g = nx.DiGraph()
        for point_id, point in batch.points[swc].data.items():
            temp_g.add_node(point.point_id, label_id=point.label_id)
            if (point.parent_id != -1 and point.parent_id != point.point_id
                    and point.parent_id in batch.points[swc].data):
                temp_g.add_edge(point.point_id, point.parent_id)

        previous_label = None
        ccs = list(nx.weakly_connected_components(temp_g))
        self.assertEqual(len(ccs), 3)
        for cc in ccs:
            self.assertEqual(len(cc), 41)
            label = None
            for point_id in cc:
                if label is None:
                    label = temp_g.nodes[point_id]["label_id"]
                    self.assertNotEqual(label, previous_label)
                self.assertEqual(temp_g.nodes[point_id]["label_id"], label)
            previous_label = label
    def test_placeholder(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = PointsKey("TEST_POINTS")

        pipeline = (
            PointTestSource3D()
            + RandomLocation(ensure_nonempty=test_points)
            + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi])
            + Snapshot(
                {test_labels: "volumes/labels"},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            )
        )

        with build(pipeline):
            for i in range(100):

                request_size = Coordinate((40, 40, 40))

                request_a = BatchRequest(random_seed=i)
                request_a.add(test_points, request_size)
                request_a.add(test_labels, request_size, placeholder=True)

                request_b = BatchRequest(random_seed=i)
                request_b.add(test_points, request_size)
                request_b.add(test_labels, request_size)

                batch_a = pipeline.request_batch(request_a)
                batch_b = pipeline.request_batch(request_b)

                points_a = batch_a[test_points].nodes
                points_b = batch_b[test_points].nodes

                for a, b in zip(points_a, points_b):
                    assert all(np.isclose(a.location, b.location))
Beispiel #10
0
    def test_two_disjoint_lines_intensity(self):
        LABEL_RADIUS = 3
        RAW_RADIUS = 3
        BLEND_SMOOTHNESS = 3

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points.to_nx_graph())

        # create swc sources
        fused = ArrayKey("FUSED")
        fused_labels = ArrayKey("FUSED_LABELS")
        fused_swc = PointsKey("FUSED_SWC")
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")
        raw_key_names = ("RAW_A", "RAW_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)
        raw_keys = tuple(ArrayKey(name) for name in raw_key_names)

        # add request
        request = BatchRequest()
        request.add(fused, bb.get_shape())
        request.add(fused_labels, bb.get_shape())
        request.add(fused_swc, bb.get_shape())
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(raw_keys[0], bb.get_shape())
        request.add(raw_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (
            data_sources_a
            + SwcFileSource(swc_paths[0], [swc_keys[0]], [PointsSpec(roi=bb)])
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=labels_keys[0], radii=[LABEL_RADIUS])
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=raw_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=raw_keys[0], radii=[RAW_RADIUS])
            + Normalize(raw_keys[0])
        )

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (
            data_sources_b
            + SwcFileSource(swc_paths[1], [swc_keys[1]], [PointsSpec(roi=bb)])
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=labels_keys[1], radii=[LABEL_RADIUS])
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=raw_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=raw_keys[1], radii=[RAW_RADIUS])
            + Normalize(raw_keys[1])
        )
        data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider()

        pipeline = data_sources + FusionAugment(
            raw_keys[0],
            raw_keys[1],
            labels_keys[0],
            labels_keys[1],
            swc_keys[0],
            swc_keys[1],
            fused,
            fused_labels,
            fused_swc,
            blend_mode="intensity",
            blend_smoothness=BLEND_SMOOTHNESS,
            num_blended_objects=0,
        )

        with build(pipeline):
            batch = pipeline.request_batch(request)

        fused_data = batch[fused].data
        fused_data = np.pad(fused_data, (1,), "constant", constant_values=(0,))

        a_data = batch[raw_keys[0]].data
        a_data = np.pad(a_data, (1,), "constant", constant_values=(0,))

        b_data = batch[raw_keys[1]].data
        b_data = np.pad(b_data, (1,), "constant", constant_values=(0,))

        diff = np.linalg.norm(fused_data - a_data - b_data)
        self.assertAlmostEqual(diff, 0)
Beispiel #11
0
def test_recenter():
    path = Path(self.path_to("test_swc_source.swc"))

    # write test swc
    self._write_swc(path, self._toy_swc_points())

    # read arrays
    swc_source = PointsKey("SWC_SOURCE")
    labels_source = ArrayKey("LABELS_SOURCE")
    img_source = ArrayKey("IMG_SOURCE")
    img_swc = PointsKey("IMG_SWC")
    label_swc = PointsKey("LABEL_SWC")
    imgs = ArrayKey("IMGS")
    labels = ArrayKey("LABELS")
    points_a = PointsKey("SKELETON_A")
    points_b = PointsKey("SKELETON_B")
    img_a = ArrayKey("VOLUME_A")
    img_b = ArrayKey("VOLUME_B")
    labels_a = ArrayKey("LABELS_A")
    labels_b = ArrayKey("LABELS_B")

    # Get points from test swc
    swc_file_source = SwcFileSource(
        path, [swc_source], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
    )
    # Create an artificial image source by rasterizing the points
    image_source = (
        SwcFileSource(
            path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=img_swc,
            array=img_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=img_source, labels_binary=imgs)
        + GrowLabels(array=imgs, radius=0)
    )
    # Create an artificial label source by rasterizing the points
    label_source = (
        SwcFileSource(
            path, [label_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=labels_source, labels_binary=labels)
        + GrowLabels(array=labels, radius=1)
    )

    skeleton = tuple()
    skeleton += (
        (swc_file_source, image_source, label_source)
        + MergeProvider()
        + RandomLocation(ensure_nonempty=swc_source, ensure_centered=True)
    )

    pipeline = (
        skeleton
        + GetNeuronPair(
            point_source=swc_source,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=4,
            shift_attempts=100,
        )
        + Recenter(points_a, img_a, max_offset=4)
        + Recenter(points_b, img_b, max_offset=4)
    )

    request = BatchRequest()

    data_shape = 9

    request.add(points_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(points_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_b, Coordinate((data_shape, data_shape, data_shape)))

    with build(pipeline):
        batch = pipeline.request_batch(request)

    data_a = batch[img_a].data
    assert data_a[tuple(np.array(data_a.shape) // 2)] == 1
    data_a = np.pad(data_a, (1,), "constant", constant_values=(0,))
    data_b = batch[img_b].data
    assert data_b[tuple(np.array(data_b.shape) // 2)] == 1
    data_b = np.pad(data_b, (1,), "constant", constant_values=(0,))

    data_c = data_a + data_b

    data = np.array((data_a, data_b, data_c))

    for _, point in batch[points_a].data.items():
        assert (
            data[(0,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1, its {}".format(
            point.location, data[(0,) + tuple(int(x) for x in point.location)]
        )

    for _, point in batch[points_b].data.items():
        assert (
            data[(1,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1".format(point.location)
Beispiel #12
0
    def test_output_min_distance(self):

        voxel_size = Coordinate((20, 2, 2))

        ArrayKey("GT_VECTORS_MAP_PRESYN")
        PointsKey("PRESYN")
        PointsKey("POSTSYN")

        arraytypes_to_source_target_pointstypes = {
            ArrayKeys.GT_VECTORS_MAP_PRESYN:
            (GraphKeys.PRESYN, GraphKeys.POSTSYN)
        }
        arraytypes_to_stayinside_arraytypes = {
            ArrayKeys.GT_VECTORS_MAP_PRESYN: ArrayKeys.GT_LABELS
        }

        # test for partner criterion 'min_distance'
        radius_phys = 30
        pipeline_min_distance = AddVectorMapTestSource() + AddVectorMap(
            src_and_trg_points=arraytypes_to_source_target_pointstypes,
            voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size},
            radius_phys=radius_phys,
            partner_criterion="min_distance",
            stayinside_array_keys=arraytypes_to_stayinside_arraytypes,
            pad_for_partners=(0, 0, 0),
        )

        with build(pipeline_min_distance):

            request = BatchRequest()
            raw_roi = pipeline_min_distance.spec[ArrayKeys.RAW].roi
            gt_labels_roi = pipeline_min_distance.spec[ArrayKeys.GT_LABELS].roi
            presyn_roi = pipeline_min_distance.spec[GraphKeys.PRESYN].roi

            request.add(ArrayKeys.RAW, raw_roi.get_shape())
            request.add(ArrayKeys.GT_LABELS, gt_labels_roi.get_shape())
            request.add(GraphKeys.PRESYN, presyn_roi.get_shape())
            request.add(GraphKeys.POSTSYN, presyn_roi.get_shape())
            request.add(ArrayKeys.GT_VECTORS_MAP_PRESYN,
                        presyn_roi.get_shape())
            for identifier, spec in request.items():
                spec.roi = spec.roi.shift((1000, 1000, 1000))

            batch = pipeline_min_distance.request_batch(request)

        presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes}
        postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes}
        vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data
        offset_vector_map_presyn = request[
            ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.get_offset()

        self.assertTrue(len(presyn_locs) > 0)
        self.assertTrue(len(postsyn_locs) > 0)

        for loc_id, point in presyn_locs.items():

            if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains(
                    Coordinate(point.location)):
                self.assertTrue(batch.arrays[
                    ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains(
                        Coordinate(point.location)))

                dist_to_loc = {}
                for partner_id in point.attrs["partner_ids"]:
                    if partner_id in postsyn_locs.keys():
                        partner_location = postsyn_locs[partner_id].location
                        dist_to_loc[
                            np.linalg.norm(partner_location -
                                           point.location)] = partner_location
                min_dist = np.min(list(dist_to_loc.keys()))
                relevant_partner_loc = dist_to_loc[min_dist]

                presyn_loc_shifted_vx = (
                    point.location - offset_vector_map_presyn) // voxel_size
                radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size]
                region_to_check = np.clip(
                    [
                        (presyn_loc_shifted_vx - radius_vx),
                        (presyn_loc_shifted_vx + radius_vx),
                    ],
                    a_min=(0, 0, 0),
                    a_max=vector_map_presyn.shape[-3:],
                )
                for x, y, z in itertools.product(
                        range(int(region_to_check[0][0]),
                              int(region_to_check[1][0])),
                        range(int(region_to_check[0][1]),
                              int(region_to_check[1][1])),
                        range(int(region_to_check[0][2]),
                              int(region_to_check[1][2])),
                ):
                    if (np.linalg.norm((np.array(
                        (x, y, z)) - np.asarray(point.location))) <
                            radius_phys):
                        vector = [
                            vector_map_presyn[dim][x, y, z]
                            for dim in range(vector_map_presyn.shape[0])
                        ]
                        if not np.sum(vector) == 0:
                            trg_loc_of_vector_phys = (
                                np.asarray(offset_vector_map_presyn) +
                                (voxel_size * np.array([x, y, z])) +
                                np.asarray(vector))
                            self.assertTrue(
                                np.array_equal(trg_loc_of_vector_phys,
                                               relevant_partner_loc))

        # test for partner criterion 'all'
        pipeline_all = AddVectorMapTestSource() + AddVectorMap(
            src_and_trg_points=arraytypes_to_source_target_pointstypes,
            voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size},
            radius_phys=radius_phys,
            partner_criterion="all",
            stayinside_array_keys=arraytypes_to_stayinside_arraytypes,
            pad_for_partners=(0, 0, 0),
        )

        with build(pipeline_all):
            batch = pipeline_all.request_batch(request)

        presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes}
        postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes}
        vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data
        offset_vector_map_presyn = request[
            ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.get_offset()

        self.assertTrue(len(presyn_locs) > 0)
        self.assertTrue(len(postsyn_locs) > 0)

        for loc_id, point in presyn_locs.items():

            if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains(
                    Coordinate(point.location)):
                self.assertTrue(batch.arrays[
                    ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains(
                        Coordinate(point.location)))

                partner_ids_to_locs_per_src, count_vectors_per_partner = {}, {}
                for partner_id in point.attrs["partner_ids"]:
                    if partner_id in postsyn_locs.keys():
                        partner_ids_to_locs_per_src[partner_id] = postsyn_locs[
                            partner_id].location.tolist()
                        count_vectors_per_partner[partner_id] = 0

                presyn_loc_shifted_vx = (
                    point.location - offset_vector_map_presyn) // voxel_size
                radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size]
                region_to_check = np.clip(
                    [
                        (presyn_loc_shifted_vx - radius_vx),
                        (presyn_loc_shifted_vx + radius_vx),
                    ],
                    a_min=(0, 0, 0),
                    a_max=vector_map_presyn.shape[-3:],
                )
                for x, y, z in itertools.product(
                        range(int(region_to_check[0][0]),
                              int(region_to_check[1][0])),
                        range(int(region_to_check[0][1]),
                              int(region_to_check[1][1])),
                        range(int(region_to_check[0][2]),
                              int(region_to_check[1][2])),
                ):
                    if (np.linalg.norm((np.array(
                        (x, y, z)) - np.asarray(point.location))) <
                            radius_phys):
                        vector = [
                            vector_map_presyn[dim][x, y, z]
                            for dim in range(vector_map_presyn.shape[0])
                        ]
                        if not np.sum(vector) == 0:
                            trg_loc_of_vector_phys = (
                                np.asarray(offset_vector_map_presyn) +
                                (voxel_size * np.array([x, y, z])) +
                                np.asarray(vector))
                            self.assertTrue(trg_loc_of_vector_phys.tolist(
                            ) in partner_ids_to_locs_per_src.values())

                            for (
                                    partner_id,
                                    partner_loc,
                            ) in partner_ids_to_locs_per_src.items():
                                if np.array_equal(
                                        np.asarray(trg_loc_of_vector_phys),
                                        partner_loc):
                                    count_vectors_per_partner[partner_id] += 1
                self.assertTrue(
                    (list(count_vectors_per_partner.values()) -
                     np.min(list(count_vectors_per_partner.values())) <= len(
                         count_vectors_per_partner.keys())).all())
Beispiel #13
0
    def test_two_disjoint_lines_softmask(self):
        LABEL_RADIUS = 3
        RAW_RADIUS = 3
        # exagerated to show problem
        BLEND_SMOOTHNESS = 10

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(
            Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points)

        # create swc sources
        fused = ArrayKey("FUSED")
        fused_labels = ArrayKey("FUSED_LABELS")
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")
        raw_key_names = ("RAW_A", "RAW_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)
        raw_keys = tuple(ArrayKey(name) for name in raw_key_names)

        # add request
        request = BatchRequest()
        request.add(fused, bb.get_shape())
        request.add(fused_labels, bb.get_shape())
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(raw_keys[0], bb.get_shape())
        request.add(raw_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (data_sources_a + SwcFileSource(
            swc_paths[0], swc_keys[0], PointsSpec(roi=bb)) + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=LABEL_RADIUS,
            ) + RasterizeSkeleton(
                points=swc_keys[0],
                array=raw_keys[0],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=RAW_RADIUS,
            ))

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (data_sources_b + SwcFileSource(
            swc_paths[1], swc_keys[1], PointsSpec(roi=bb)) + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=LABEL_RADIUS,
            ) + RasterizeSkeleton(
                points=swc_keys[1],
                array=raw_keys[1],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=RAW_RADIUS,
            ))
        data_sources = tuple([data_sources_a, data_sources_b
                              ]) + MergeProvider()

        pipeline = data_sources + FusionAugment(
            raw_keys[0],
            raw_keys[1],
            labels_keys[0],
            labels_keys[1],
            fused,
            fused_labels,
            blend_mode="labels_mask",
            blend_smoothness=BLEND_SMOOTHNESS,
            num_blended_objects=0,
        )

        with build(pipeline):
            batch = pipeline.request_batch(request)

        fused_data = batch[fused].data
        fused_data = np.pad(fused_data, (1, ),
                            "constant",
                            constant_values=(0, ))

        a_data = batch[raw_keys[0]].data
        a_data = np.pad(a_data, (1, ), "constant", constant_values=(0, ))

        b_data = batch[raw_keys[1]].data
        b_data = np.pad(b_data, (1, ), "constant", constant_values=(0, ))

        all_data = np.zeros((5, ) + fused_data.shape)
        all_data[0, :, :, :] = fused_data
        all_data[1, :, :, :] = a_data + b_data
        all_data[2, :, :, :] = fused_data - a_data - b_data
        all_data[3, :, :, :] = a_data
        all_data[4, :, :, :] = b_data

        # Uncomment to visualize problem
        if imported_volshow:
            volshow(all_data)
            # input("Press enter when you are done viewing the data: ")

        diff = np.linalg.norm(fused_data - a_data - b_data)
        self.assertAlmostEqual(diff, 0)
Beispiel #14
0
    def test_rasterize_speed(self):
        # This is worryingly slow for such a small volume (256**3) and only 2
        # straight lines for skeletons.

        LABEL_RADIUS = 3

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points)

        # create swc sources
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)

        # add request
        request = BatchRequest()
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (
            data_sources_a
            + SwcFileSource(swc_paths[0], swc_keys[0], PointsSpec(roi=bb))
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint32, voxel_size=voxel_size
                ),
                radius=LABEL_RADIUS,
            )
        )

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (
            data_sources_b
            + SwcFileSource(swc_paths[1], swc_keys[1], PointsSpec(roi=bb))
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint32, voxel_size=voxel_size
                ),
                radius=LABEL_RADIUS,
            )
        )
        data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider()

        pipeline = data_sources

        t1 = time.time()
        with build(pipeline):
            batch = pipeline.request_batch(request)

        a_data = batch[labels_keys[0]].data
        a_data = np.pad(a_data, (1,), "constant", constant_values=(0,))

        b_data = batch[labels_keys[1]].data
        b_data = np.pad(b_data, (1,), "constant", constant_values=(0,))

        t2 = time.time()
        self.assertLess(t2 - t1, 0.1)
Beispiel #15
0
    def test_get_neuron_pair(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc_source = PointsKey("SWC_SOURCE")
        ensure_nonempty = PointsKey("ENSURE_NONEMPTY")
        labels_source = ArrayKey("LABELS_SOURCE")
        img_source = ArrayKey("IMG_SOURCE")
        img_swc = PointsKey("IMG_SWC")
        label_swc = PointsKey("LABEL_SWC")
        imgs = ArrayKey("IMGS")
        labels = ArrayKey("LABELS")
        points_a = PointsKey("SKELETON_A")
        points_b = PointsKey("SKELETON_B")
        img_a = ArrayKey("VOLUME_A")
        img_b = ArrayKey("VOLUME_B")
        labels_a = ArrayKey("LABELS_A")
        labels_b = ArrayKey("LABELS_B")

        data_shape = 5
        output_shape = Coordinate((data_shape, data_shape, data_shape))

        # Get points from test swc
        swc_file_source = SwcFileSource(
            path,
            [swc_source, ensure_nonempty],
            [
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
            ],
        )
        # Create an artificial image source by rasterizing the points
        image_source = (SwcFileSource(
            path, [img_swc],
            [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]) +
                        RasterizeSkeleton(
                            points=img_swc,
                            array=img_source,
                            array_spec=ArraySpec(
                                interpolatable=True,
                                dtype=np.uint32,
                                voxel_size=Coordinate((1, 1, 1)),
                            ),
                        ) +
                        BinarizeLabels(labels=img_source, labels_binary=imgs) +
                        GrowLabels(array=imgs, radius=0))
        # Create an artificial label source by rasterizing the points
        label_source = (SwcFileSource(path, [label_swc], [
            PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))
        ]) + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True,
                dtype=np.uint32,
                voxel_size=Coordinate((1, 1, 1)),
            ),
        ) + BinarizeLabels(labels=labels_source, labels_binary=labels) +
                        GrowLabels(array=labels, radius=1))

        skeleton = tuple()
        skeleton += ((swc_file_source, image_source, label_source) +
                     MergeProvider() +
                     RandomLocation(ensure_nonempty=ensure_nonempty,
                                    ensure_centered=True))

        pipeline = skeleton + GetNeuronPair(
            point_source=swc_source,
            nonempty_placeholder=ensure_nonempty,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=(1, 3),
            shift_attempts=100,
            request_attempts=10,
            output_shape=output_shape,
        )

        request = BatchRequest()

        request.add(points_a, output_shape)
        request.add(points_b, output_shape)
        request.add(img_a, output_shape)
        request.add(img_b, output_shape)
        request.add(labels_a, output_shape)
        request.add(labels_b, output_shape)

        with build(pipeline):
            for i in range(10):
                batch = pipeline.request_batch(request)
                assert all([
                    x in batch for x in
                    [points_a, points_b, img_a, img_b, labels_a, labels_b]
                ])

                min_dist = 5
                for a, b in itertools.product(
                        batch[points_a].nodes,
                        batch[points_b].nodes,
                ):
                    min_dist = min(
                        min_dist,
                        np.linalg.norm(a.location - b.location),
                    )

                self.assertLessEqual(min_dist, 3)
                self.assertGreaterEqual(min_dist, 1)
def get_test_data_sources(setup_config):

    input_shape = Coordinate(setup_config["INPUT_SHAPE"])
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    input_size = input_shape * voxel_size

    micron_scale = voxel_size[0]

    # New array keys
    # Note: These are intended to be requested with size input_size
    raw = ArrayKey("RAW")
    consensus = PointsKey("CONSENSUS")
    skeletonization = PointsKey("SKELETONIZATION")
    matched = PointsKey("MATCHED")
    nonempty_placeholder = PointsKey("NONEMPTY")
    labels = ArrayKey("LABELS")

    if setup_config["FUSION_PIPELINE"]:
        ensure_nonempty = nonempty_placeholder
    else:
        ensure_nonempty = consensus

    data_sources = ((
        TestImageSource(
            array=raw,
            array_specs={
                raw:
                ArraySpec(interpolatable=True,
                          voxel_size=voxel_size,
                          dtype=np.uint16)
            },
            size=input_size * 3,
            voxel_size=voxel_size,
        ),
        TestPointSource(
            points=[consensus, nonempty_placeholder],
            directed=True,
            size=input_size * 3,
            num_points=30,
        ),
        TestPointSource(
            points=[skeletonization],
            directed=False,
            size=input_size * 3,
            num_points=333,
        ),
    ) + MergeProvider() + RandomLocation(
        ensure_nonempty=ensure_nonempty,
        ensure_centered=True,
        point_balance_radius=10 * micron_scale,
    ) + TopologicalMatcher(
        skeletonization,
        consensus,
        matched,
        match_distance_threshold=50 * micron_scale,
        max_gap_crossing=30 * micron_scale,
        try_complete=False,
        use_gurobi=True,
    ) + RasterizeSkeleton(
        points=matched,
        array=labels,
        array_spec=ArraySpec(
            interpolatable=False, voxel_size=voxel_size, dtype=np.uint64),
    ) + Normalize(raw))

    return (
        data_sources,
        raw,
        labels,
        consensus,
        nonempty_placeholder,
        skeletonization,
        matched,
    )