Esempio n. 1
0
    def test_merge_basics(self):
        voxel_size = (1, 1, 1)
        GraphKey("PRESYN")
        ArrayKey("GT_LABELS")
        graphsource = GraphTestSource(voxel_size)
        arraysource = ArrayTestSoure(voxel_size)
        pipeline = (graphsource, arraysource) + MergeProvider() + RandomLocation()
        window_request = Coordinate((50, 50, 50))
        with build(pipeline):
            # Check basic merging.
            request = BatchRequest()
            request.add((GraphKeys.PRESYN), window_request)
            request.add((ArrayKeys.GT_LABELS), window_request)
            batch_res = pipeline.request_batch(request)
            self.assertTrue(ArrayKeys.GT_LABELS in batch_res.arrays)
            self.assertTrue(GraphKeys.PRESYN in batch_res.graphs)

            # Check that request of only one source also works.
            request = BatchRequest()
            request.add((GraphKeys.PRESYN), window_request)
            batch_res = pipeline.request_batch(request)
            self.assertFalse(ArrayKeys.GT_LABELS in batch_res.arrays)
            self.assertTrue(GraphKeys.PRESYN in batch_res.graphs)

        # Check that it fails, when having two sources that provide the same type.
        arraysource2 = ArrayTestSoure(voxel_size)
        pipeline_fail = (arraysource, arraysource2) + MergeProvider() + RandomLocation()
        with self.assertRaises(PipelineSetupError):
            with build(pipeline_fail):
                pass
Esempio n. 2
0
    def test_ensure_center_non_zero(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc = PointsKey("SWC")
        img = ArrayKey("IMG")
        pipeline = (SwcFileSource(
            path, [swc], [PointsSpec(roi=Roi((0, 0, 0), (11, 11, 11)))]) +
                    RandomLocation(ensure_nonempty=swc, ensure_centered=True) +
                    RasterizeSkeleton(
                        points=swc,
                        array=img,
                        array_spec=ArraySpec(
                            interpolatable=False,
                            dtype=np.uint32,
                            voxel_size=Coordinate((1, 1, 1)),
                        ),
                    ))

        request = BatchRequest()
        request.add(img, Coordinate((5, 5, 5)))
        request.add(swc, Coordinate((5, 5, 5)))

        with build(pipeline):
            batch = pipeline.request_batch(request)

            data = batch[img].data
            g = batch[swc]
            assert g.num_vertices() > 0

            self.assertNotEqual(data[tuple(np.array(data.shape) // 2)], 0)
    def test_placeholder(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = GraphKey("TEST_POINTS")

        pipeline = (
            PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) +
            ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) +
            Snapshot(
                {test_labels: "volumes/labels"},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            ))

        with build(pipeline):
            for i in range(2):

                request_size = Coordinate((40, 40, 40))

                request_a = BatchRequest(random_seed=i)
                request_a.add(test_points, request_size)
                request_a.add(test_labels, request_size, placeholder=True)

                request_b = BatchRequest(random_seed=i)
                request_b.add(test_points, request_size)
                request_b.add(test_labels, request_size)

                batch_a = pipeline.request_batch(request_a)
                batch_b = pipeline.request_batch(request_b)

                points_a = batch_a[test_points].nodes
                points_b = batch_b[test_points].nodes

                for a, b in zip(points_a, points_b):
                    assert all(np.isclose(a.location, b.location))
    def test_without_placeholder(self):

        test_labels = ArrayKey("TEST_LABELS")
        test_points = GraphKey("TEST_POINTS")

        pipeline = (
            PointTestSource3D() + RandomLocation(ensure_nonempty=test_points) +
            ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) +
            Snapshot(
                {test_labels: "volumes/labels"},
                output_dir=self.path_to(),
                output_filename="elastic_augment_test{id}-{iteration}.hdf",
            ))

        with build(pipeline):
            for i in range(2):

                request_size = Coordinate((40, 40, 40))

                request_a = BatchRequest(random_seed=i)
                request_a.add(test_points, request_size)

                request_b = BatchRequest(random_seed=i)
                request_b.add(test_points, request_size)
                request_b.add(test_labels, request_size)

                # No array to provide a voxel size to ElasticAugment
                with pytest.raises(PipelineRequestError):
                    pipeline.request_batch(request_a)
                batch_b = pipeline.request_batch(request_b)

                self.assertIn(test_labels, batch_b)
Esempio n. 5
0
    def test_pipeline3(self):
        array_key = ArrayKey("TEST_ARRAY")
        points_key = GraphKey("TEST_POINTS")
        voxel_size = Coordinate((1, 1))
        spec = ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = Hdf5Source(self.fake_data_file, {array_key: "testdata"},
                                 array_specs={array_key: spec})
        csv_source = CsvPointsSource(
            self.fake_points_file,
            points_key,
            GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))),
        )

        request = BatchRequest()
        shape = Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = ShiftAugment(prob_slip=0.2,
                                  prob_shift=0.2,
                                  sigma=5,
                                  shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + MergeProvider() +
                    RandomLocation(ensure_nonempty=points_key) + shift_node)
        with build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = list(request[points_key].nodes)
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(
                    result_val,
                    list(result_points),
                    target_vals,
                    self.fake_points,
                ),
            )
    def test_output(self):
        """
        Fails due to probabilities being calculated in advance, rather than after creating
        each roi. The new approach does not account for all possible roi's containing
        each point, some of which may not contain its nearest neighbors.
        """

        GraphKey('TEST_POINTS')

        pipeline = (ExampleSourceRandomLocation() + RandomLocation(
            ensure_nonempty=GraphKeys.TEST_POINTS, point_balance_radius=100))

        # count the number of times we get each point
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_POINTS:
                        GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))
                    }))

                points = {
                    node.id: node
                    for node in batch[GraphKeys.TEST_POINTS].nodes
                }

                self.assertTrue(len(points) > 0)
                self.assertTrue((1 in points) != (2 in points or 3 in points),
                                points)

                for node in batch[GraphKeys.TEST_POINTS].nodes:
                    if node.id not in histogram:
                        histogram[node.id] = 1
                    else:
                        histogram[node.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1)
    def test_req_full_roi(self):

        GraphKey("TEST_GRAPH")

        possible_roi = Roi((0, 0, 0), (1000, 1000, 1000))

        pipeline = (SourceGraphLocation() +
                    BatchTester(possible_roi, exact=False) +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH))
        with build(pipeline):

            batch = pipeline.request_batch(
                BatchRequest({
                    GraphKeys.TEST_GRAPH:
                    GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000)))
                }))

            assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
    def test_roi_one_point(self):

        GraphKey("TEST_GRAPH")
        upstream_roi = Roi((500, 500, 500), (1, 1, 1))

        pipeline = (SourceGraphLocation() +
                    BatchTester(upstream_roi, exact=True) +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH))

        with build(pipeline):
            for i in range(500):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_GRAPH:
                        GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1)))
                    }))

                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
    def test_dim_size_1(self):

        GraphKey("TEST_GRAPH")
        upstream_roi = Roi((500, 401, 401), (1, 200, 200))
        pipeline = (SourceGraphLocation() +
                    BatchTester(upstream_roi, exact=False) +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH))

        # count the number of times we get each node
        with build(pipeline):

            for i in range(500):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_GRAPH:
                        GraphSpec(roi=Roi((0, 0, 0), (1, 100, 100)))
                    }))

                assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1
    def test_output(self):

        GraphKey("TEST_GRAPH")

        pipeline = TestSourceRandomLocation() + RandomLocation(
            ensure_nonempty=GraphKeys.TEST_GRAPH)

        # count the number of times we get each node
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_GRAPH:
                        GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))
                    }))

                nodes = list(batch[GraphKeys.TEST_GRAPH].nodes)
                node_ids = [v.id for v in nodes]

                self.assertTrue(len(nodes) > 0)
                self.assertTrue(
                    (1 in node_ids) != (2 in node_ids or 3 in node_ids),
                    node_ids,
                )

                for node in batch[GraphKeys.TEST_GRAPH].nodes:
                    if node.id not in histogram:
                        histogram[node.id] = 1
                    else:
                        histogram[node.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1)
    def test_equal_probability(self):

        GraphKey('TEST_POINTS')

        pipeline = (ExampleSourceRandomLocation() +
                    RandomLocation(ensure_nonempty=GraphKeys.TEST_POINTS))

        # count the number of times we get each point
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_POINTS:
                        GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10)))
                    }))

                points = {
                    node.id: node
                    for node in batch[GraphKeys.TEST_POINTS].nodes
                }

                self.assertTrue(len(points) > 0)
                self.assertTrue((1 in points) != (2 in points or 3 in points),
                                points)

                for point in batch[GraphKeys.TEST_POINTS].nodes:
                    if point.id not in histogram:
                        histogram[point.id] = 1
                    else:
                        histogram[point.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1)
Esempio n. 12
0
def train():

    gunpowder.set_verbose(False)

    affinity_neighborhood = malis.mknhood3d()
    solver_parameters = gunpowder.caffe.SolverParameters()
    solver_parameters.train_net = 'net.prototxt'
    solver_parameters.base_lr = 1e-4
    solver_parameters.momentum = 0.95
    solver_parameters.momentum2 = 0.999
    solver_parameters.delta = 1e-8
    solver_parameters.weight_decay = 0.000005
    solver_parameters.lr_policy = 'inv'
    solver_parameters.gamma = 0.0001
    solver_parameters.power = 0.75
    solver_parameters.snapshot = 10000
    solver_parameters.snapshot_prefix = 'net'
    solver_parameters.type = 'Adam'
    solver_parameters.resume_from = None
    solver_parameters.train_state.add_stage('euclid')

    request = BatchRequest()
    request.add_volume_request(VolumeTypes.RAW, constants.input_shape)
    request.add_volume_request(VolumeTypes.GT_LABELS, constants.output_shape)
    request.add_volume_request(VolumeTypes.GT_MASK, constants.output_shape)
    request.add_volume_request(VolumeTypes.GT_AFFINITIES, constants.output_shape)
    request.add_volume_request(VolumeTypes.LOSS_SCALE, constants.output_shape)

    data_providers = list()
    fibsem_dir = "/groups/turaga/turagalab/data/FlyEM/fibsem_medulla_7col"
    for volume_name in ("tstvol-520-1-h5",):
        h5_filepath = "./{}.h5".format(volume_name)
        path_to_labels = os.path.join(fibsem_dir, volume_name, "groundtruth_seg.h5")
        with h5py.File(path_to_labels, "r") as f_labels:
            mask_shape = f_labels["main"].shape
        with h5py.File(h5_filepath, "w") as h5:
            h5['volumes/raw'] = h5py.ExternalLink(os.path.join(fibsem_dir, volume_name, "im_uint8.h5"), "main")
            h5['volumes/labels/neuron_ids'] = h5py.ExternalLink(path_to_labels, "main")
            h5.create_dataset(
                name="volumes/labels/mask",
                dtype="uint8",
                shape=mask_shape,
                fillvalue=1,
            )
        data_providers.append(
            gunpowder.Hdf5Source(
                h5_filepath,
                datasets={
                    VolumeTypes.RAW: 'volumes/raw',
                    VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids',
                    VolumeTypes.GT_MASK: 'volumes/labels/mask',
                },
                resolution=(8, 8, 8),
            )
        )
    dvid_source = DvidSource(
        hostname='slowpoke3',
        port=32788,
        uuid='341',
        raw_array_name='grayscale',
        gt_array_name='groundtruth',
        gt_mask_roi_name="seven_column_eroded7_z_lt_5024",
        resolution=(8, 8, 8),
    )
    data_providers.extend([dvid_source])
    data_providers = tuple(
        provider +
        RandomLocation() +
        Reject(min_masked=0.5) +
        Normalize()
        for provider in data_providers
    )

    # create a batch provider by concatenation of filters
    batch_provider = (
        data_providers +
        RandomProvider() +
        ElasticAugment([20, 20, 20], [0, 0, 0], [0, math.pi / 2.0]) +
        SimpleAugment(transpose_only_xy=False) +
        GrowBoundary(steps=2, only_xy=False) +
        AddGtAffinities(affinity_neighborhood) +
        BalanceAffinityLabels() +
        SplitAndRenumberSegmentationLabels() +
        IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=False) +
        PreCache(
            request,
            cache_size=11,
            num_workers=10) +
        Train(solver_parameters, use_gpu=0) +
        Typecast(volume_dtypes={
            VolumeTypes.GT_LABELS: np.dtype("uint32"),
            VolumeTypes.GT_MASK: np.dtype("uint8"),
            VolumeTypes.LOSS_SCALE: np.dtype("float32"),
        }, safe=True) +
        Snapshot(every=50, output_filename='batch_{id}.hdf') +
        PrintProfilingStats(every=50)
    )

    n = 500000
    print("Training for", n, "iterations")

    with gunpowder.build(batch_provider) as pipeline:
        for i in range(n):
            pipeline.request_batch(request)

    print("Finished")
Esempio n. 13
0
def visualize_augmentations_paintera(args=None):

    data_providers = []
    data_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data'
    # data_dir = os.path.expanduser('~/Dropbox/cremi-upsampled/')
    file_pattern = 'sample_A_padded_20160501-2-additional-sections-fixed-offset.h5'
    file_pattern = 'sample_B_padded_20160501-2-additional-sections-fixed-offset.h5'
    file_pattern = 'sample_C_padded_20160501-2-additional-sections-fixed-offset.h5'

    defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects'

    artifact_source = (
        Hdf5Source(
            os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                RAW: 'defect_sections/raw',
                ALPHA_MASK: 'defect_sections/mask',
            },
            array_specs={
                RAW: ArraySpec(voxel_size=tuple(d * 9 for d in (40, 4, 4))),
                ALPHA_MASK: ArraySpec(voxel_size=tuple(d * 9
                                                       for d in (40, 4, 4))),
            }) + RandomLocation(min_masked=0.05, mask=ALPHA_MASK) +
        Normalize(RAW) +
        IntensityAugment(RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment(voxel_size=(360, 36, 36),
                       control_point_spacing=(4, 40, 40),
                       control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
                       rotation_interval=(0, np.pi / 2.0),
                       subsample=8) + SimpleAugment(transpose_only=[1, 2]))

    for data in glob.glob(os.path.join(data_dir, file_pattern)):
        h5_source = Hdf5Source(data,
                               datasets={
                                   RAW:
                                   'volumes/raw',
                                   GT_LABELS:
                                   'volumes/labels/neuron_ids-downsampled',
                               })
        data_providers.append(h5_source)

    input_resolution = (360.0, 36.0, 36.0)
    output_resolution = Coordinate((120.0, 108.0, 108.0))
    offset = (13640, 10932, 10932)

    output_shape = Coordinate((60.0, 100.0, 100.0)) * output_resolution
    output_offset = (13320 + 3600, 32796 + 36 + 10800, 32796 + 36 + 10800)

    overhang = Coordinate((360.0, 108.0, 108.0)) * 16

    input_shape = output_shape + overhang * 2
    input_offset = Coordinate(output_offset) - overhang

    output_roi = Roi(offset=output_offset, shape=output_shape)
    input_roi = Roi(offset=input_offset, shape=input_shape)

    augmentations = (
        Snapshot(dataset_names={
            GT_LABELS: 'volumes/gt',
            RAW: 'volumes/raw'
        },
                 output_dir='.',
                 output_filename='snapshot-before.h5',
                 attributes_callback=Snapshot.default_attributes_callback()),
        ElasticAugment(voxel_size=(360.0, 36.0, 36.0),
                       control_point_spacing=(4, 40, 40),
                       control_point_displacement_sigma=(0, 5 * 2 * 36,
                                                         5 * 2 * 36),
                       rotation_interval=(0 * np.pi / 8, 0 * 2 * np.pi),
                       subsample=8,
                       augmentation_probability=1.0,
                       seed=None),
        Misalign(z_resolution=360,
                 prob_slip=0.2,
                 prob_shift=0.5,
                 max_misalign=(3600, 0),
                 seed=100,
                 ignore_keys_for_slip=(GT_LABELS, )),
        Snapshot(dataset_names={
            GT_LABELS: 'volumes/gt',
            RAW: 'volumes/raw'
        },
                 output_dir='.',
                 output_filename='snapshot-after.h5',
                 attributes_callback=Snapshot.default_attributes_callback()))

    keys = (RAW, GT_LABELS)[:]

    batch, snapshot = run_augmentations(
        data_providers=data_providers,
        roi=lambda key: output_roi.copy()
        if key == GT_LABELS else input_roi.copy(),
        augmentations=augmentations,
        keys=keys,
        voxel_size=lambda key: input_resolution
        if key == RAW else (output_resolution if key == GT_LABELS else None))

    args = get_parser().parse_args() if args is None else args
    jnius_config.add_options('-Xmx{}'.format(args.max_heap_size))

    import payntera.jfx
    from jnius import autoclass
    payntera.jfx.init_platform()

    PainteraBaseView = autoclass(
        'org.janelia.saalfeldlab.paintera.PainteraBaseView')
    viewer = PainteraBaseView.defaultView()
    pbv = viewer.baseView
    scene, stage = payntera.jfx.start_stage(viewer.paneWithStatus.getPane())
    screen_scale_setter = lambda: pbv.orthogonalViews().setScreenScales(
        [0.3, 0.1, 0.03])
    # payntera.jfx.invoke_on_jfx_application_thread(screen_scale_setter)

    snapshot_states = add_to_viewer(
        snapshot, keys=keys, name=lambda key: '%s-snapshot' % key.identifier)
    states = add_to_viewer(batch, keys=keys)

    viewer.keyTracker.installInto(scene)
    scene.addEventFilter(
        autoclass('javafx.scene.input.MouseEvent').ANY, viewer.mouseTracker)

    while stage.isShowing():
        time.sleep(0.1)
def train_until(
        data_providers,
        affinity_neighborhood,
        meta_graph_filename,
        stop,
        input_shape,
        output_shape,
        loss,
        optimizer,
        tensor_affinities,
        tensor_affinities_mask,
        tensor_glia,
        tensor_glia_mask,
        summary,
        save_checkpoint_every,
        pre_cache_size,
        pre_cache_num_workers,
        snapshot_every,
        balance_labels,
        renumber_connected_components,
        network_inputs,
        ignore_labels_for_slip,
        grow_boundaries,
        mask_out_labels,
        snapshot_dir):

    ignore_keys_for_slip = (LABELS_KEY, GT_MASK_KEY, GT_GLIA_KEY, GLIA_MASK_KEY, UNLABELED_KEY) if ignore_labels_for_slip else ()

    defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects'
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    input_voxel_size = Coordinate((120, 12, 12)) * 3
    output_voxel_size = Coordinate((40, 36, 36)) * 3

    input_size = Coordinate(input_shape) * input_voxel_size
    output_size = Coordinate(output_shape) * output_voxel_size

    num_affinities = sum(len(nh) for nh in affinity_neighborhood)
    gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size))
    print("gt affinities size", gt_affinities_size)

    # TODO why is GT_AFFINITIES three-dimensional? compare to
    # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35
    # TODO Use glia scale somehow, probably not possible with tensorflow 1.3 because it does not know uint64...
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(RAW_KEY,             input_size,  voxel_size=input_voxel_size)
    request.add(LABELS_KEY,          output_size, voxel_size=output_voxel_size)
    request.add(GT_AFFINITIES_KEY,   output_size, voxel_size=output_voxel_size)
    request.add(AFFINITIES_MASK_KEY, output_size, voxel_size=output_voxel_size)
    request.add(GT_MASK_KEY,         output_size, voxel_size=output_voxel_size)
    request.add(GLIA_MASK_KEY,       output_size, voxel_size=output_voxel_size)
    request.add(GLIA_KEY,            output_size, voxel_size=output_voxel_size)
    request.add(GT_GLIA_KEY,         output_size, voxel_size=output_voxel_size)
    request.add(UNLABELED_KEY,       output_size, voxel_size=output_voxel_size)
    if balance_labels:
        request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size)
    # always balance glia labels!
    request.add(GLIA_SCALE_KEY, output_size, voxel_size=output_voxel_size)
    network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY
    network_inputs[tensor_glia_mask]       = GLIA_SCALE_KEY#GLIA_SCALE_KEY if balance_labels else GLIA_MASK_KEY

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(RAW_KEY, None) +
        Pad(GT_MASK_KEY, None) +
        Pad(GLIA_MASK_KEY, None) +
        Pad(LABELS_KEY, size=NETWORK_OUTPUT_SHAPE / 2, value=np.uint64(-3)) +
        Pad(GT_GLIA_KEY, size=NETWORK_OUTPUT_SHAPE / 2) +
        # Pad(LABELS_KEY, None) +
        # Pad(GT_GLIA_KEY, None) +
        RandomLocation() + # chose a random location inside the provided arrays
        Reject(mask=GT_MASK_KEY, min_masked=0.5) +
        Reject(mask=GLIA_MASK_KEY, min_masked=0.5) +
        MapNumpyArray(lambda array: np.require(array, dtype=np.int64), GT_GLIA_KEY) # this is necessary because gunpowder 1.3 only understands int64, not uint64

        for provider in data_providers)

    # TODO figure out what this is for
    snapshot_request = BatchRequest({
        LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY],
        AFFINITIES_KEY    : request[GT_AFFINITIES_KEY],
    })

    # no need to do anything here. random sections will be replaced with sections from this source (only raw)
    artifact_source = (
        Hdf5Source(
            os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                RAW_KEY        : 'defect_sections/raw',
                DEFECT_MASK_KEY : 'defect_sections/mask',
            },
            array_specs={
                RAW_KEY        : ArraySpec(voxel_size=input_voxel_size),
                DEFECT_MASK_KEY : ArraySpec(voxel_size=input_voxel_size),
            }
        ) +
        RandomLocation(min_masked=0.05, mask=DEFECT_MASK_KEY) +
        Normalize(RAW_KEY) +
        IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            subsample=8
        ) +
        SimpleAugment(transpose_only=[1,2])
    )

    train_pipeline  = data_sources
    train_pipeline += RandomProvider()

    train_pipeline += ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            augmentation_probability=0.5,
            subsample=8
        )

    # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: before misalign: ' % GT_MASK_KEY)
    train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip)
    # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: after  misalign: ' % GT_MASK_KEY)

    train_pipeline += SimpleAugment(transpose_only=[1,2])
    train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
    train_pipeline += DefectAugment(RAW_KEY,
                                    prob_missing=0.03,
                                    prob_low_contrast=0.01,
                                    prob_artifact=0.03,
                                    artifact_source=artifact_source,
                                    artifacts=RAW_KEY,
                                    artifacts_mask=DEFECT_MASK_KEY,
                                    contrast_scale=0.5)
    train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1)
    train_pipeline += ZeroOutConstSections(RAW_KEY)

    if grow_boundaries > 0:
        train_pipeline += GrowBoundary(LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True)

    _logger.info("Renumbering connected components? %s", renumber_connected_components)
    if renumber_connected_components:
        train_pipeline += RenumberConnectedComponents(labels=LABELS_KEY)

    train_pipeline += NewKeyFromNumpyArray(lambda array: 1 - array, GT_GLIA_KEY, UNLABELED_KEY)

    if len(mask_out_labels) > 0:
        train_pipeline += MaskOutLabels(label_key=LABELS_KEY, mask_key=GT_MASK_KEY, ids_to_be_masked=mask_out_labels)

    # labels_mask: anything that connects into labels_mask will be zeroed out
    # unlabelled: anyhing that points into unlabeled will have zero affinity;
    #             affinities within unlabelled will be masked out
    train_pipeline += AddAffinities(
            affinity_neighborhood=affinity_neighborhood,
            labels=LABELS_KEY,
            labels_mask=GT_MASK_KEY,
            affinities=GT_AFFINITIES_KEY,
            affinities_mask=AFFINITIES_MASK_KEY,
            unlabelled=UNLABELED_KEY
    )

    snapshot_datasets = {
        RAW_KEY: 'volumes/raw',
        LABELS_KEY: 'volumes/labels/neuron_ids',
        GT_AFFINITIES_KEY: 'volumes/affinities/gt',
        GT_GLIA_KEY: 'volumes/labels/glia_gt',
        UNLABELED_KEY: 'volumes/labels/unlabeled',
        AFFINITIES_KEY: 'volumes/affinities/prediction',
        LOSS_GRADIENT_KEY: 'volumes/loss_gradient',
        AFFINITIES_MASK_KEY: 'masks/affinities',
        GLIA_KEY: 'volumes/labels/glia_pred',
        GT_MASK_KEY: 'masks/gt',
        GLIA_MASK_KEY: 'masks/glia'}

    if balance_labels:
        train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY)
        snapshot_datasets[AFFINITIES_SCALE_KEY] = 'masks/affinity-scale'
    train_pipeline += BalanceLabels(labels=GT_GLIA_KEY, scales=GLIA_SCALE_KEY, mask=GLIA_MASK_KEY)
    snapshot_datasets[GLIA_SCALE_KEY] = 'masks/glia-scale'


    if (pre_cache_size > 0 and pre_cache_num_workers > 0):
        train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers)
    train_pipeline += Train(
            summary=summary,
            graph=meta_graph_filename,
            save_every=save_checkpoint_every,
            optimizer=optimizer,
            loss=loss,
            inputs=network_inputs,
            log_dir='log',
            outputs={tensor_affinities: AFFINITIES_KEY, tensor_glia: GLIA_KEY},
            gradients={tensor_affinities: LOSS_GRADIENT_KEY},
            array_specs={
                AFFINITIES_KEY       : ArraySpec(voxel_size=output_voxel_size),
                LOSS_GRADIENT_KEY    : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_MASK_KEY  : ArraySpec(voxel_size=output_voxel_size),
                GT_MASK_KEY          : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size),
                GLIA_MASK_KEY        : ArraySpec(voxel_size=output_voxel_size),
                GLIA_SCALE_KEY       : ArraySpec(voxel_size=output_voxel_size),
                GLIA_KEY             : ArraySpec(voxel_size=output_voxel_size)
            }
        )

    train_pipeline += Snapshot(
            snapshot_datasets,
            every=snapshot_every,
            output_filename='batch_{iteration}.hdf',
            output_dir=snapshot_dir,
            additional_request=snapshot_request,
            attributes_callback=Snapshot.default_attributes_callback())

    train_pipeline += PrintProfilingStats(every=50)

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(trained_until, stop):
            b.request_batch(request)

    print("Training finished")
Esempio n. 15
0
    def test_get_neuron_pair(self):
        path = Path(self.path_to("test_swc_source.swc"))

        # write test swc
        self._write_swc(path, self._toy_swc_points().to_nx_graph())

        # read arrays
        swc_source = PointsKey("SWC_SOURCE")
        ensure_nonempty = PointsKey("ENSURE_NONEMPTY")
        labels_source = ArrayKey("LABELS_SOURCE")
        img_source = ArrayKey("IMG_SOURCE")
        img_swc = PointsKey("IMG_SWC")
        label_swc = PointsKey("LABEL_SWC")
        imgs = ArrayKey("IMGS")
        labels = ArrayKey("LABELS")
        points_a = PointsKey("SKELETON_A")
        points_b = PointsKey("SKELETON_B")
        img_a = ArrayKey("VOLUME_A")
        img_b = ArrayKey("VOLUME_B")
        labels_a = ArrayKey("LABELS_A")
        labels_b = ArrayKey("LABELS_B")

        data_shape = 5
        output_shape = Coordinate((data_shape, data_shape, data_shape))

        # Get points from test swc
        swc_file_source = SwcFileSource(
            path,
            [swc_source, ensure_nonempty],
            [
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
                PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31))),
            ],
        )
        # Create an artificial image source by rasterizing the points
        image_source = (SwcFileSource(
            path, [img_swc],
            [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]) +
                        RasterizeSkeleton(
                            points=img_swc,
                            array=img_source,
                            array_spec=ArraySpec(
                                interpolatable=True,
                                dtype=np.uint32,
                                voxel_size=Coordinate((1, 1, 1)),
                            ),
                        ) +
                        BinarizeLabels(labels=img_source, labels_binary=imgs) +
                        GrowLabels(array=imgs, radius=0))
        # Create an artificial label source by rasterizing the points
        label_source = (SwcFileSource(path, [label_swc], [
            PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))
        ]) + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True,
                dtype=np.uint32,
                voxel_size=Coordinate((1, 1, 1)),
            ),
        ) + BinarizeLabels(labels=labels_source, labels_binary=labels) +
                        GrowLabels(array=labels, radius=1))

        skeleton = tuple()
        skeleton += ((swc_file_source, image_source, label_source) +
                     MergeProvider() +
                     RandomLocation(ensure_nonempty=ensure_nonempty,
                                    ensure_centered=True))

        pipeline = skeleton + GetNeuronPair(
            point_source=swc_source,
            nonempty_placeholder=ensure_nonempty,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=(1, 3),
            shift_attempts=100,
            request_attempts=10,
            output_shape=output_shape,
        )

        request = BatchRequest()

        request.add(points_a, output_shape)
        request.add(points_b, output_shape)
        request.add(img_a, output_shape)
        request.add(img_b, output_shape)
        request.add(labels_a, output_shape)
        request.add(labels_b, output_shape)

        with build(pipeline):
            for i in range(10):
                batch = pipeline.request_batch(request)
                assert all([
                    x in batch for x in
                    [points_a, points_b, img_a, img_b, labels_a, labels_b]
                ])

                min_dist = 5
                for a, b in itertools.product(
                        batch[points_a].nodes,
                        batch[points_b].nodes,
                ):
                    min_dist = min(
                        min_dist,
                        np.linalg.norm(a.location - b.location),
                    )

                self.assertLessEqual(min_dist, 3)
                self.assertGreaterEqual(min_dist, 1)
file_pattern = 'sample_B_padded_20160501-2-additional-sections-fixed-offset.h5'
file_pattern = 'sample_C_padded_20160501-2-additional-sections-fixed-offset.h5'

defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects'

artifact_source = (
    Hdf5Source(
        os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'),
        datasets={
            RAW: 'defect_sections/raw',
            ALPHA_MASK: 'defect_sections/mask',
        },
        array_specs={
            RAW: ArraySpec(voxel_size=tuple(d * 9 for d in (40, 4, 4))),
            ALPHA_MASK: ArraySpec(voxel_size=tuple(d * 9 for d in (40, 4, 4))),
        }) + RandomLocation(min_masked=0.05, mask=ALPHA_MASK) +
    Normalize(RAW) +
    IntensityAugment(RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
    ElasticAugment(voxel_size=(360, 36, 36),
                   control_point_spacing=(4, 40, 40),
                   control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
                   rotation_interval=(0, np.pi / 2.0),
                   subsample=8) + SimpleAugment(transpose_only=[1, 2]))

for data in glob.glob(os.path.join(data_dir, file_pattern)):
    h5_source = Hdf5Source(data,
                           datasets={
                               RAW: 'volumes/raw',
                               GT_LABELS:
                               'volumes/labels/neuron_ids-downsampled',
                           })
Esempio n. 17
0
def train_until(
        data_providers,
        affinity_neighborhood,
        meta_graph_filename,
        stop,
        input_shape,
        output_shape,
        loss,
        optimizer,
        tensor_affinities,
        tensor_affinities_nn,
        tensor_affinities_mask,
        summary,
        save_checkpoint_every,
        pre_cache_size,
        pre_cache_num_workers,
        snapshot_every,
        balance_labels,
        renumber_connected_components,
        network_inputs,
        ignore_labels_for_slip,
        grow_boundaries):

    ignore_keys_for_slip = (GT_LABELS_KEY, GT_MASK_KEY) if ignore_labels_for_slip else ()

    defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects'
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    input_voxel_size = Coordinate((120, 12, 12)) * 3
    output_voxel_size = Coordinate((40, 36, 36)) * 3

    input_size = Coordinate(input_shape) * input_voxel_size
    output_size = Coordinate(output_shape) * output_voxel_size
    output_size_nn = Coordinate(s - 2 for s in output_shape) * output_voxel_size

    num_affinities = sum(len(nh) for nh in affinity_neighborhood)
    gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size))
    print("gt affinities size", gt_affinities_size)

    # TODO why is GT_AFFINITIES three-dimensional? compare to
    # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(RAW_KEY,             input_size,     voxel_size=input_voxel_size)
    request.add(GT_LABELS_KEY,       output_size,    voxel_size=output_voxel_size)
    request.add(GT_AFFINITIES_KEY,   output_size,    voxel_size=output_voxel_size)
    request.add(AFFINITIES_MASK_KEY, output_size,    voxel_size=output_voxel_size)
    request.add(GT_MASK_KEY,         output_size,    voxel_size=output_voxel_size)
    request.add(AFFINITIES_NN_KEY,   output_size_nn, voxel_size=output_voxel_size)
    if balance_labels:
        request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size)
    network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(RAW_KEY, None) +
        Pad(GT_MASK_KEY, None) +
        RandomLocation() + # chose a random location inside the provided arrays
        Reject(GT_MASK_KEY) + # reject batches wich do contain less than 50% labelled data
        Reject(GT_LABELS_KEY, min_masked=0.0, reject_probability=0.95)

        for provider in data_providers)

    # TODO figure out what this is for
    snapshot_request = BatchRequest({
        LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY],
        AFFINITIES_KEY    : request[GT_AFFINITIES_KEY],
        AFFINITIES_NN_KEY : request[AFFINITIES_NN_KEY]
    })

    # no need to do anything here. random sections will be replaced with sections from this source (only raw)
    artifact_source = (
        Hdf5Source(
            os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                RAW_KEY        : 'defect_sections/raw',
                ALPHA_MASK_KEY : 'defect_sections/mask',
            },
            array_specs={
                RAW_KEY        : ArraySpec(voxel_size=input_voxel_size),
                ALPHA_MASK_KEY : ArraySpec(voxel_size=input_voxel_size),
            }
        ) +
        RandomLocation(min_masked=0.05, mask=ALPHA_MASK_KEY) +
        Normalize(RAW_KEY) +
        IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            subsample=8
        ) +
        SimpleAugment(transpose_only=[1,2])
    )

    train_pipeline  = data_sources
    train_pipeline += RandomProvider()
    train_pipeline += ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            augmentation_probability=0.5,
            subsample=8
        )
    train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip)
    train_pipeline += SimpleAugment(transpose_only=[1,2])
    train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
    train_pipeline += DefectAugment(RAW_KEY,
                      prob_missing=0.03,
                      prob_low_contrast=0.01,
                      prob_artifact=0.03,
                      artifact_source=artifact_source,
                      artifacts=RAW_KEY,
                      artifacts_mask=ALPHA_MASK_KEY,
                      contrast_scale=0.5)
    train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1)
    train_pipeline += ZeroOutConstSections(RAW_KEY)
    if grow_boundaries > 0:
        train_pipeline += GrowBoundary(GT_LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True)

    if renumber_connected_components:
        train_pipeline += RenumberConnectedComponents(labels=GT_LABELS_KEY)

    train_pipeline += AddAffinities(
            affinity_neighborhood=affinity_neighborhood,
            labels=GT_LABELS_KEY,
            labels_mask=GT_MASK_KEY,
            affinities=GT_AFFINITIES_KEY,
            affinities_mask=AFFINITIES_MASK_KEY
        )

    if balance_labels:
        train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY)

    train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers)
    train_pipeline += Train(
            summary=summary,
            graph=meta_graph_filename,
            save_every=save_checkpoint_every,
            optimizer=optimizer,
            loss=loss,
            inputs=network_inputs,
            log_dir='log',
            outputs={tensor_affinities: AFFINITIES_KEY, tensor_affinities_nn: AFFINITIES_NN_KEY},
            gradients={tensor_affinities: LOSS_GRADIENT_KEY},
            array_specs={
                AFFINITIES_KEY       : ArraySpec(voxel_size=output_voxel_size),
                LOSS_GRADIENT_KEY    : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_MASK_KEY  : ArraySpec(voxel_size=output_voxel_size),
                GT_MASK_KEY          : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_NN_KEY    : ArraySpec(voxel_size=output_voxel_size)
            }
        )
    train_pipeline += Snapshot(
            dataset_names={
                RAW_KEY             : 'volumes/raw',
                GT_LABELS_KEY       : 'volumes/labels/neuron_ids',
                GT_AFFINITIES_KEY   : 'volumes/affinities/gt',
                AFFINITIES_KEY      : 'volumes/affinities/prediction',
                LOSS_GRADIENT_KEY   : 'volumes/loss_gradient',
                AFFINITIES_MASK_KEY : 'masks/affinities',
                AFFINITIES_NN_KEY   : 'volumes/affinities/prediction-nn'
            },
            every=snapshot_every,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request,
            attributes_callback=Snapshot.default_attributes_callback())
    train_pipeline += PrintProfilingStats(every=50)

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(trained_until, stop):
            b.request_batch(request)

    print("Training finished")
Esempio n. 18
0
def test_recenter():
    path = Path(self.path_to("test_swc_source.swc"))

    # write test swc
    self._write_swc(path, self._toy_swc_points())

    # read arrays
    swc_source = PointsKey("SWC_SOURCE")
    labels_source = ArrayKey("LABELS_SOURCE")
    img_source = ArrayKey("IMG_SOURCE")
    img_swc = PointsKey("IMG_SWC")
    label_swc = PointsKey("LABEL_SWC")
    imgs = ArrayKey("IMGS")
    labels = ArrayKey("LABELS")
    points_a = PointsKey("SKELETON_A")
    points_b = PointsKey("SKELETON_B")
    img_a = ArrayKey("VOLUME_A")
    img_b = ArrayKey("VOLUME_B")
    labels_a = ArrayKey("LABELS_A")
    labels_b = ArrayKey("LABELS_B")

    # Get points from test swc
    swc_file_source = SwcFileSource(
        path, [swc_source], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
    )
    # Create an artificial image source by rasterizing the points
    image_source = (
        SwcFileSource(
            path, [img_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=img_swc,
            array=img_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=img_source, labels_binary=imgs)
        + GrowLabels(array=imgs, radius=0)
    )
    # Create an artificial label source by rasterizing the points
    label_source = (
        SwcFileSource(
            path, [label_swc], [PointsSpec(roi=Roi((-10, -10, -10), (31, 31, 31)))]
        )
        + RasterizeSkeleton(
            points=label_swc,
            array=labels_source,
            array_spec=ArraySpec(
                interpolatable=True, dtype=np.uint32, voxel_size=Coordinate((1, 1, 1))
            ),
        )
        + BinarizeLabels(labels=labels_source, labels_binary=labels)
        + GrowLabels(array=labels, radius=1)
    )

    skeleton = tuple()
    skeleton += (
        (swc_file_source, image_source, label_source)
        + MergeProvider()
        + RandomLocation(ensure_nonempty=swc_source, ensure_centered=True)
    )

    pipeline = (
        skeleton
        + GetNeuronPair(
            point_source=swc_source,
            array_source=imgs,
            label_source=labels,
            points=(points_a, points_b),
            arrays=(img_a, img_b),
            labels=(labels_a, labels_b),
            seperate_by=4,
            shift_attempts=100,
        )
        + Recenter(points_a, img_a, max_offset=4)
        + Recenter(points_b, img_b, max_offset=4)
    )

    request = BatchRequest()

    data_shape = 9

    request.add(points_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(points_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(img_b, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_a, Coordinate((data_shape, data_shape, data_shape)))
    request.add(labels_b, Coordinate((data_shape, data_shape, data_shape)))

    with build(pipeline):
        batch = pipeline.request_batch(request)

    data_a = batch[img_a].data
    assert data_a[tuple(np.array(data_a.shape) // 2)] == 1
    data_a = np.pad(data_a, (1,), "constant", constant_values=(0,))
    data_b = batch[img_b].data
    assert data_b[tuple(np.array(data_b.shape) // 2)] == 1
    data_b = np.pad(data_b, (1,), "constant", constant_values=(0,))

    data_c = data_a + data_b

    data = np.array((data_a, data_b, data_c))

    for _, point in batch[points_a].data.items():
        assert (
            data[(0,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1, its {}".format(
            point.location, data[(0,) + tuple(int(x) for x in point.location)]
        )

    for _, point in batch[points_b].data.items():
        assert (
            data[(1,) + tuple(int(x) + 1 for x in point.location)] == 1
        ), "data at {} is not 1".format(point.location)
    def test_ensure_centered(self):
        """
        Expected failure due to emergent behavior of two desired rules:
        1) Points on the upper bound of Roi are not considered contained
        2) When considering a point as a center of a random location,
            scale by the number of points within some delta distance

        if two points are equally likely to be chosen, and centering
        a roi on either of them means the other is on the bounding box
        of the roi, then it can be the case that if the roi is centered
        one of them, the roi contains only that one, but if the roi is
        centered on the second, then both are considered contained,
        breaking the equal likelihood of picking each point.
        """

        GraphKey("TEST_POINTS")

        pipeline = ExampleSourceRandomLocation() + RandomLocation(
            ensure_nonempty=GraphKeys.TEST_POINTS, ensure_centered=True)

        # count the number of times we get each point
        histogram = {}

        with build(pipeline):

            for i in range(5000):
                batch = pipeline.request_batch(
                    BatchRequest({
                        GraphKeys.TEST_POINTS:
                        GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))
                    }))

                points = batch[GraphKeys.TEST_POINTS].data
                roi = batch[GraphKeys.TEST_POINTS].spec.roi

                locations = tuple(
                    [Coordinate(point.location) for point in points.values()])
                self.assertTrue(
                    Coordinate([50, 50, 50]) in locations,
                    f"locations: {tuple([point.location for point in points.values()])}"
                )

                self.assertTrue(len(points) > 0)
                self.assertTrue((1 in points) != (2 in points or 3 in points),
                                points)

                for point_id in batch[GraphKeys.TEST_POINTS].data.keys():
                    if point_id not in histogram:
                        histogram[point_id] = 1
                    else:
                        histogram[node.id] += 1

        total = sum(histogram.values())
        for k, v in histogram.items():
            histogram[k] = float(v) / total

        # we should get roughly the same count for each point
        for i in histogram.keys():
            for j in histogram.keys():
                self.assertAlmostEqual(histogram[i], histogram[j], 1,
                                       histogram)