예제 #1
0
    def test_ensure_center_non_zero(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc = PointsKey("SWC")
        img = ArrayKey("IMG")
        pipeline = (SwcFileSource(
            path, [swc], [PointsSpec(roi=Roi((0, 0, 0), (11, 11, 11)))]) +
                    RandomLocation(ensure_nonempty=swc, ensure_centered=True) +
                    RasterizeSkeleton(
                        points=swc,
                        array=img,
                        array_spec=ArraySpec(
                            interpolatable=False,
                            dtype=np.uint32,
                            voxel_size=Coordinate((1, 1, 1)),
                        ),
                    ))

        request = BatchRequest()
        request.add(img, Coordinate((5, 5, 5)))
        request.add(swc, Coordinate((5, 5, 5)))

        with build(pipeline):
            batch = pipeline.request_batch(request)

            data = batch[img].data
            g = batch[swc]
            assert g.num_vertices() > 0

            self.assertNotEqual(data[tuple(np.array(data.shape) // 2)], 0)
예제 #2
0
    def test_without_placeholder(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = GraphKey("TEST_POINTS")

        pipeline = (
            PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) +
            ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) +
            Snapshot(
                {test_labels: "volumes/labels"},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            ))

        with build(pipeline):
            for i in range(2):

                request_size = Coordinate((40, 40, 40))

                request_a = BatchRequest(random_seed=i)
                request_a.add(test_points, request_size)

                request_b = BatchRequest(random_seed=i)
                request_b.add(test_points, request_size)
                request_b.add(test_labels, request_size)

                # No array to provide a voxel size to ElasticAugment
                with pytest.raises(PipelineRequestError):
                    pipeline.request_batch(request_a)
                batch_b = pipeline.request_batch(request_b)

                self.assertIn(test_labels, batch_b)
예제 #3
0
    def test_3d(self):

        test_graph = GraphKey("TEST_GRAPH")
        graph_spec = GraphSpec(roi=Roi((0, 0, 0), (5, 5, 5)))
        test_array = ArrayKey("TEST_ARRAY")
        array_spec = ArraySpec(
            roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1))
        )
        test_array2 = ArrayKey("TEST_ARRAY2")
        array2_spec = ArraySpec(
            roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1))
        )

        snapshot_request = BatchRequest()
        snapshot_request.add(test_graph, Coordinate((5, 5, 5)))

        pipeline = ExampleSource(
            [test_graph, test_array, test_array2], [graph_spec, array_spec, array2_spec]
        ) + Snapshot(
            {
                test_graph: "graphs/graph",
                test_array: "volumes/array",
                test_array2: "volumes/array2",
            },
            output_dir=str(self.test_dir),
            every=2,
            additional_request=snapshot_request,
            output_filename="snapshot.hdf",
        )

        snapshot_file_path = Path(self.test_dir, "snapshot.hdf")

        with build(pipeline):

            request = BatchRequest()
            roi = Roi((0, 0, 0), (5, 5, 5))

            request[test_array] = ArraySpec(roi=roi)
            request[test_array2] = ArraySpec(roi=roi)

            pipeline.request_batch(request)

            assert snapshot_file_path.exists()
            f = h5py.File(snapshot_file_path)
            assert f["volumes/array"] is not None
            assert f["graphs/graph-ids"] is not None

            snapshot_file_path.unlink()

            pipeline.request_batch(request)

            assert not snapshot_file_path.exists()
예제 #4
0
    def test_pipeline3(self):
        array_key = ArrayKey("TEST_ARRAY")
        points_key = GraphKey("TEST_POINTS")
        voxel_size = Coordinate((1, 1))
        spec = ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"},
                                 array_specs={array_key: spec})
        csv_source = CsvPointsSource(
            self.fake_points_file,
            points_key,
            GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))),
        )

        request = BatchRequest()
        shape = Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = ShiftAugment(prob_slip=0.2,
                                  prob_shift=0.2,
                                  sigma=5,
                                  shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + MergeProvider() +
                    RandomLocation(ensure_nonempty=points_key) + shift_node)
        with build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = list(request[points_key].nodes)
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(
                    result_val,
                    list(result_points),
                    target_vals,
                    self.fake_points,
                ),
            )
예제 #5
0
def visualize_embedding_pipeline(fusion_pipeline, train_embedding):
    setup_config = DEFAULT_CONFIG
    setup_config["FUSION_PIPELINE"] = fusion_pipeline
    setup_config["TRAIN_EMBEDDING"] = train_embedding
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    output_size = Coordinate(setup_config["OUTPUT_SHAPE"]) * voxel_size
    input_size = Coordinate(setup_config["INPUT_SHAPE"]) * voxel_size
    pipeline, raw, output = embedding_pipeline(setup_config,
                                               get_test_data_sources)
    request = BatchRequest()
    request.add(raw, input_size)
    request.add(output, output_size)
    with build(pipeline):
        pipeline.request_batch(request)
    visualize_hdf5(Path("snapshots/snapshot_1.hdf"), tuple(voxel_size))
예제 #6
0
    def test_prepare1(self):

        key = ArrayKey("TEST_ARRAY")
        spec = ArraySpec(voxel_size=Coordinate((1, 1)), interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {key: "testdata"},
                                 array_specs={key: spec})

        request = BatchRequest()
        shape = Coordinate((3, 3))
        request.add(key, shape, voxel_size=Coordinate((1, 1)))

        shift_node = ShiftAugment(sigma=1, shift_axis=0)
        with build((hdf5_source + shift_node)):
            shift_node.prepare(request)
            self.assertTrue(shift_node.ndim == 2)
            self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0]))
예제 #7
0
    def test_pipeline2(self):

        key = ArrayKey("TEST_ARRAY")
        spec = ArraySpec(voxel_size=Coordinate((3, 1)), interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {key: "testdata"},
                                 array_specs={key: spec})

        request = BatchRequest()
        shape = Coordinate((3, 3))
        request.add(key, shape, voxel_size=Coordinate((3, 1)))

        shift_node = ShiftAugment(prob_slip=0.2,
                                  prob_shift=0.2,
                                  sigma=1,
                                  shift_axis=0)
        with build((hdf5_source + shift_node)) as b:
            b.request_batch(request)
예제 #8
0
    def test_grow_labels_speed(self):

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_file = "test_swc.swc"
        swc_path = Path(self.path_to(swc_file))

        swc_points = self._get_points(np.array([1, 1, 1]), np.array([1, 1, 1]),
                                      bb)
        self._write_swc(swc_path, swc_points.graph)

        # create swc sources
        swc_key = PointsKey("SWC")
        labels_key = ArrayKey("LABELS")

        # add request
        request = BatchRequest()
        request.add(labels_key, bb.get_shape())
        request.add(swc_key, bb.get_shape())

        # data source for swc a
        data_source = tuple()
        data_source = (data_source + SwcFileSource(
            swc_path, [swc_key], [PointsSpec(roi=bb)]) + RasterizeSkeleton(
                points=swc_key,
                array=labels_key,
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
            ) + GrowLabels(array=labels_key, radius=3))

        pipeline = data_source

        num_repeats = 10
        t1 = time.time()
        with build(pipeline):
            for i in range(num_repeats):
                pipeline.request_batch(request)

        t2 = time.time()
        self.assertLess((t2 - t1) / num_repeats, 0.1)
예제 #9
0
def test_embedding_pipeline(
    tmpdir, aux_task, blend_mode, fusion_pipeline, train_embedding, snapshot_every
):
    setup_config = DEFAULT_CONFIG
    setup_config["FUSION_PIPELINE"] = fusion_pipeline
    setup_config["TRAIN_EMBEDDING"] = train_embedding
    setup_config["SNAPSHOT_EVERY"] = snapshot_every
    setup_config["TENSORBOARD_LOG_DIR"] = tmpdir
    setup_config["SNAPSHOT_DIR"] = tmpdir
    setup_config["SNAPSHOT_FILE_NAME"] = "test_snapshot"
    setup_config["MATCHING_FAILURES_DIR"] = None
    setup_config["BLEND_MODE"] = blend_mode
    setup_config["AUX_TASK"] = aux_task
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    output_size = Coordinate(setup_config["OUTPUT_SHAPE"]) * voxel_size
    input_size = Coordinate(setup_config["INPUT_SHAPE"]) * voxel_size
    pipeline, raw, output, inputs = embedding_pipeline(
        setup_config, get_test_data_sources
    )
    request = BatchRequest()
    request.add(raw, input_size)
    request.add(output, output_size)
    for key in inputs:
        request.add(key, output_size)
    with build(pipeline):
        batch = pipeline.request_batch(request)
        assert output in batch
        assert raw in batch
예제 #10
0
    def test_merge_basics(self):
        voxel_size = (1, 1, 1)
        GraphKey("PRESYN")
        ArrayKey("GT_LABELS")
        graphsource = GraphTestSource(voxel_size)
        arraysource = ArrayTestSoure(voxel_size)
        pipeline = (graphsource, arraysource) + MergeProvider() + RandomLocation()
        window_request = Coordinate((50, 50, 50))
        with build(pipeline):
            # Check basic merging.
            request = BatchRequest()
            request.add((GraphKeys.PRESYN), window_request)
            request.add((ArrayKeys.GT_LABELS), window_request)
            batch_res = pipeline.request_batch(request)
            self.assertTrue(ArrayKeys.GT_LABELS in batch_res.arrays)
            self.assertTrue(GraphKeys.PRESYN in batch_res.graphs)

            # Check that request of only one source also works.
            request = BatchRequest()
            request.add((GraphKeys.PRESYN), window_request)
            batch_res = pipeline.request_batch(request)
            self.assertFalse(ArrayKeys.GT_LABELS in batch_res.arrays)
            self.assertTrue(GraphKeys.PRESYN in batch_res.graphs)

        # Check that it fails, when having two sources that provide the same type.
        arraysource2 = ArrayTestSoure(voxel_size)
        pipeline_fail = (arraysource, arraysource2) + MergeProvider() + RandomLocation()
        with self.assertRaises(PipelineSetupError):
            with build(pipeline_fail):
                pass
예제 #11
0
    def test_placeholder(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = GraphKey("TEST_POINTS")

        pipeline = (
            PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) +
            ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) +
            Snapshot(
                {test_labels: "volumes/labels"},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            ))

        with build(pipeline):
            for i in range(2):

                request_size = Coordinate((40, 40, 40))

                request_a = BatchRequest(random_seed=i)
                request_a.add(test_points, request_size)
                request_a.add(test_labels, request_size, placeholder=True)

                request_b = BatchRequest(random_seed=i)
                request_b.add(test_points, request_size)
                request_b.add(test_labels, request_size)

                batch_a = pipeline.request_batch(request_a)
                batch_b = pipeline.request_batch(request_b)

                points_a = batch_a[test_points].nodes
                points_b = batch_b[test_points].nodes

                for a, b in zip(points_a, points_b):
                    assert all(np.isclose(a.location, b.location))
예제 #12
0
def visualize_foreground_pipeline(fusion_pipeline,
                                  train_foreground,
                                  distances,
                                  test_sources=True):
    setup_config = DEFAULT_CONFIG
    setup_config["FUSION_PIPELINE"] = fusion_pipeline
    setup_config["TRAIN_FOREGROUND"] = train_foreground
    setup_config["DISTANCES"] = distances
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    output_size = Coordinate(setup_config["OUTPUT_SHAPE"]) * voxel_size
    input_size = Coordinate(setup_config["INPUT_SHAPE"]) * voxel_size
    if test_sources:
        pipeline, raw, output = foreground_pipeline(setup_config,
                                                    get_test_data_sources)
    else:
        pipeline, raw, output = foreground_pipeline(setup_config)

    request = BatchRequest()
    request.add(raw, input_size)
    request.add(output, output_size)
    with build(pipeline):
        pipeline.request_batch(request)
    visualize_hdf5(Path("snapshots/snapshot_1.hdf"), tuple(voxel_size))
예제 #13
0
def test_foreground_pipeline(tmpdir, fusion_pipeline, train_foreground,
                             distance_loss, snapshot_every):
    setup_config = DEFAULT_CONFIG
    setup_config["FUSION_PIPELINE"] = fusion_pipeline
    setup_config["TRAIN_FOREGROUND"] = train_foreground
    setup_config["SNAPSHOT_EVERY"] = snapshot_every
    setup_config["TENSORBOARD_LOG_DIR"] = tmpdir
    setup_config["SNAPSHOT_DIR"] = tmpdir
    setup_config["SNAPSHOT_FILE_NAME"] = "test_snapshot"
    setup_config["MATCHING_FAILURES_DIR"] = None
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    output_size = Coordinate(setup_config["OUTPUT_SHAPE"]) * voxel_size
    input_size = Coordinate(setup_config["INPUT_SHAPE"]) * voxel_size
    pipeline, raw, output, inputs = foreground_pipeline(
        setup_config, get_test_data_sources)
    request = BatchRequest()
    request.add(raw, input_size)
    request.add(output, output_size)
    for key in inputs:
        request.add(key, output_size)
    with build(pipeline):
        batch = pipeline.request_batch(request)
        assert output in batch
        assert raw in batch
예제 #14
0
    def test_output_min_distance(self):

        voxel_size = Coordinate((20, 2, 2))

        ArrayKey("GT_VECTORS_MAP_PRESYN")
        GraphKey("PRESYN")
        GraphKey("POSTSYN")

        arraytypes_to_source_target_pointstypes = {
            ArrayKeys.GT_VECTORS_MAP_PRESYN: (GraphKeys.PRESYN, GraphKeys.POSTSYN)
        }
        arraytypes_to_stayinside_arraytypes = {
            ArrayKeys.GT_VECTORS_MAP_PRESYN: ArrayKeys.GT_LABELS
        }

        # test for partner criterion 'min_distance'
        radius_phys = 30
        pipeline_min_distance = AddVectorMapTestSource() + AddVectorMap(
            src_and_trg_points=arraytypes_to_source_target_pointstypes,
            voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size},
            radius_phys=radius_phys,
            partner_criterion="min_distance",
            stayinside_array_keys=arraytypes_to_stayinside_arraytypes,
            pad_for_partners=(0, 0, 0),
        )

        with build(pipeline_min_distance):

            request = BatchRequest()
            raw_roi = pipeline_min_distance.spec[ArrayKeys.RAW].roi
            gt_labels_roi = pipeline_min_distance.spec[ArrayKeys.GT_LABELS].roi
            presyn_roi = pipeline_min_distance.spec[GraphKeys.PRESYN].roi

            request.add(ArrayKeys.RAW, raw_roi.get_shape())
            request.add(ArrayKeys.GT_LABELS, gt_labels_roi.get_shape())
            request.add(GraphKeys.PRESYN, presyn_roi.get_shape())
            request.add(GraphKeys.POSTSYN, presyn_roi.get_shape())
            request.add(ArrayKeys.GT_VECTORS_MAP_PRESYN, presyn_roi.get_shape())
            for identifier, spec in request.items():
                spec.roi = spec.roi.shift((1000, 1000, 1000))

            batch = pipeline_min_distance.request_batch(request)

        presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes}
        postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes}
        vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data
        offset_vector_map_presyn = request[
            ArrayKeys.GT_VECTORS_MAP_PRESYN
        ].roi.get_offset()

        self.assertTrue(len(presyn_locs) > 0)
        self.assertTrue(len(postsyn_locs) > 0)

        for loc_id, point in presyn_locs.items():

            if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains(
                Coordinate(point.location)
            ):
                self.assertTrue(
                    batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains(
                        Coordinate(point.location)
                    )
                )

                dist_to_loc = {}
                for partner_id in point.attrs["partner_ids"]:
                    if partner_id in postsyn_locs.keys():
                        partner_location = postsyn_locs[partner_id].location
                        dist_to_loc[
                            np.linalg.norm(partner_location - point.location)
                        ] = partner_location
                min_dist = np.min(list(dist_to_loc.keys()))
                relevant_partner_loc = dist_to_loc[min_dist]

                presyn_loc_shifted_vx = (
                    point.location - offset_vector_map_presyn
                ) // voxel_size
                radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size]
                region_to_check = np.clip(
                    [
                        (presyn_loc_shifted_vx - radius_vx),
                        (presyn_loc_shifted_vx + radius_vx),
                    ],
                    a_min=(0, 0, 0),
                    a_max=vector_map_presyn.shape[-3:],
                )
                for x, y, z in itertools.product(
                    range(int(region_to_check[0][0]), int(region_to_check[1][0])),
                    range(int(region_to_check[0][1]), int(region_to_check[1][1])),
                    range(int(region_to_check[0][2]), int(region_to_check[1][2])),
                ):
                    if (
                        np.linalg.norm(
                            (np.array((x, y, z)) - np.asarray(point.location))
                        )
                        < radius_phys
                    ):
                        vector = [
                            vector_map_presyn[dim][x, y, z]
                            for dim in range(vector_map_presyn.shape[0])
                        ]
                        if not np.sum(vector) == 0:
                            trg_loc_of_vector_phys = (
                                np.asarray(offset_vector_map_presyn)
                                + (voxel_size * np.array([x, y, z]))
                                + np.asarray(vector)
                            )
                            self.assertTrue(
                                np.array_equal(
                                    trg_loc_of_vector_phys, relevant_partner_loc
                                )
                            )

        # test for partner criterion 'all'
        pipeline_all = AddVectorMapTestSource() + AddVectorMap(
            src_and_trg_points=arraytypes_to_source_target_pointstypes,
            voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size},
            radius_phys=radius_phys,
            partner_criterion="all",
            stayinside_array_keys=arraytypes_to_stayinside_arraytypes,
            pad_for_partners=(0, 0, 0),
        )

        with build(pipeline_all):
            batch = pipeline_all.request_batch(request)

        presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes}
        postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes}
        vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data
        offset_vector_map_presyn = request[
            ArrayKeys.GT_VECTORS_MAP_PRESYN
        ].roi.get_offset()

        self.assertTrue(len(presyn_locs) > 0)
        self.assertTrue(len(postsyn_locs) > 0)

        for loc_id, point in presyn_locs.items():

            if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains(
                Coordinate(point.location)
            ):
                self.assertTrue(
                    batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains(
                        Coordinate(point.location)
                    )
                )

                partner_ids_to_locs_per_src, count_vectors_per_partner = {}, {}
                for partner_id in point.attrs["partner_ids"]:
                    if partner_id in postsyn_locs.keys():
                        partner_ids_to_locs_per_src[partner_id] = postsyn_locs[
                            partner_id
                        ].location.tolist()
                        count_vectors_per_partner[partner_id] = 0

                presyn_loc_shifted_vx = (
                    point.location - offset_vector_map_presyn
                ) // voxel_size
                radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size]
                region_to_check = np.clip(
                    [
                        (presyn_loc_shifted_vx - radius_vx),
                        (presyn_loc_shifted_vx + radius_vx),
                    ],
                    a_min=(0, 0, 0),
                    a_max=vector_map_presyn.shape[-3:],
                )
                for x, y, z in itertools.product(
                    range(int(region_to_check[0][0]), int(region_to_check[1][0])),
                    range(int(region_to_check[0][1]), int(region_to_check[1][1])),
                    range(int(region_to_check[0][2]), int(region_to_check[1][2])),
                ):
                    if (
                        np.linalg.norm(
                            (np.array((x, y, z)) - np.asarray(point.location))
                        )
                        < radius_phys
                    ):
                        vector = [
                            vector_map_presyn[dim][x, y, z]
                            for dim in range(vector_map_presyn.shape[0])
                        ]
                        if not np.sum(vector) == 0:
                            trg_loc_of_vector_phys = (
                                np.asarray(offset_vector_map_presyn)
                                + (voxel_size * np.array([x, y, z]))
                                + np.asarray(vector)
                            )
                            self.assertTrue(
                                trg_loc_of_vector_phys.tolist()
                                in partner_ids_to_locs_per_src.values()
                            )

                            for (
                                partner_id,
                                partner_loc,
                            ) in partner_ids_to_locs_per_src.items():
                                if np.array_equal(
                                    np.asarray(trg_loc_of_vector_phys), partner_loc
                                ):
                                    count_vectors_per_partner[partner_id] += 1
                self.assertTrue(
                    (
                        list(count_vectors_per_partner.values())
                        - np.min(list(count_vectors_per_partner.values()))
                        <= len(count_vectors_per_partner.keys())
                    ).all()
                )
def train_until(
        data_providers,
        affinity_neighborhood,
        meta_graph_filename,
        stop,
        input_shape,
        output_shape,
        loss,
        optimizer,
        tensor_affinities,
        tensor_affinities_mask,
        tensor_glia,
        tensor_glia_mask,
        summary,
        save_checkpoint_every,
        pre_cache_size,
        pre_cache_num_workers,
        snapshot_every,
        balance_labels,
        renumber_connected_components,
        network_inputs,
        ignore_labels_for_slip,
        grow_boundaries,
        mask_out_labels,
        snapshot_dir):

    ignore_keys_for_slip = (LABELS_KEY, GT_MASK_KEY, GT_GLIA_KEY, GLIA_MASK_KEY, UNLABELED_KEY) if ignore_labels_for_slip else ()

    defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects'
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    input_voxel_size = Coordinate((120, 12, 12)) * 3
    output_voxel_size = Coordinate((40, 36, 36)) * 3

    input_size = Coordinate(input_shape) * input_voxel_size
    output_size = Coordinate(output_shape) * output_voxel_size

    num_affinities = sum(len(nh) for nh in affinity_neighborhood)
    gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size))
    print("gt affinities size", gt_affinities_size)

    # TODO why is GT_AFFINITIES three-dimensional? compare to
    # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35
    # TODO Use glia scale somehow, probably not possible with tensorflow 1.3 because it does not know uint64...
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(RAW_KEY,             input_size,  voxel_size=input_voxel_size)
    request.add(LABELS_KEY,          output_size, voxel_size=output_voxel_size)
    request.add(GT_AFFINITIES_KEY,   output_size, voxel_size=output_voxel_size)
    request.add(AFFINITIES_MASK_KEY, output_size, voxel_size=output_voxel_size)
    request.add(GT_MASK_KEY,         output_size, voxel_size=output_voxel_size)
    request.add(GLIA_MASK_KEY,       output_size, voxel_size=output_voxel_size)
    request.add(GLIA_KEY,            output_size, voxel_size=output_voxel_size)
    request.add(GT_GLIA_KEY,         output_size, voxel_size=output_voxel_size)
    request.add(UNLABELED_KEY,       output_size, voxel_size=output_voxel_size)
    if balance_labels:
        request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size)
    # always balance glia labels!
    request.add(GLIA_SCALE_KEY, output_size, voxel_size=output_voxel_size)
    network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY
    network_inputs[tensor_glia_mask]       = GLIA_SCALE_KEY#GLIA_SCALE_KEY if balance_labels else GLIA_MASK_KEY

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(RAW_KEY, None) +
        Pad(GT_MASK_KEY, None) +
        Pad(GLIA_MASK_KEY, None) +
        Pad(LABELS_KEY, size=NETWORK_OUTPUT_SHAPE / 2, value=np.uint64(-3)) +
        Pad(GT_GLIA_KEY, size=NETWORK_OUTPUT_SHAPE / 2) +
        # Pad(LABELS_KEY, None) +
        # Pad(GT_GLIA_KEY, None) +
        RandomLocation() + # chose a random location inside the provided arrays
        Reject(mask=GT_MASK_KEY, min_masked=0.5) +
        Reject(mask=GLIA_MASK_KEY, min_masked=0.5) +
        MapNumpyArray(lambda array: np.require(array, dtype=np.int64), GT_GLIA_KEY) # this is necessary because gunpowder 1.3 only understands int64, not uint64

        for provider in data_providers)

    # TODO figure out what this is for
    snapshot_request = BatchRequest({
        LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY],
        AFFINITIES_KEY    : request[GT_AFFINITIES_KEY],
    })

    # no need to do anything here. random sections will be replaced with sections from this source (only raw)
    artifact_source = (
        Hdf5Source(
            os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                RAW_KEY        : 'defect_sections/raw',
                DEFECT_MASK_KEY : 'defect_sections/mask',
            },
            array_specs={
                RAW_KEY        : ArraySpec(voxel_size=input_voxel_size),
                DEFECT_MASK_KEY : ArraySpec(voxel_size=input_voxel_size),
            }
        ) +
        RandomLocation(min_masked=0.05, mask=DEFECT_MASK_KEY) +
        Normalize(RAW_KEY) +
        IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            subsample=8
        ) +
        SimpleAugment(transpose_only=[1,2])
    )

    train_pipeline  = data_sources
    train_pipeline += RandomProvider()

    train_pipeline += ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            augmentation_probability=0.5,
            subsample=8
        )

    # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: before misalign: ' % GT_MASK_KEY)
    train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip)
    # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: after  misalign: ' % GT_MASK_KEY)

    train_pipeline += SimpleAugment(transpose_only=[1,2])
    train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
    train_pipeline += DefectAugment(RAW_KEY,
                                    prob_missing=0.03,
                                    prob_low_contrast=0.01,
                                    prob_artifact=0.03,
                                    artifact_source=artifact_source,
                                    artifacts=RAW_KEY,
                                    artifacts_mask=DEFECT_MASK_KEY,
                                    contrast_scale=0.5)
    train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1)
    train_pipeline += ZeroOutConstSections(RAW_KEY)

    if grow_boundaries > 0:
        train_pipeline += GrowBoundary(LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True)

    _logger.info("Renumbering connected components? %s", renumber_connected_components)
    if renumber_connected_components:
        train_pipeline += RenumberConnectedComponents(labels=LABELS_KEY)

    train_pipeline += NewKeyFromNumpyArray(lambda array: 1 - array, GT_GLIA_KEY, UNLABELED_KEY)

    if len(mask_out_labels) > 0:
        train_pipeline += MaskOutLabels(label_key=LABELS_KEY, mask_key=GT_MASK_KEY, ids_to_be_masked=mask_out_labels)

    # labels_mask: anything that connects into labels_mask will be zeroed out
    # unlabelled: anyhing that points into unlabeled will have zero affinity;
    #             affinities within unlabelled will be masked out
    train_pipeline += AddAffinities(
            affinity_neighborhood=affinity_neighborhood,
            labels=LABELS_KEY,
            labels_mask=GT_MASK_KEY,
            affinities=GT_AFFINITIES_KEY,
            affinities_mask=AFFINITIES_MASK_KEY,
            unlabelled=UNLABELED_KEY
    )

    snapshot_datasets = {
        RAW_KEY: 'volumes/raw',
        LABELS_KEY: 'volumes/labels/neuron_ids',
        GT_AFFINITIES_KEY: 'volumes/affinities/gt',
        GT_GLIA_KEY: 'volumes/labels/glia_gt',
        UNLABELED_KEY: 'volumes/labels/unlabeled',
        AFFINITIES_KEY: 'volumes/affinities/prediction',
        LOSS_GRADIENT_KEY: 'volumes/loss_gradient',
        AFFINITIES_MASK_KEY: 'masks/affinities',
        GLIA_KEY: 'volumes/labels/glia_pred',
        GT_MASK_KEY: 'masks/gt',
        GLIA_MASK_KEY: 'masks/glia'}

    if balance_labels:
        train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY)
        snapshot_datasets[AFFINITIES_SCALE_KEY] = 'masks/affinity-scale'
    train_pipeline += BalanceLabels(labels=GT_GLIA_KEY, scales=GLIA_SCALE_KEY, mask=GLIA_MASK_KEY)
    snapshot_datasets[GLIA_SCALE_KEY] = 'masks/glia-scale'


    if (pre_cache_size > 0 and pre_cache_num_workers > 0):
        train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers)
    train_pipeline += Train(
            summary=summary,
            graph=meta_graph_filename,
            save_every=save_checkpoint_every,
            optimizer=optimizer,
            loss=loss,
            inputs=network_inputs,
            log_dir='log',
            outputs={tensor_affinities: AFFINITIES_KEY, tensor_glia: GLIA_KEY},
            gradients={tensor_affinities: LOSS_GRADIENT_KEY},
            array_specs={
                AFFINITIES_KEY       : ArraySpec(voxel_size=output_voxel_size),
                LOSS_GRADIENT_KEY    : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_MASK_KEY  : ArraySpec(voxel_size=output_voxel_size),
                GT_MASK_KEY          : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size),
                GLIA_MASK_KEY        : ArraySpec(voxel_size=output_voxel_size),
                GLIA_SCALE_KEY       : ArraySpec(voxel_size=output_voxel_size),
                GLIA_KEY             : ArraySpec(voxel_size=output_voxel_size)
            }
        )

    train_pipeline += Snapshot(
            snapshot_datasets,
            every=snapshot_every,
            output_filename='batch_{iteration}.hdf',
            output_dir=snapshot_dir,
            additional_request=snapshot_request,
            attributes_callback=Snapshot.default_attributes_callback())

    train_pipeline += PrintProfilingStats(every=50)

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(trained_until, stop):
            b.request_batch(request)

    print("Training finished")
예제 #16
0
if __name__ == "__main__":

    logging.basicConfig(level=logging.INFO)

    setup_config = DEFAULT_CONFIG
    setup_config["FUSION_PIPELINE"] = True
    setup_config["TRAIN_EMBEDDING"] = True
    setup_config["SNAPSHOT_EVERY"] = 0
    setup_config["SNAPSHOT_FILE_NAME"] = None
    setup_config["MATCHING_FAILURES_DIR"] = None
    setup_config["PROFILE_EVERY"] = 1
    setup_config["CLAHE"] = False
    setup_config["NUM_FMAPS_EMBEDDING"] = 12
    setup_config["FMAP_INC_FACTORS_EMBEDDING"] = 5
    voxel_size = Coordinate(setup_config["VOXEL_SIZE"])
    output_size = Coordinate(setup_config["OUTPUT_SHAPE"]) * voxel_size
    input_size = Coordinate(setup_config["INPUT_SHAPE"]) * voxel_size
    pipeline, raw, output, inputs = embedding_pipeline(setup_config,
                                                       get_test_data_sources)
    request = BatchRequest()
    request.add(raw, input_size)
    request.add(output, output_size)
    for key in inputs:
        request.add(key, output_size)
    with build(pipeline):
        for i in range(1):
            batch = pipeline.request_batch(request)
            assert output in batch
            assert raw in batch
예제 #17
0
        masked_add_b: "volumes/masked_add_b",
        softmask: "volumes/softmask",
        softmask_b: "volumes/softmask_b",
        mask_maxed: "volumes/mask_maxed",
        mask_maxed_b: "volumes/mask_maxed_b",
    },
    every=1,
))

with build(pipeline):
    for i in range(1):
        request = BatchRequest(random_seed=i)

        # add request
        request = gp.BatchRequest()
        request.add(raw_fused, input_size)
        request.add(labels_fused, input_size)
        request.add(raw_fused_b, input_size)
        request.add(labels_fused_b, input_size)

        # add snapshot request
        # request.add(fg, output_size)
        # request.add(labels_fg, output_size)
        # request.add(gradient_fg, output_size)
        request.add(raw_base, input_size)
        request.add(raw_add, input_size)
        request.add(labels_base, input_size)
        request.add(labels_add, input_size)

        # debugging
        request.add(masked_base, input_size)
예제 #18
0
파일: scan.py 프로젝트: yajivunev/gunpowder
    def test_output(self):

        source = ScanTestSource()

        chunk_request = BatchRequest()
        chunk_request.add(ArrayKeys.RAW, (400, 30, 34))
        chunk_request.add(ArrayKeys.GT_LABELS, (200, 10, 14))
        chunk_request.add(GraphKeys.GT_GRAPH, (400, 30, 34))

        pipeline = ScanTestSource() + Scan(chunk_request, num_workers=10)

        with build(pipeline):

            raw_spec = pipeline.spec[ArrayKeys.RAW]
            labels_spec = pipeline.spec[ArrayKeys.GT_LABELS]
            graph_spec = pipeline.spec[GraphKeys.GT_GRAPH]

            full_request = BatchRequest({
                ArrayKeys.RAW: raw_spec,
                ArrayKeys.GT_LABELS: labels_spec,
                GraphKeys.GT_GRAPH: graph_spec,
            })

            batch = pipeline.request_batch(full_request)
            voxel_size = pipeline.spec[ArrayKeys.RAW].voxel_size

        # assert that pixels encode their position
        for (array_key, array) in batch.arrays.items():

            # the z,y,x coordinates of the ROI
            roi = array.spec.roi
            meshgrids = np.meshgrid(range(roi.get_begin()[0] // voxel_size[0],
                                          roi.get_end()[0] // voxel_size[0]),
                                    range(roi.get_begin()[1] // voxel_size[1],
                                          roi.get_end()[1] // voxel_size[1]),
                                    range(roi.get_begin()[2] // voxel_size[2],
                                          roi.get_end()[2] // voxel_size[2]),
                                    indexing='ij')
            data = meshgrids[0] + meshgrids[1] + meshgrids[2]

            self.assertTrue((array.data == data).all())

        for (graph_key, graph) in batch.graphs.items():

            roi = graph.spec.roi
            for i, j, k in itertools.product(range(20000, 22000, 100),
                                             range(2000, 2200, 10),
                                             range(2000, 2200, 10)):
                assert all(
                    np.isclose(
                        graph.node(coordinate_to_id(i, j, k)).location,
                        np.array([i, j, k])))

        assert (batch.arrays[ArrayKeys.RAW].spec.roi.get_offset() == (20000,
                                                                      2000,
                                                                      2000))

        # test scanning with empty request

        pipeline = ScanTestSource() + Scan(chunk_request, num_workers=1)
        with build(pipeline):
            batch = pipeline.request_batch(BatchRequest())
예제 #19
0
def train_until(
        data_providers,
        affinity_neighborhood,
        meta_graph_filename,
        stop,
        input_shape,
        output_shape,
        loss,
        optimizer,
        tensor_affinities,
        tensor_affinities_nn,
        tensor_affinities_mask,
        summary,
        save_checkpoint_every,
        pre_cache_size,
        pre_cache_num_workers,
        snapshot_every,
        balance_labels,
        renumber_connected_components,
        network_inputs,
        ignore_labels_for_slip,
        grow_boundaries):

    ignore_keys_for_slip = (GT_LABELS_KEY, GT_MASK_KEY) if ignore_labels_for_slip else ()

    defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects'
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    input_voxel_size = Coordinate((120, 12, 12)) * 3
    output_voxel_size = Coordinate((40, 36, 36)) * 3

    input_size = Coordinate(input_shape) * input_voxel_size
    output_size = Coordinate(output_shape) * output_voxel_size
    output_size_nn = Coordinate(s - 2 for s in output_shape) * output_voxel_size

    num_affinities = sum(len(nh) for nh in affinity_neighborhood)
    gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size))
    print("gt affinities size", gt_affinities_size)

    # TODO why is GT_AFFINITIES three-dimensional? compare to
    # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(RAW_KEY,             input_size,     voxel_size=input_voxel_size)
    request.add(GT_LABELS_KEY,       output_size,    voxel_size=output_voxel_size)
    request.add(GT_AFFINITIES_KEY,   output_size,    voxel_size=output_voxel_size)
    request.add(AFFINITIES_MASK_KEY, output_size,    voxel_size=output_voxel_size)
    request.add(GT_MASK_KEY,         output_size,    voxel_size=output_voxel_size)
    request.add(AFFINITIES_NN_KEY,   output_size_nn, voxel_size=output_voxel_size)
    if balance_labels:
        request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size)
    network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(RAW_KEY, None) +
        Pad(GT_MASK_KEY, None) +
        RandomLocation() + # chose a random location inside the provided arrays
        Reject(GT_MASK_KEY) + # reject batches wich do contain less than 50% labelled data
        Reject(GT_LABELS_KEY, min_masked=0.0, reject_probability=0.95)

        for provider in data_providers)

    # TODO figure out what this is for
    snapshot_request = BatchRequest({
        LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY],
        AFFINITIES_KEY    : request[GT_AFFINITIES_KEY],
        AFFINITIES_NN_KEY : request[AFFINITIES_NN_KEY]
    })

    # no need to do anything here. random sections will be replaced with sections from this source (only raw)
    artifact_source = (
        Hdf5Source(
            os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                RAW_KEY        : 'defect_sections/raw',
                ALPHA_MASK_KEY : 'defect_sections/mask',
            },
            array_specs={
                RAW_KEY        : ArraySpec(voxel_size=input_voxel_size),
                ALPHA_MASK_KEY : ArraySpec(voxel_size=input_voxel_size),
            }
        ) +
        RandomLocation(min_masked=0.05, mask=ALPHA_MASK_KEY) +
        Normalize(RAW_KEY) +
        IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            subsample=8
        ) +
        SimpleAugment(transpose_only=[1,2])
    )

    train_pipeline  = data_sources
    train_pipeline += RandomProvider()
    train_pipeline += ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            augmentation_probability=0.5,
            subsample=8
        )
    train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip)
    train_pipeline += SimpleAugment(transpose_only=[1,2])
    train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
    train_pipeline += DefectAugment(RAW_KEY,
                      prob_missing=0.03,
                      prob_low_contrast=0.01,
                      prob_artifact=0.03,
                      artifact_source=artifact_source,
                      artifacts=RAW_KEY,
                      artifacts_mask=ALPHA_MASK_KEY,
                      contrast_scale=0.5)
    train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1)
    train_pipeline += ZeroOutConstSections(RAW_KEY)
    if grow_boundaries > 0:
        train_pipeline += GrowBoundary(GT_LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True)

    if renumber_connected_components:
        train_pipeline += RenumberConnectedComponents(labels=GT_LABELS_KEY)

    train_pipeline += AddAffinities(
            affinity_neighborhood=affinity_neighborhood,
            labels=GT_LABELS_KEY,
            labels_mask=GT_MASK_KEY,
            affinities=GT_AFFINITIES_KEY,
            affinities_mask=AFFINITIES_MASK_KEY
        )

    if balance_labels:
        train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY)

    train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers)
    train_pipeline += Train(
            summary=summary,
            graph=meta_graph_filename,
            save_every=save_checkpoint_every,
            optimizer=optimizer,
            loss=loss,
            inputs=network_inputs,
            log_dir='log',
            outputs={tensor_affinities: AFFINITIES_KEY, tensor_affinities_nn: AFFINITIES_NN_KEY},
            gradients={tensor_affinities: LOSS_GRADIENT_KEY},
            array_specs={
                AFFINITIES_KEY       : ArraySpec(voxel_size=output_voxel_size),
                LOSS_GRADIENT_KEY    : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_MASK_KEY  : ArraySpec(voxel_size=output_voxel_size),
                GT_MASK_KEY          : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_NN_KEY    : ArraySpec(voxel_size=output_voxel_size)
            }
        )
    train_pipeline += Snapshot(
            dataset_names={
                RAW_KEY             : 'volumes/raw',
                GT_LABELS_KEY       : 'volumes/labels/neuron_ids',
                GT_AFFINITIES_KEY   : 'volumes/affinities/gt',
                AFFINITIES_KEY      : 'volumes/affinities/prediction',
                LOSS_GRADIENT_KEY   : 'volumes/loss_gradient',
                AFFINITIES_MASK_KEY : 'masks/affinities',
                AFFINITIES_NN_KEY   : 'volumes/affinities/prediction-nn'
            },
            every=snapshot_every,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request,
            attributes_callback=Snapshot.default_attributes_callback())
    train_pipeline += PrintProfilingStats(every=50)

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(trained_until, stop):
            b.request_batch(request)

    print("Training finished")
예제 #20
0
    def test_get_neuron_pair(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc_source = PointsKey("SWC_SOURCE")
        ensure_nonempty = PointsKey("ENSURE_NONEMPTY")
        labels_source = ArrayKey("LABELS_SOURCE")
        img_source = ArrayKey("IMG_SOURCE")
        img_swc = PointsKey("IMG_SWC")
        label_swc = PointsKey("LABEL_SWC")
        imgs = ArrayKey("IMGS")
        labels = ArrayKey("LABELS")
        points_a = PointsKey("SKELETON_A")
        points_b = PointsKey("SKELETON_B")
        img_a = ArrayKey("VOLUME_A")
        img_b = ArrayKey("VOLUME_B")
        labels_a = ArrayKey("LABELS_A")
        labels_b = ArrayKey("LABELS_B")

        data_shape = 5
        output_shape = Coordinate((data_shape, data_shape, data_shape))

        # Get points from test swc
        swc_file_source = SwcFileSource(
            path,
            [swc_source, ensure_nonempty],
            [
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
            ],
        )
        # Create an artificial image source by rasterizing the points
        image_source = (SwcFileSource(
            path, [img_swc],
            [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]) +
                        RasterizeSkeleton(
                            points=img_swc,
                            array=img_source,
                            array_spec=ArraySpec(
                                interpolatable=True,
                                dtype=np.uint32,
                                voxel_size=Coordinate((1, 1, 1)),
                            ),
                        ) +
                        BinarizeLabels(labels=img_source, labels_binary=imgs) +
                        GrowLabels(array=imgs, radius=0))
        # Create an artificial label source by rasterizing the points
        label_source = (SwcFileSource(path, [label_swc], [
            PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))
        ]) + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True,
                dtype=np.uint32,
                voxel_size=Coordinate((1, 1, 1)),
            ),
        ) + BinarizeLabels(labels=labels_source, labels_binary=labels) +
                        GrowLabels(array=labels, radius=1))

        skeleton = tuple()
        skeleton += ((swc_file_source, image_source, label_source) +
                     MergeProvider() +
                     RandomLocation(ensure_nonempty=ensure_nonempty,
                                    ensure_centered=True))

        pipeline = skeleton + GetNeuronPair(
            point_source=swc_source,
            nonempty_placeholder=ensure_nonempty,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=(1, 3),
            shift_attempts=100,
            request_attempts=10,
            output_shape=output_shape,
        )

        request = BatchRequest()

        request.add(points_a, output_shape)
        request.add(points_b, output_shape)
        request.add(img_a, output_shape)
        request.add(img_b, output_shape)
        request.add(labels_a, output_shape)
        request.add(labels_b, output_shape)

        with build(pipeline):
            for i in range(10):
                batch = pipeline.request_batch(request)
                assert all([
                    x in batch for x in
                    [points_a, points_b, img_a, img_b, labels_a, labels_b]
                ])

                min_dist = 5
                for a, b in itertools.product(
                        batch[points_a].nodes,
                        batch[points_b].nodes,
                ):
                    min_dist = min(
                        min_dist,
                        np.linalg.norm(a.location - b.location),
                    )

                self.assertLessEqual(min_dist, 3)
                self.assertGreaterEqual(min_dist, 1)
예제 #21
0
def test_recenter():
    path = Path(self.path_to("test_swc_source.swc"))

    # write test swc
    self._write_swc(path, self._toy_swc_points())

    # read arrays
    swc_source = PointsKey("SWC_SOURCE")
    labels_source = ArrayKey("LABELS_SOURCE")
    img_source = ArrayKey("IMG_SOURCE")
    img_swc = PointsKey("IMG_SWC")
    label_swc = PointsKey("LABEL_SWC")
    imgs = ArrayKey("IMGS")
    labels = ArrayKey("LABELS")
    points_a = PointsKey("SKELETON_A")
    points_b = PointsKey("SKELETON_B")
    img_a = ArrayKey("VOLUME_A")
    img_b = ArrayKey("VOLUME_B")
    labels_a = ArrayKey("LABELS_A")
    labels_b = ArrayKey("LABELS_B")

    # Get points from test swc
    swc_file_source = SwcFileSource(
        path, [swc_source], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
    )
    # Create an artificial image source by rasterizing the points
    image_source = (
        SwcFileSource(
            path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=img_swc,
            array=img_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=img_source, labels_binary=imgs)
        + GrowLabels(array=imgs, radius=0)
    )
    # Create an artificial label source by rasterizing the points
    label_source = (
        SwcFileSource(
            path, [label_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=labels_source, labels_binary=labels)
        + GrowLabels(array=labels, radius=1)
    )

    skeleton = tuple()
    skeleton += (
        (swc_file_source, image_source, label_source)
        + MergeProvider()
        + RandomLocation(ensure_nonempty=swc_source, ensure_centered=True)
    )

    pipeline = (
        skeleton
        + GetNeuronPair(
            point_source=swc_source,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=4,
            shift_attempts=100,
        )
        + Recenter(points_a, img_a, max_offset=4)
        + Recenter(points_b, img_b, max_offset=4)
    )

    request = BatchRequest()

    data_shape = 9

    request.add(points_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(points_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_b, Coordinate((data_shape, data_shape, data_shape)))

    with build(pipeline):
        batch = pipeline.request_batch(request)

    data_a = batch[img_a].data
    assert data_a[tuple(np.array(data_a.shape) // 2)] == 1
    data_a = np.pad(data_a, (1,), "constant", constant_values=(0,))
    data_b = batch[img_b].data
    assert data_b[tuple(np.array(data_b.shape) // 2)] == 1
    data_b = np.pad(data_b, (1,), "constant", constant_values=(0,))

    data_c = data_a + data_b

    data = np.array((data_a, data_b, data_c))

    for _, point in batch[points_a].data.items():
        assert (
            data[(0,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1, its {}".format(
            point.location, data[(0,) + tuple(int(x) for x in point.location)]
        )

    for _, point in batch[points_b].data.items():
        assert (
            data[(1,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1".format(point.location)
예제 #22
0
    def test_rasterize_speed(self):
        # This is worryingly slow for such a small volume (256**3) and only 2
        # straight lines for skeletons.

        LABEL_RADIUS = 3

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points)

        # create swc sources
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)

        # add request
        request = BatchRequest()
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (
            data_sources_a
            + SwcFileSource(swc_paths[0], swc_keys[0], PointsSpec(roi=bb))
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint32, voxel_size=voxel_size
                ),
                radius=LABEL_RADIUS,
            )
        )

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (
            data_sources_b
            + SwcFileSource(swc_paths[1], swc_keys[1], PointsSpec(roi=bb))
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint32, voxel_size=voxel_size
                ),
                radius=LABEL_RADIUS,
            )
        )
        data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider()

        pipeline = data_sources

        t1 = time.time()
        with build(pipeline):
            batch = pipeline.request_batch(request)

        a_data = batch[labels_keys[0]].data
        a_data = np.pad(a_data, (1,), "constant", constant_values=(0,))

        b_data = batch[labels_keys[1]].data
        b_data = np.pad(b_data, (1,), "constant", constant_values=(0,))

        t2 = time.time()
        self.assertLess(t2 - t1, 0.1)
예제 #23
0
    def test_two_disjoint_lines_softmask(self):
        LABEL_RADIUS = 3
        RAW_RADIUS = 3
        # exagerated to show problem
        BLEND_SMOOTHNESS = 10

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(
            Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points)

        # create swc sources
        fused = ArrayKey("FUSED")
        fused_labels = ArrayKey("FUSED_LABELS")
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")
        raw_key_names = ("RAW_A", "RAW_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)
        raw_keys = tuple(ArrayKey(name) for name in raw_key_names)

        # add request
        request = BatchRequest()
        request.add(fused, bb.get_shape())
        request.add(fused_labels, bb.get_shape())
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(raw_keys[0], bb.get_shape())
        request.add(raw_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (data_sources_a + SwcFileSource(
            swc_paths[0], swc_keys[0], PointsSpec(roi=bb)) + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=LABEL_RADIUS,
            ) + RasterizeSkeleton(
                points=swc_keys[0],
                array=raw_keys[0],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=RAW_RADIUS,
            ))

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (data_sources_b + SwcFileSource(
            swc_paths[1], swc_keys[1], PointsSpec(roi=bb)) + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=LABEL_RADIUS,
            ) + RasterizeSkeleton(
                points=swc_keys[1],
                array=raw_keys[1],
                array_spec=ArraySpec(interpolatable=False,
                                     dtype=np.uint32,
                                     voxel_size=voxel_size),
                radius=RAW_RADIUS,
            ))
        data_sources = tuple([data_sources_a, data_sources_b
                              ]) + MergeProvider()

        pipeline = data_sources + FusionAugment(
            raw_keys[0],
            raw_keys[1],
            labels_keys[0],
            labels_keys[1],
            fused,
            fused_labels,
            blend_mode="labels_mask",
            blend_smoothness=BLEND_SMOOTHNESS,
            num_blended_objects=0,
        )

        with build(pipeline):
            batch = pipeline.request_batch(request)

        fused_data = batch[fused].data
        fused_data = np.pad(fused_data, (1, ),
                            "constant",
                            constant_values=(0, ))

        a_data = batch[raw_keys[0]].data
        a_data = np.pad(a_data, (1, ), "constant", constant_values=(0, ))

        b_data = batch[raw_keys[1]].data
        b_data = np.pad(b_data, (1, ), "constant", constant_values=(0, ))

        all_data = np.zeros((5, ) + fused_data.shape)
        all_data[0, :, :, :] = fused_data
        all_data[1, :, :, :] = a_data + b_data
        all_data[2, :, :, :] = fused_data - a_data - b_data
        all_data[3, :, :, :] = a_data
        all_data[4, :, :, :] = b_data

        # Uncomment to visualize problem
        if imported_volshow:
            volshow(all_data)
            # input("Press enter when you are done viewing the data: ")

        diff = np.linalg.norm(fused_data - a_data - b_data)
        self.assertAlmostEqual(diff, 0)
예제 #24
0
    def test_two_disjoint_lines_intensity(self):
        LABEL_RADIUS = 3
        RAW_RADIUS = 3
        BLEND_SMOOTHNESS = 3

        bb = Roi(Coordinate([0, 0, 0]), ([256, 256, 256]))
        voxel_size = Coordinate([1, 1, 1])
        swc_files = ("test_line_a.swc", "test_line_b.swc")
        swc_paths = tuple(Path(self.path_to(file_name)) for file_name in swc_files)

        # create two lines seperated by a given distance and write them to swc files
        intercepts, slopes = self._get_line_pair(roi=bb, dist=3 * LABEL_RADIUS)
        for intercept, slope, swc_path in zip(intercepts, slopes, swc_paths):
            swc_points = self._get_points(intercept, slope, bb)
            self._write_swc(swc_path, swc_points.to_nx_graph())

        # create swc sources
        fused = ArrayKey("FUSED")
        fused_labels = ArrayKey("FUSED_LABELS")
        fused_swc = PointsKey("FUSED_SWC")
        swc_key_names = ("SWC_A", "SWC_B")
        labels_key_names = ("LABELS_A", "LABELS_B")
        raw_key_names = ("RAW_A", "RAW_B")

        swc_keys = tuple(PointsKey(name) for name in swc_key_names)
        labels_keys = tuple(ArrayKey(name) for name in labels_key_names)
        raw_keys = tuple(ArrayKey(name) for name in raw_key_names)

        # add request
        request = BatchRequest()
        request.add(fused, bb.get_shape())
        request.add(fused_labels, bb.get_shape())
        request.add(fused_swc, bb.get_shape())
        request.add(labels_keys[0], bb.get_shape())
        request.add(labels_keys[1], bb.get_shape())
        request.add(raw_keys[0], bb.get_shape())
        request.add(raw_keys[1], bb.get_shape())
        request.add(swc_keys[0], bb.get_shape())
        request.add(swc_keys[1], bb.get_shape())

        # data source for swc a
        data_sources_a = tuple()
        data_sources_a = (
            data_sources_a
            + SwcFileSource(swc_paths[0], [swc_keys[0]], [PointsSpec(roi=bb)])
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=labels_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=labels_keys[0], radii=[LABEL_RADIUS])
            + RasterizeSkeleton(
                points=swc_keys[0],
                array=raw_keys[0],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=raw_keys[0], radii=[RAW_RADIUS])
            + Normalize(raw_keys[0])
        )

        # data source for swc b
        data_sources_b = tuple()
        data_sources_b = (
            data_sources_b
            + SwcFileSource(swc_paths[1], [swc_keys[1]], [PointsSpec(roi=bb)])
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=labels_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=labels_keys[1], radii=[LABEL_RADIUS])
            + RasterizeSkeleton(
                points=swc_keys[1],
                array=raw_keys[1],
                array_spec=ArraySpec(
                    interpolatable=False, dtype=np.uint16, voxel_size=voxel_size
                ),
            )
            + GrowLabels(array=raw_keys[1], radii=[RAW_RADIUS])
            + Normalize(raw_keys[1])
        )
        data_sources = tuple([data_sources_a, data_sources_b]) + MergeProvider()

        pipeline = data_sources + FusionAugment(
            raw_keys[0],
            raw_keys[1],
            labels_keys[0],
            labels_keys[1],
            swc_keys[0],
            swc_keys[1],
            fused,
            fused_labels,
            fused_swc,
            blend_mode="intensity",
            blend_smoothness=BLEND_SMOOTHNESS,
            num_blended_objects=0,
        )

        with build(pipeline):
            batch = pipeline.request_batch(request)

        fused_data = batch[fused].data
        fused_data = np.pad(fused_data, (1,), "constant", constant_values=(0,))

        a_data = batch[raw_keys[0]].data
        a_data = np.pad(a_data, (1,), "constant", constant_values=(0,))

        b_data = batch[raw_keys[1]].data
        b_data = np.pad(b_data, (1,), "constant", constant_values=(0,))

        diff = np.linalg.norm(fused_data - a_data - b_data)
        self.assertAlmostEqual(diff, 0)
예제 #25
0
        labels_fused: "volumes/labels_fused",
        labels_base: "volumes/labels_base",
        labels_add: "volumes/labels_add",
        raw_fused_b: "volumes/raw_fused_b",
        labels_fused_b: "volumes/labels_fused_b",
    },
    every=1,
))

with build(pipeline):
    for i in range(1):
        request = BatchRequest(random_seed=i)

        # add request
        request = gp.BatchRequest()
        request.add(raw_fused, input_size)
        request.add(labels_fused, input_size)
        request.add(swc_fused, input_size)
        request.add(raw_fused_b, input_size)
        request.add(labels_fused_b, input_size)
        request.add(swc_fused_b, input_size)

        # add snapshot request
        # request.add(fg, output_size)
        # request.add(labels_fg, output_size)
        # request.add(gradient_fg, output_size)
        request.add(raw_base, input_size)
        request.add(raw_add, input_size)
        request.add(labels_base, input_size)
        request.add(labels_add, input_size)
        request.add(swc_base, input_size)