Example #1
0
def test_realistic_valid_examples(example, use_gurobi):
    penalty_attr = "penalty"
    location_attr = "location"
    example_dir = Path(__file__).parent / "mouselight_examples" / "valid" / example

    consensus = PointsKey("CONSENSUS")
    skeletonization = PointsKey("SKELETONIZATION")
    matched = PointsKey("MATCHED")
    matched_with_fallback = PointsKey("MATCHED_WITH_FALLBACK")

    inf_roi = Roi(Coordinate((None,) * 3), Coordinate((None,) * 3))

    request = BatchRequest()
    request[matched] = PointsSpec(roi=inf_roi)
    request[matched_with_fallback] = PointsSpec(roi=inf_roi)
    request[consensus] = PointsSpec(roi=inf_roi)

    pipeline = (
        (
            GraphSource(example_dir / "graph.obj", [skeletonization]),
            GraphSource(example_dir / "tree.obj", [consensus]),
        )
        + MergeProvider()
        + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            expected_edge_len=10,
            match_distance_threshold=76,
            max_gap_crossing=48,
            use_gurobi=use_gurobi,
            location_attr=location_attr,
            penalty_attr=penalty_attr,
        )
        + TopologicalMatcher(
            skeletonization,
            consensus,
            matched_with_fallback,
            expected_edge_len=10,
            match_distance_threshold=76,
            max_gap_crossing=48,
            use_gurobi=use_gurobi,
            location_attr=location_attr,
            penalty_attr=penalty_attr,
            with_fallback=True,
        )
    )

    with build(pipeline):
        batch = pipeline.request_batch(request)
        consensus_ccs = list(batch[consensus].connected_components)
        matched_with_fallback_ccs = list(batch[matched_with_fallback].connected_components)
        matched_ccs = list(batch[matched].connected_components)

        assert len(matched_ccs) == len(consensus_ccs)
Example #2
0
    def provide(self, request):

        batch = Batch()

        if PointsKeys.TEST_POINTS in request:
            roi_points = request[PointsKeys.TEST_POINTS].roi

            points = {}
            for i, point in self.points.items():
                if roi_points.contains(point.location):
                    points[i] = copy.deepcopy(point)
            batch.points[PointsKeys.TEST_POINTS] = Points(
                points, PointsSpec(roi=roi_points)
            )

        if ArrayKeys.TEST_LABELS in request:
            roi_array = request[ArrayKeys.TEST_LABELS].roi
            roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size

            data = np.zeros(roi_voxel.get_shape(), dtype=np.uint32)
            data[:, ::2] = 100

            for i, point in self.points.items():
                loc = self.point_to_voxel(roi_array, point.location)
                data[loc] = i

            spec = self.spec[ArrayKeys.TEST_LABELS].copy()
            spec.roi = roi_array
            batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec)

        return batch
Example #3
0
    def prepare(self, request):
        # add "base" and "add" volume to request
        deps = BatchRequest()
        deps[self.raw_base] = request[self.raw_fused]
        deps[self.raw_add] = request[self.raw_fused]

        # enlarge roi for labels to be the same size as the raw data for mask generation
        deps[self.labels_base] = request[self.raw_fused]
        deps[self.labels_add] = request[self.raw_fused]

        # make points optional
        if self.points_fused in request:
            deps[self.points_base] = PointsSpec(roi=request[self.raw_fused].roi)
            deps[self.points_add] = PointsSpec(roi=request[self.raw_fused].roi)

        return deps
Example #4
0
    def test_relabel_components(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points())

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest(
                    {swc: PointsSpec(roi=Roi((0, 1, 0), (11, 10, 1)))}))

        temp_g = nx.DiGraph()
        for point_id, point in batch.points[swc].data.items():
            temp_g.add_node(point.point_id, label_id=point.label_id)
            if (point.parent_id != -1 and point.parent_id != point.point_id
                    and point.parent_id in batch.points[swc].data):
                temp_g.add_edge(point.point_id, point.parent_id)

        previous_label = None
        ccs = list(nx.weakly_connected_components(temp_g))
        self.assertEqual(len(ccs), 3)
        for cc in ccs:
            self.assertEqual(len(cc), 10)
            label = None
            for point_id in cc:
                if label is None:
                    label = temp_g.nodes[point_id]["label_id"]
                    self.assertNotEqual(label, previous_label)
                self.assertEqual(temp_g.nodes[point_id]["label_id"], label)
            previous_label = label
Example #5
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        min_bb = request[self.points].roi.get_begin()
        max_bb = request[self.points].roi.get_end()

        logger.debug(
            "CSV points source got request for %s",
            request[self.points].roi)

        point_filter = np.ones((self.locations.shape[0],), dtype=np.bool)
        for d in range(self.locations.shape[1]):
            point_filter = np.logical_and(point_filter,
                                          self.locations[:, d] >= min_bb[d])
            point_filter = np.logical_and(point_filter,
                                          self.locations[:, d] < max_bb[d])

        points_data = self._get_points(point_filter)
        logger.debug("Points data: %s", points_data)
        logger.debug("Type of point: %s", type(list(points_data.values())[0]))
        points_spec = PointsSpec(roi=request[self.points].roi.copy())

        batch = Batch()
        batch.points[self.points] = Points(points_data, points_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Example #6
0
    def test_ensure_center_non_zero(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc = PointsKey("SWC")
        img = ArrayKey("IMG")
        pipeline = (SwcFileSource(
            path, [swc], [PointsSpec(roi=Roi((0, 0, 0), (11, 11, 11)))]) +
                    RandomLocation(ensure_nonempty=swc, ensure_centered=True) +
                    RasterizeSkeleton(
                        points=swc,
                        array=img,
                        array_spec=ArraySpec(
                            interpolatable=False,
                            dtype=np.uint32,
                            voxel_size=Coordinate((1, 1, 1)),
                        ),
                    ))

        request = BatchRequest()
        request.add(img, Coordinate((5, 5, 5)))
        request.add(swc, Coordinate((5, 5, 5)))

        with build(pipeline):
            batch = pipeline.request_batch(request)

            data = batch[img].data
            g = batch[swc]
            assert g.num_vertices() > 0

            self.assertNotEqual(data[tuple(np.array(data.shape) // 2)], 0)
Example #7
0
    def provide(self, request):

        batch = Batch()

        # have the pixels encode their position
        if ArrayKeys.RAW in request:

            # the z,y,x coordinates of the ROI
            roi = request[ArrayKeys.RAW].roi
            roi_voxel = roi // self.spec[ArrayKeys.RAW].voxel_size
            meshgrids = np.meshgrid(range(roi_voxel.get_begin()[0],
                                          roi_voxel.get_end()[0]),
                                    range(roi_voxel.get_begin()[1],
                                          roi_voxel.get_end()[1]),
                                    range(roi_voxel.get_begin()[2],
                                          roi_voxel.get_end()[2]),
                                    indexing='ij')
            data = meshgrids[0] + meshgrids[1] + meshgrids[2]

            spec = self.spec[ArrayKeys.RAW].copy()
            spec.roi = roi
            batch.arrays[ArrayKeys.RAW] = Array(data, spec)

        if ArrayKeys.GT_LABELS in request:
            roi = request[ArrayKeys.GT_LABELS].roi
            roi_voxel_shape = (
                roi // self.spec[ArrayKeys.GT_LABELS].voxel_size).get_shape()
            data = np.ones(roi_voxel_shape)
            data[roi_voxel_shape[0] // 2:, roi_voxel_shape[1] // 2:, :] = 2
            data[roi_voxel_shape[0] // 2:, -(roi_voxel_shape[1] // 2):, :] = 3
            spec = self.spec[ArrayKeys.GT_LABELS].copy()
            spec.roi = roi
            batch.arrays[ArrayKeys.GT_LABELS] = Array(data, spec)

        if PointsKeys.PRESYN in request:
            data_presyn, data_postsyn = self.__get_pre_and_postsyn_locations(
                roi=request[PointsKeys.PRESYN].roi)
        elif PointsKeys.POSTSYN in request:
            data_presyn, data_postsyn = self.__get_pre_and_postsyn_locations(
                roi=request[PointsKeys.POSTSYN].roi)

        voxel_size_points = self.spec[ArrayKeys.RAW].voxel_size
        for (points_key, spec) in request.points_specs.items():
            if points_key == PointsKeys.PRESYN:
                data = data_presyn
            if points_key == PointsKeys.POSTSYN:
                data = data_postsyn
            batch.points[points_key] = Points(data, PointsSpec(spec.roi))

        return batch
Example #8
0
    def setup(self):

        for identifier in [ArrayKeys.RAW, ArrayKeys.GT_LABELS]:

            self.provides(
                identifier,
                ArraySpec(roi=Roi((1000, 1000, 1000), (400, 400, 400)),
                          voxel_size=(20, 2, 2)))

        for identifier in [PointsKeys.PRESYN, PointsKeys.POSTSYN]:

            self.provides(
                identifier,
                PointsSpec(roi=Roi((1000, 1000, 1000), (400, 400, 400))))
Example #9
0
    def setup(self):

        self._read_points()
        logger.debug("Locations: %s", self.locations)

        if self.points_spec is not None:

            self.provides(self.points, self.points_spec)
            return

        min_bb = Coordinate(np.floor(np.amin(self.locations, 0)))
        max_bb = Coordinate(np.ceil(np.amax(self.locations, 0)) + 1)

        roi = Roi(min_bb, max_bb - min_bb)

        self.provides(self.points, PointsSpec(roi=roi))
Example #10
0
    def setup(self):
        roi = Roi(Coordinate([0] * len(self.size)), self.size)
        for points_key in self.points:
            self.provides(points_key, PointsSpec(roi=roi))

        k = min(self.size)
        point_list = [(i, {
            "location": np.array([i * k / self.num_points] * 3)
        }) for i in range(self.num_points)]
        edge_list = [(i, i + 1, {}) for i in range(self.num_points - 1)]
        if not self.directed:
            edge_list += [(i + 1, i, {}) for i in range(self.num_points - 1)]

        self.graph = SpatialGraph()
        self.graph.add_nodes_from(point_list)
        self.graph.add_edges_from(edge_list)
Example #11
0
    def test_read_single_swc(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points())

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest(
                    {swc: PointsSpec(roi=Roi((0, 0, 0), (11, 11, 1)))}))

        for point in self._toy_swc_points():
            self.assertCountEqual(
                point.location,
                batch.points[swc].data[point.point_id].location)
Example #12
0
    def test_create_boundary_nodes(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points(),
                        {"resolution": np.array([2, 2, 2])})

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest({swc: PointsSpec(roi=Roi((0, 5, 0), (1, 3, 1)))}))

        temp_g = nx.DiGraph()
        for point_id, point in batch.points[swc].data.items():
            temp_g.add_node(point.point_id,
                            label_id=point.label_id,
                            location=point.location)
            if (point.parent_id != -1 and point.parent_id != point.point_id
                    and point.parent_id in batch.points[swc].data):
                temp_g.add_edge(point.point_id, point.parent_id)
            else:
                root = point.point_id

        current = root
        expected_path = [
            tuple(np.array([0.0, 5.0, 0.0])),
            tuple(np.array([0.0, 6.0, 0.0])),
            tuple(np.array([0.0, 7.0, 0.0])),
        ]
        expected_node_ids = [0, 1, 2]
        path = []
        node_ids = []
        while current is not None:
            node_ids.append(current)
            path.append(tuple(temp_g.nodes[current]["location"]))
            predecessors = list(temp_g._pred[current].keys())
            current = predecessors[0] if len(predecessors) == 1 else None

        self.assertCountEqual(path, expected_path)
        self.assertCountEqual(node_ids, expected_node_ids)
Example #13
0
    def test_grow_labels_speed(self):

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_file = "test_swc.swc"
        swc_path = Path(self.path_to(swc_file))

        swc_points = self._get_points(np.array([1, 1, 1]), np.array([1, 1, 1]),
                                      bb)
        self._write_swc(swc_path, swc_points.graph)

        # create swc sources
        swc_key = PointsKey("SWC")
        labels_key = ArrayKey("LABELS")

        # add request
        request = BatchRequest()
        request.add(labels_key, bb.get_shape())
        request.add(swc_key, bb.get_shape())

        # data source for swc a
        data_source = tuple()
        data_source = (data_source + SwcFileSource(
            swc_path, [swc_key], [PointsSpec(roi=bb)]) + RasterizeSkeleton(
                points=swc_key,
                array=labels_key,
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
            ) + GrowLabels(array=labels_key, radius=3))

        pipeline = data_source

        num_repeats = 10
        t1 = time.time()
        with build(pipeline):
            for i in range(num_repeats):
                pipeline.request_batch(request)

        t2 = time.time()
        self.assertLess((t2 - t1) / num_repeats, 0.1)
Example #14
0
    def test_overlap(self):
        path = Path(self.path_to("test_swc_sources"))
        path.mkdir(parents=True, exist_ok=True)

        # write test swc
        for i in range(3):
            self._write_swc(
                path / "{}.swc".format(i),
                self._toy_swc_points(),
                {"offset": np.array([0, 0, 0])},
            )

        # read arrays
        swc = PointsKey("SWC")
        source = SwcFileSource(path, swc)

        with build(source):
            batch = source.request_batch(
                BatchRequest(
                    {swc: PointsSpec(roi=Roi((0, 0, 0), (11, 11, 1)))}))

        temp_g = nx.DiGraph()
        for point_id, point in batch.points[swc].data.items():
            temp_g.add_node(point.point_id, label_id=point.label_id)
            if (point.parent_id != -1 and point.parent_id != point.point_id
                    and point.parent_id in batch.points[swc].data):
                temp_g.add_edge(point.point_id, point.parent_id)

        previous_label = None
        ccs = list(nx.weakly_connected_components(temp_g))
        self.assertEqual(len(ccs), 3)
        for cc in ccs:
            self.assertEqual(len(cc), 41)
            label = None
            for point_id in cc:
                if label is None:
                    label = temp_g.nodes[point_id]["label_id"]
                    self.assertNotEqual(label, previous_label)
                self.assertEqual(temp_g.nodes[point_id]["label_id"], label)
            previous_label = label
Example #15
0
    def setup(self):

        self.points = {
            0: Point([0, 10, 0]),
            1: Point([0, 30, 0]),
            2: Point([0, 50, 0]),
            3: Point([0, 70, 0]),
            4: Point([0, 90, 0]),
        }

        self.provides(
            PointsKeys.TEST_POINTS,
            PointsSpec(roi=Roi((-100, -100, -100), (300, 300, 300))),
        )

        self.provides(
            ArrayKeys.TEST_LABELS,
            ArraySpec(
                roi=Roi((-100, -100, -100), (300, 300, 300)),
                voxel_size=Coordinate((4, 1, 1)),
                interpolatable=False,
            ),
        )
Example #16
0
    def test_two_disjoint_lines_intensity(self):
        LABEL_RADIUS = 3
        RAW_RADIUS = 3
        BLEND_SMOOTHNESS = 3

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points.to_nx_graph())

        # create swc sources
        fused = ArrayKey("FUSED")
        fused_labels = ArrayKey("FUSED_LABELS")
        fused_swc = PointsKey("FUSED_SWC")
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")
        raw_key_names = ("RAW_A", "RAW_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)
        raw_keys = tuple(ArrayKey(name) for name in raw_key_names)

        # add request
        request = BatchRequest()
        request.add(fused, bb.get_shape())
        request.add(fused_labels, bb.get_shape())
        request.add(fused_swc, bb.get_shape())
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(raw_keys[0], bb.get_shape())
        request.add(raw_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (
            data_sources_a
            + SwcFileSource(swc_paths[0], [swc_keys[0]], [PointsSpec(roi=bb)])
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=labels_keys[0], radii=[LABEL_RADIUS])
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=raw_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=raw_keys[0], radii=[RAW_RADIUS])
            + Normalize(raw_keys[0])
        )

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (
            data_sources_b
            + SwcFileSource(swc_paths[1], [swc_keys[1]], [PointsSpec(roi=bb)])
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=labels_keys[1], radii=[LABEL_RADIUS])
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=raw_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=raw_keys[1], radii=[RAW_RADIUS])
            + Normalize(raw_keys[1])
        )
        data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider()

        pipeline = data_sources + FusionAugment(
            raw_keys[0],
            raw_keys[1],
            labels_keys[0],
            labels_keys[1],
            swc_keys[0],
            swc_keys[1],
            fused,
            fused_labels,
            fused_swc,
            blend_mode="intensity",
            blend_smoothness=BLEND_SMOOTHNESS,
            num_blended_objects=0,
        )

        with build(pipeline):
            batch = pipeline.request_batch(request)

        fused_data = batch[fused].data
        fused_data = np.pad(fused_data, (1,), "constant", constant_values=(0,))

        a_data = batch[raw_keys[0]].data
        a_data = np.pad(a_data, (1,), "constant", constant_values=(0,))

        b_data = batch[raw_keys[1]].data
        b_data = np.pad(b_data, (1,), "constant", constant_values=(0,))

        diff = np.linalg.norm(fused_data - a_data - b_data)
        self.assertAlmostEqual(diff, 0)
Example #17
0
def test_recenter():
    path = Path(self.path_to("test_swc_source.swc"))

    # write test swc
    self._write_swc(path, self._toy_swc_points())

    # read arrays
    swc_source = PointsKey("SWC_SOURCE")
    labels_source = ArrayKey("LABELS_SOURCE")
    img_source = ArrayKey("IMG_SOURCE")
    img_swc = PointsKey("IMG_SWC")
    label_swc = PointsKey("LABEL_SWC")
    imgs = ArrayKey("IMGS")
    labels = ArrayKey("LABELS")
    points_a = PointsKey("SKELETON_A")
    points_b = PointsKey("SKELETON_B")
    img_a = ArrayKey("VOLUME_A")
    img_b = ArrayKey("VOLUME_B")
    labels_a = ArrayKey("LABELS_A")
    labels_b = ArrayKey("LABELS_B")

    # Get points from test swc
    swc_file_source = SwcFileSource(
        path, [swc_source], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
    )
    # Create an artificial image source by rasterizing the points
    image_source = (
        SwcFileSource(
            path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=img_swc,
            array=img_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=img_source, labels_binary=imgs)
        + GrowLabels(array=imgs, radius=0)
    )
    # Create an artificial label source by rasterizing the points
    label_source = (
        SwcFileSource(
            path, [label_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=labels_source, labels_binary=labels)
        + GrowLabels(array=labels, radius=1)
    )

    skeleton = tuple()
    skeleton += (
        (swc_file_source, image_source, label_source)
        + MergeProvider()
        + RandomLocation(ensure_nonempty=swc_source, ensure_centered=True)
    )

    pipeline = (
        skeleton
        + GetNeuronPair(
            point_source=swc_source,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=4,
            shift_attempts=100,
        )
        + Recenter(points_a, img_a, max_offset=4)
        + Recenter(points_b, img_b, max_offset=4)
    )

    request = BatchRequest()

    data_shape = 9

    request.add(points_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(points_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_b, Coordinate((data_shape, data_shape, data_shape)))

    with build(pipeline):
        batch = pipeline.request_batch(request)

    data_a = batch[img_a].data
    assert data_a[tuple(np.array(data_a.shape) // 2)] == 1
    data_a = np.pad(data_a, (1,), "constant", constant_values=(0,))
    data_b = batch[img_b].data
    assert data_b[tuple(np.array(data_b.shape) // 2)] == 1
    data_b = np.pad(data_b, (1,), "constant", constant_values=(0,))

    data_c = data_a + data_b

    data = np.array((data_a, data_b, data_c))

    for _, point in batch[points_a].data.items():
        assert (
            data[(0,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1, its {}".format(
            point.location, data[(0,) + tuple(int(x) for x in point.location)]
        )

    for _, point in batch[points_b].data.items():
        assert (
            data[(1,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1".format(point.location)
Example #18
0
    def test_get_neuron_pair(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc_source = PointsKey("SWC_SOURCE")
        ensure_nonempty = PointsKey("ENSURE_NONEMPTY")
        labels_source = ArrayKey("LABELS_SOURCE")
        img_source = ArrayKey("IMG_SOURCE")
        img_swc = PointsKey("IMG_SWC")
        label_swc = PointsKey("LABEL_SWC")
        imgs = ArrayKey("IMGS")
        labels = ArrayKey("LABELS")
        points_a = PointsKey("SKELETON_A")
        points_b = PointsKey("SKELETON_B")
        img_a = ArrayKey("VOLUME_A")
        img_b = ArrayKey("VOLUME_B")
        labels_a = ArrayKey("LABELS_A")
        labels_b = ArrayKey("LABELS_B")

        data_shape = 5
        output_shape = Coordinate((data_shape, data_shape, data_shape))

        # Get points from test swc
        swc_file_source = SwcFileSource(
            path,
            [swc_source, ensure_nonempty],
            [
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
            ],
        )
        # Create an artificial image source by rasterizing the points
        image_source = (SwcFileSource(
            path, [img_swc],
            [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]) +
                        RasterizeSkeleton(
                            points=img_swc,
                            array=img_source,
                            array_spec=ArraySpec(
                                interpolatable=True,
                                dtype=np.uint32,
                                voxel_size=Coordinate((1, 1, 1)),
                            ),
                        ) +
                        BinarizeLabels(labels=img_source, labels_binary=imgs) +
                        GrowLabels(array=imgs, radius=0))
        # Create an artificial label source by rasterizing the points
        label_source = (SwcFileSource(path, [label_swc], [
            PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))
        ]) + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True,
                dtype=np.uint32,
                voxel_size=Coordinate((1, 1, 1)),
            ),
        ) + BinarizeLabels(labels=labels_source, labels_binary=labels) +
                        GrowLabels(array=labels, radius=1))

        skeleton = tuple()
        skeleton += ((swc_file_source, image_source, label_source) +
                     MergeProvider() +
                     RandomLocation(ensure_nonempty=ensure_nonempty,
                                    ensure_centered=True))

        pipeline = skeleton + GetNeuronPair(
            point_source=swc_source,
            nonempty_placeholder=ensure_nonempty,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=(1, 3),
            shift_attempts=100,
            request_attempts=10,
            output_shape=output_shape,
        )

        request = BatchRequest()

        request.add(points_a, output_shape)
        request.add(points_b, output_shape)
        request.add(img_a, output_shape)
        request.add(img_b, output_shape)
        request.add(labels_a, output_shape)
        request.add(labels_b, output_shape)

        with build(pipeline):
            for i in range(10):
                batch = pipeline.request_batch(request)
                assert all([
                    x in batch for x in
                    [points_a, points_b, img_a, img_b, labels_a, labels_b]
                ])

                min_dist = 5
                for a, b in itertools.product(
                        batch[points_a].nodes,
                        batch[points_b].nodes,
                ):
                    min_dist = min(
                        min_dist,
                        np.linalg.norm(a.location - b.location),
                    )

                self.assertLessEqual(min_dist, 3)
                self.assertGreaterEqual(min_dist, 1)
Example #19
0
    def test_two_disjoint_lines_softmask(self):
        LABEL_RADIUS = 3
        RAW_RADIUS = 3
        # exagerated to show problem
        BLEND_SMOOTHNESS = 10

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(
            Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points)

        # create swc sources
        fused = ArrayKey("FUSED")
        fused_labels = ArrayKey("FUSED_LABELS")
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")
        raw_key_names = ("RAW_A", "RAW_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)
        raw_keys = tuple(ArrayKey(name) for name in raw_key_names)

        # add request
        request = BatchRequest()
        request.add(fused, bb.get_shape())
        request.add(fused_labels, bb.get_shape())
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(raw_keys[0], bb.get_shape())
        request.add(raw_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (data_sources_a + SwcFileSource(
            swc_paths[0], swc_keys[0], PointsSpec(roi=bb)) + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=LABEL_RADIUS,
            ) + RasterizeSkeleton(
                points=swc_keys[0],
                array=raw_keys[0],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=RAW_RADIUS,
            ))

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (data_sources_b + SwcFileSource(
            swc_paths[1], swc_keys[1], PointsSpec(roi=bb)) + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=LABEL_RADIUS,
            ) + RasterizeSkeleton(
                points=swc_keys[1],
                array=raw_keys[1],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=RAW_RADIUS,
            ))
        data_sources = tuple([data_sources_a, data_sources_b
                              ]) + MergeProvider()

        pipeline = data_sources + FusionAugment(
            raw_keys[0],
            raw_keys[1],
            labels_keys[0],
            labels_keys[1],
            fused,
            fused_labels,
            blend_mode="labels_mask",
            blend_smoothness=BLEND_SMOOTHNESS,
            num_blended_objects=0,
        )

        with build(pipeline):
            batch = pipeline.request_batch(request)

        fused_data = batch[fused].data
        fused_data = np.pad(fused_data, (1, ),
                            "constant",
                            constant_values=(0, ))

        a_data = batch[raw_keys[0]].data
        a_data = np.pad(a_data, (1, ), "constant", constant_values=(0, ))

        b_data = batch[raw_keys[1]].data
        b_data = np.pad(b_data, (1, ), "constant", constant_values=(0, ))

        all_data = np.zeros((5, ) + fused_data.shape)
        all_data[0, :, :, :] = fused_data
        all_data[1, :, :, :] = a_data + b_data
        all_data[2, :, :, :] = fused_data - a_data - b_data
        all_data[3, :, :, :] = a_data
        all_data[4, :, :, :] = b_data

        # Uncomment to visualize problem
        if imported_volshow:
            volshow(all_data)
            # input("Press enter when you are done viewing the data: ")

        diff = np.linalg.norm(fused_data - a_data - b_data)
        self.assertAlmostEqual(diff, 0)
Example #20
0
    def test_rasterize_speed(self):
        # This is worryingly slow for such a small volume (256**3) and only 2
        # straight lines for skeletons.

        LABEL_RADIUS = 3

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points)

        # create swc sources
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)

        # add request
        request = BatchRequest()
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (
            data_sources_a
            + SwcFileSource(swc_paths[0], swc_keys[0], PointsSpec(roi=bb))
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint32, voxel_size=voxel_size
                ),
                radius=LABEL_RADIUS,
            )
        )

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (
            data_sources_b
            + SwcFileSource(swc_paths[1], swc_keys[1], PointsSpec(roi=bb))
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint32, voxel_size=voxel_size
                ),
                radius=LABEL_RADIUS,
            )
        )
        data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider()

        pipeline = data_sources

        t1 = time.time()
        with build(pipeline):
            batch = pipeline.request_batch(request)

        a_data = batch[labels_keys[0]].data
        a_data = np.pad(a_data, (1,), "constant", constant_values=(0,))

        b_data = batch[labels_keys[1]].data
        b_data = np.pad(b_data, (1,), "constant", constant_values=(0,))

        t2 = time.time()
        self.assertLess(t2 - t1, 0.1)