예제 #1
0
    def test_shift_points5(self):
        data = {
            0: gp.Point([3, 0]),
            1: gp.Point([3, 2]),
            2: gp.Point([3, 4]),
            3: gp.Point([3, 6]),
            4: gp.Point([3, 8])
        }
        spec = gp.PointsSpec(gp.Roi(offset=(0, 0), shape=(15, 10)))
        points = gp.Points(data, spec)
        request_roi = gp.Roi(offset=(3, 0), shape=(9, 10))
        shift_array = np.array([[3, 0], [-3, 0], [0, 0], [-3, 0], [3, 0]],
                               dtype=int)

        lcm_voxel_size = gp.Coordinate((3, 2))
        shifted_data = {
            0: gp.Point([6, 0]),
            2: gp.Point([3, 4]),
            4: gp.Point([6, 8])
        }
        result = gp.ShiftAugment.shift_points(points,
                                              request_roi,
                                              shift_array,
                                              shift_axis=1,
                                              lcm_voxel_size=lcm_voxel_size)
        # print("test 4", result.data, shifted_data)
        self.assertTrue(self.points_equal(result.data, shifted_data))
        self.assertTrue(result.spec == gp.PointsSpec(request_roi))
예제 #2
0
    def setup(self):

        self.spec_src = gp.PointsSpec()
        self.spec_trg = gp.PointsSpec()

        self.provides(self.srcpoints, self.spec_src)
        self.provides(self.trgpoints, self.spec_trg)

        self.enable_autoskip()
예제 #3
0
    def test_context(self):
        d_pred = gp.ArrayKeys.D_PRED
        m_pred = gp.ArrayKeys.M_PRED
        presyn = gp.PointsKeys.PRESYN
        postsyn = gp.PointsKeys.POSTSYN

        outdir = tempfile.mkdtemp()

        voxel_size = gp.Coordinate((10, 10, 10))
        size = ((200, 200, 200))
        # Check whether the score of the entire cube is measured, although
        # cube of borderpoint partially outside request ROI.
        context = 40
        shape = gp.Coordinate(size) / voxel_size
        m_predar = np.zeros(shape, dtype=np.float32)
        outsidepoint = gp.Coordinate((13, 13, 13))
        borderpoint = (4, 4, 4)
        m_predar[3:5, 3:5, 3:5] = 1
        m_predar[outsidepoint] = 1

        d_predar = np.ones((3, shape[0], shape[1], shape[2])) * 0

        pipeline = (TestSource(m_predar, d_predar, voxel_size=voxel_size) +
                    ExtractSynapses(m_pred,
                                    d_pred,
                                    presyn,
                                    postsyn,
                                    out_dir=outdir,
                                    settings=parameters,
                                    context=context) +
                    gp.PrintProfilingStats())

        request = gp.BatchRequest()

        roi = gp.Roi((40, 40, 40), (80, 80, 80))

        request[presyn] = gp.PointsSpec(roi=roi)
        request[postsyn] = gp.PointsSpec(roi=roi)
        with gp.build(pipeline):
            batch = pipeline.request_batch(request)

        synapsefile = os.path.join(outdir, "40", "40", "40.npz")
        with np.load(synapsefile) as data:
            data = dict(data)

        self.assertTrue(len(data['ids']) == 1)
        self.assertEqual(data['scores'][0], 2.0**3)  # Size of the cube.
        for ii in range(len(voxel_size)):
            self.assertEqual(data['positions'][0][0][ii],
                             borderpoint[ii] * voxel_size[ii])

        for ii in range(len(voxel_size)):
            self.assertEqual(data['positions'][0][1][ii],
                             borderpoint[ii] * voxel_size[ii] + 0)
        shutil.rmtree(outdir)
예제 #4
0
    def test_output_basics(self):
        d_pred = gp.ArrayKeys.D_PRED
        m_pred = gp.ArrayKeys.M_PRED
        presyn = gp.PointsKeys.PRESYN
        postsyn = gp.PointsKeys.POSTSYN

        voxel_size = gp.Coordinate((10, 10, 10))
        size = ((200, 200, 200))
        context = 40
        shape = gp.Coordinate(size) / voxel_size
        m_predar = np.zeros(shape, dtype=np.float32)
        insidepoint = gp.Coordinate((10, 10, 10))
        outsidepoint = gp.Coordinate((15, 15, 15))
        m_predar[insidepoint] = 1
        m_predar[outsidepoint] = 1

        d_predar = np.ones((3, shape[0], shape[1], shape[2])) * 10

        outdir = tempfile.mkdtemp()

        pipeline = (TestSource(m_predar, d_predar, voxel_size=voxel_size) +
                    ExtractSynapses(m_pred,
                                    d_pred,
                                    presyn,
                                    postsyn,
                                    out_dir=outdir,
                                    settings=parameters,
                                    context=context))

        request = gp.BatchRequest()

        roi = gp.Roi((40, 40, 40), (80, 80, 80))

        request[presyn] = gp.PointsSpec(roi=roi)
        request[postsyn] = gp.PointsSpec(roi=roi)
        with gp.build(pipeline):
            batch = pipeline.request_batch(request)
        print(outdir, "outdir")
        synapsefile = os.path.join(outdir, "40", "40", "40.npz")
        with np.load(synapsefile) as data:
            data = dict(data)

        self.assertTrue(len(data['ids']) == 1)
        self.assertEqual(data['scores'][0], 1.0)  # Size of the cube.
        for ii in range(len(voxel_size)):
            self.assertEqual(data['positions'][0][1][ii],
                             insidepoint[ii] * voxel_size[ii])

        for ii in range(len(voxel_size)):
            self.assertEqual(data['positions'][0][0][ii],
                             insidepoint[ii] * voxel_size[ii] + 10)
        shutil.rmtree(outdir)
예제 #5
0
    def test_shift_points2(self):
        data = {1: gp.Point([0, 1])}
        spec = gp.PointsSpec(gp.Roi(offset=(0, 0), shape=(5, 5)))
        points = gp.Points(data, spec)
        request_roi = gp.Roi(offset=(0, 1), shape=(5, 3))
        shift_array = np.array([[0, 0], [0, -1], [0, 0], [0, 0], [0, 1]],
                               dtype=int)
        lcm_voxel_size = gp.Coordinate((1, 1))

        result = gp.ShiftAugment.shift_points(points,
                                              request_roi,
                                              shift_array,
                                              shift_axis=0,
                                              lcm_voxel_size=lcm_voxel_size)
        # print("test 2", result.data, data)
        self.assertTrue(self.points_equal(result.data, data))
        self.assertTrue(result.spec == gp.PointsSpec(request_roi))
예제 #6
0
    def test_pipeline3(self):
        array_key = gp.ArrayKey("TEST_ARRAY")
        points_key = gp.PointsKey("TEST_POINTS")
        voxel_size = gp.Coordinate((1, 1))
        spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file,
                                    {array_key: 'testdata'},
                                    array_specs={array_key: spec})
        csv_source = gp.CsvPointsSource(
            self.fake_points_file, points_key,
            gp.PointsSpec(
                roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0))))

        request = gp.BatchRequest()
        shape = gp.Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = gp.ShiftAugment(prob_slip=0.2,
                                     prob_shift=0.2,
                                     sigma=5,
                                     shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() +
                    gp.RandomLocation(ensure_nonempty=points_key) + shift_node)
        with gp.build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = request[points_key].data
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points.values()
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(result_val, list(result_points.values()), target_vals,
                        self.fake_points))
예제 #7
0
def validation_data_sources_recomputed(config, blocks):
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    validation_dirs = {}
    for group in benchmark_datasets_path.iterdir():
        if "validation" in group.name and group.is_dir():
            for validation_dir in group.iterdir():
                validation_num = int(validation_dir.name.split("_")[-1])
                if validation_num in blocks:
                    validation_dirs[validation_num] = validation_dir

    validation_dirs = [validation_dirs[block] for block in blocks]

    raw = gp.ArrayKey("RAW")
    ground_truth = gp.GraphKey("GROUND_TRUTH")
    labels = gp.ArrayKey("LABELS")

    validation_pipelines = []
    for validation_dir in validation_dirs:
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        pipeline = ((
            gp.ZarrSource(
                filename=str(Path(sample_dir, sample, raw_n5).absolute()),
                datasets={raw: "volume-rechunked"},
                array_specs={
                    raw: gp.ArraySpec(interpolatable=True,
                                      voxel_size=voxel_size)
                },
            ),
            nl.gunpowder.nodes.MouselightSwcFileSource(
                validation_dir,
                [ground_truth],
                transform_file=transform_template.format(sample=sample),
                ignore_human_nodes=False,
                scale=voxel_size,
                transpose=[2, 1, 0],
                points_spec=[
                    gp.PointsSpec(roi=gp.Roi(
                        gp.Coordinate([None, None, None]),
                        gp.Coordinate([None, None, None]),
                    ))
                ],
            ),
        ) + gp.nodes.MergeProvider() + gp.Normalize(
            raw, dtype=np.float32) + nl.gunpowder.RasterizeSkeleton(
                ground_truth,
                labels,
                connected_component_labeling=True,
                array_spec=gp.ArraySpec(
                    voxel_size=voxel_size,
                    dtype=np.int64,
                    roi=gp.Roi(
                        gp.Coordinate([None, None, None]),
                        gp.Coordinate([None, None, None]),
                    ),
                ),
            ) + nl.gunpowder.GrowLabels(labels, radii=[neuron_width * 1000]))

        request = gp.BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        print(f"input_roi has shape: {input_roi.get_shape()}")
        print(f"cube_roi has shape: {cube_roi.get_shape()}")
        request[raw] = gp.ArraySpec(input_roi)
        request[ground_truth] = gp.GraphSpec(cube_roi)
        request[labels] = gp.ArraySpec(cube_roi)

        validation_pipelines.append((pipeline, request))
    return validation_pipelines, (raw, labels, ground_truth)
예제 #8
0
def validation_pipeline(config):
    """
    Per block
    {
        Raw -> predict -> scan
        gt -> rasterize        -> merge -> candidates -> trees
    } -> merge -> comatch + evaluate
    """
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        raw = gp.ArrayKey(f"RAW_{block}")
        raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}")
        ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}")
        labels = gp.ArrayKey(f"LABELS_{block}")

        raw_source = (gp.ZarrSource(
            filename=str(Path(sample_dir, sample, raw_n5).absolute()),
            datasets={
                raw: "volume-rechunked",
                raw_clahed: "volume-rechunked"
            },
            array_specs={
                raw:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
                raw_clahed:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
            },
        ) + gp.Normalize(raw, dtype=np.float32) +
                      gp.Normalize(raw_clahed, dtype=np.float32) +
                      scipyCLAHE([raw_clahed], [20, 64, 64]))
        swc_source = nl.gunpowder.nodes.MouselightSwcFileSource(
            validation_dir,
            [ground_truth],
            transform_file=transform_template.format(sample=sample),
            ignore_human_nodes=False,
            scale=voxel_size,
            transpose=[2, 1, 0],
            points_spec=[
                gp.PointsSpec(roi=gp.Roi(
                    gp.Coordinate([None, None, None]),
                    gp.Coordinate([None, None, None]),
                ))
            ],
        )

        additional_request = BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)

        cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()),
                                  cube_roi.get_shape())
        input_roi = cube_roi_shifted.grow((input_size - output_size) // 2,
                                          (input_size - output_size) // 2)

        block_spec = specs.setdefault(block, {})
        block_spec[raw] = gp.ArraySpec(input_roi)
        additional_request[raw] = gp.ArraySpec(roi=input_roi)
        block_spec[raw_clahed] = gp.ArraySpec(input_roi)
        additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi)
        block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted)
        additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted)
        block_spec[labels] = gp.ArraySpec(cube_roi_shifted)
        additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted)

        pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() +
                    gp.SpecifiedLocation(locations=[cube_roi.get_center()]) +
                    gp.Crop(raw, roi=input_roi) +
                    gp.Crop(raw_clahed, roi=input_roi) +
                    gp.Crop(ground_truth, roi=cube_roi_shifted) +
                    nl.gunpowder.RasterizeSkeleton(
                        ground_truth,
                        labels,
                        connected_component_labeling=True,
                        array_spec=gp.ArraySpec(
                            voxel_size=voxel_size,
                            dtype=np.int64,
                            roi=gp.Roi(
                                gp.Coordinate([None, None, None]),
                                gp.Coordinate([None, None, None]),
                            ),
                        ),
                    ) + nl.gunpowder.GrowLabels(
                        labels, radii=[neuron_width * micron_scale]) +
                    gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot(
                        {
                            raw: f"volumes/{block}/raw",
                            raw_clahed: f"volumes/{block}/raw_clahe",
                            ground_truth: f"points/{block}/ground_truth",
                            labels: f"volumes/{block}/labels",
                        },
                        additional_request=additional_request,
                        output_dir="validations",
                        output_filename="validations.hdf",
                    ))

        validation_pipelines.append(pipeline)

    validation_pipeline = (tuple(pipeline
                                 for pipeline in validation_pipelines) +
                           gp.MergeProvider() + gp.PrintProfilingStats())
    return validation_pipeline, specs
예제 #9
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_labels = gp.ArrayKey('GT_LABELS')
    gt_affs = gp.ArrayKey('GT_AFFS')
    gt_fgbg = gp.ArrayKey('GT_FGBG')
    gt_cpv = gp.ArrayKey('GT_CPV')
    gt_points = gp.PointsKey('GT_CPV_POINTS')

    loss_weights_affs = gp.ArrayKey('LOSS_WEIGHTS_AFFS')
    loss_weights_fgbg = gp.ArrayKey('LOSS_WEIGHTS_FGBG')
    # loss_weights_cpv = gp.ArrayKey('LOSS_WEIGHTS_CPV')

    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_fgbg = gp.ArrayKey('PRED_FGBG')
    pred_cpv = gp.ArrayKey('PRED_CPV')

    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')
    pred_fgbg_gradients = gp.ArrayKey('PRED_FGBG_GRADIENTS')
    pred_cpv_gradients = gp.ArrayKey('PRED_CPV_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_labels, output_shape_world)
    request.add(gt_fgbg, output_shape_world)
    request.add(anchor, output_shape_world)
    request.add(gt_cpv, output_shape_world)
    request.add(gt_affs, output_shape_world)
    request.add(loss_weights_affs, output_shape_world)
    request.add(loss_weights_fgbg, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(pred_affs, output_shape_world)
    # snapshot_request.add(pred_affs_gradients, output_shape_world)
    snapshot_request.add(gt_fgbg, output_shape_world)
    snapshot_request.add(pred_fgbg, output_shape_world)
    # snapshot_request.add(pred_fgbg_gradients, output_shape_world)
    snapshot_request.add(pred_cpv, output_shape_world)
    # snapshot_request.add(pred_cpv_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for {} not implemented".format(
            kwargs['input_format']))

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    # padR = 46
    # padGT = 32

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    pipeline = (
        tuple(
            # read batches from the HDF5 file
            (
                sourceNode(
                    fls[t] + "." + kwargs['input_format'],
                    datasets={
                        raw: 'volumes/raw',
                        gt_labels: 'volumes/gt_labels',
                        gt_fgbg: 'volumes/gt_fgbg',
                        anchor: 'volumes/gt_fgbg',
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True),
                        gt_labels: gp.ArraySpec(interpolatable=False),
                        gt_fgbg: gp.ArraySpec(interpolatable=False),
                        anchor: gp.ArraySpec(interpolatable=False)
                    }
                ),
                gp.CsvIDPointsSource(
                    fls[t] + ".csv",
                    gt_points,
                    points_spec=gp.PointsSpec(roi=gp.Roi(
                        gp.Coordinate((0, 0, 0)),
                        gp.Coordinate(shapes[t])))
                )
            )
            + gp.MergeProvider()
            + gp.Pad(raw, None)
            + gp.Pad(gt_points, None)
            + gp.Pad(gt_labels, None)
            + gp.Pad(gt_fgbg, None)

            # chose a random location for each requested batch
            + gp.RandomLocation()

            for t in range(ln)
        ) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        # grow a boundary between labels
        gp.GrowBoundary(
            gt_labels,
            steps=1,
            only_xy=False) +

        # convert labels into affinities between voxels
        gp.AddAffinities(
            [[-1, 0, 0], [0, -1, 0], [0, 0, -1]],
            gt_labels,
            gt_affs) +

        gp.AddCPV(
            gt_points,
            gt_labels,
            gt_cpv) +
        # create a weight array that balances positive and negative samples in
        # the affinity array
        gp.BalanceLabels(
            gt_affs,
            loss_weights_affs) +

        gp.BalanceLabels(
            gt_fgbg,
            loss_weights_fgbg) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['gt_affs']: gt_affs,
                net_names['gt_fgbg']: gt_fgbg,
                net_names['anchor']: anchor,
                net_names['gt_cpv']: gt_cpv,
                net_names['gt_labels']: gt_labels,
                net_names['loss_weights_affs']: loss_weights_affs,
                net_names['loss_weights_fgbg']: loss_weights_fgbg
            },
            outputs={
                net_names['pred_affs']: pred_affs,
                net_names['pred_fgbg']: pred_fgbg,
                net_names['pred_cpv']: pred_cpv,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                net_names['pred_affs']: pred_affs_gradients,
                net_names['pred_fgbg']: pred_fgbg_gradients,
                net_names['pred_cpv']: pred_cpv_gradients
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_labels: '/volumes/gt_labels',
                gt_affs: '/volumes/gt_affs',
                gt_fgbg: '/volumes/gt_fgbg',
                gt_cpv: '/volumes/gt_cpv',
                pred_affs: '/volumes/pred_affs',
                pred_affs_gradients: '/volumes/pred_affs_gradients',
                pred_fgbg: '/volumes/pred_fgbg',
                pred_fgbg_gradients: '/volumes/pred_fgbg_gradients',
                pred_cpv: '/volumes/pred_cpv',
                pred_cpv_gradients: '/volumes/pred_cpv_gradients'
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
예제 #10
0
def validation_pipeline(config):
    """
    Per block
    {
        Raw -> predict -> scan
        gt -> rasterize        -> merge -> candidates -> trees
    } -> merge -> comatch + evaluate
    """
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]
    candidate_threshold = config["NMS_THRESHOLD"]
    candidate_spacing = min(config["NMS_WINDOW_SIZE"]) * micron_scale
    coordinate_scale = config["COORDINATE_SCALE"] * np.array(
        voxel_size) / micron_scale

    emb_model = get_emb_model(config)
    fg_model = get_fg_model(config)

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        raw = gp.ArrayKey(f"RAW_{block}")
        ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}")
        labels = gp.ArrayKey(f"LABELS_{block}")
        candidates = gp.ArrayKey(f"CANDIDATES_{block}")
        mst = gp.GraphKey(f"MST_{block}")

        raw_source = (gp.ZarrSource(
            filename=str(Path(sample_dir, sample, raw_n5).absolute()),
            datasets={raw: "volume-rechunked"},
            array_specs={
                raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size)
            },
        ) + gp.Normalize(raw, dtype=np.float32) + mCLAHE([raw], [20, 64, 64]))
        emb_source, emb = add_emb_pred(config, raw_source, raw, block,
                                       emb_model)
        pred_source, fg = add_fg_pred(config, emb_source, raw, block, fg_model)
        pred_source = add_scan(pred_source, {
            raw: input_size,
            emb: output_size,
            fg: output_size
        })
        swc_source = nl.gunpowder.nodes.MouselightSwcFileSource(
            validation_dir,
            [ground_truth],
            transform_file=transform_template.format(sample=sample),
            ignore_human_nodes=False,
            scale=voxel_size,
            transpose=[2, 1, 0],
            points_spec=[
                gp.PointsSpec(roi=gp.Roi(
                    gp.Coordinate([None, None, None]),
                    gp.Coordinate([None, None, None]),
                ))
            ],
        )

        additional_request = BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        block_spec = specs.setdefault(block, {})
        block_spec["raw"] = (raw, gp.ArraySpec(input_roi))
        additional_request[raw] = gp.ArraySpec(roi=input_roi)
        block_spec["ground_truth"] = (ground_truth, gp.GraphSpec(cube_roi))
        additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi)
        block_spec["labels"] = (labels, gp.ArraySpec(cube_roi))
        additional_request[labels] = gp.ArraySpec(roi=cube_roi)
        block_spec["fg_pred"] = (fg, gp.ArraySpec(cube_roi))
        additional_request[fg] = gp.ArraySpec(roi=cube_roi)
        block_spec["emb_pred"] = (emb, gp.ArraySpec(cube_roi))
        additional_request[emb] = gp.ArraySpec(roi=cube_roi)
        block_spec["candidates"] = (candidates, gp.ArraySpec(cube_roi))
        additional_request[candidates] = gp.ArraySpec(roi=cube_roi)
        block_spec["mst_pred"] = (mst, gp.GraphSpec(cube_roi))
        additional_request[mst] = gp.GraphSpec(roi=cube_roi)

        pipeline = ((swc_source, pred_source) + gp.nodes.MergeProvider() +
                    nl.gunpowder.RasterizeSkeleton(
                        ground_truth,
                        labels,
                        connected_component_labeling=True,
                        array_spec=gp.ArraySpec(
                            voxel_size=voxel_size,
                            dtype=np.int64,
                            roi=gp.Roi(
                                gp.Coordinate([None, None, None]),
                                gp.Coordinate([None, None, None]),
                            ),
                        ),
                    ) + nl.gunpowder.GrowLabels(
                        labels, radii=[neuron_width * micron_scale]) +
                    Skeletonize(fg, candidates, candidate_spacing,
                                candidate_threshold) + EMST(
                                    emb,
                                    candidates,
                                    mst,
                                    distance_attr=distance_attr,
                                    coordinate_scale=coordinate_scale,
                                ) + gp.Snapshot(
                                    {
                                        raw: f"volumes/{raw}",
                                        ground_truth: f"points/{ground_truth}",
                                        labels: f"volumes/{labels}",
                                        fg: f"volumes/{fg}",
                                        emb: f"volumes/{emb}",
                                        candidates: f"volumes/{candidates}",
                                        mst: f"points/{mst}",
                                    },
                                    additional_request=additional_request,
                                    output_dir="snapshots",
                                    output_filename="{id}.hdf",
                                    edge_attrs={mst: [distance_attr]},
                                ))

        validation_pipelines.append(pipeline)

    full_gt = gp.GraphKey("FULL_GT")
    full_mst = gp.GraphKey("FULL_MST")
    score = gp.ArrayKey("SCORE")

    validation_pipeline = (
        tuple(pipeline for pipeline in validation_pipelines) +
        gp.MergeProvider() + MergeGraphs(specs, full_gt, full_mst) +
        Evaluate(full_gt, full_mst, score, edge_threshold_attr=distance_attr) +
        gp.PrintProfilingStats())
    return validation_pipeline, score
예제 #11
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')

    points = gp.PointsKey('POINTS')
    gt_cp = gp.ArrayKey('GT_CP')
    pred_cp = gp.ArrayKey('PRED_CP')
    pred_cp_gradients = gp.ArrayKey('PRED_CP_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_cp, output_shape_world)
    request.add(anchor, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(gt_cp, output_shape_world)
    snapshot_request.add(pred_cp, output_shape_world)
    # snapshot_request.add(pred_cp_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    shapes = []
    mn = []
    mx = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        mn.append(np.min(vol))
        mx.append(np.max(vol))
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    sources = tuple(
        (sourceNode(fls[t] + "." + kwargs['input_format'],
                    datasets={
                        raw: 'volumes/raw',
                        anchor: 'volumes/gt_fgbg',
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True),
                        anchor: gp.ArraySpec(interpolatable=False)
                    }),
         gp.CsvIDPointsSource(fls[t] + ".csv",
                              points,
                              points_spec=gp.PointsSpec(
                                  roi=gp.Roi(gp.Coordinate((
                                      0, 0, 0)), gp.Coordinate(shapes[t]))))) +
        gp.MergeProvider()
        # + Clip(raw, mn=mn[t], mx=mx[t])
        # + NormalizeMinMax(raw, mn=mn[t], mx=mx[t])
        + gp.Pad(raw, None) + gp.Pad(points, None)

        # chose a random location for each requested batch
        + gp.RandomLocation() for t in range(ln))
    pipeline = (
        sources +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

       # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +
        # (gp.SimpleAugment(
        #     mirror_only=augmentation['simple'].get("mirror"),
        #     transpose_only=augmentation['simple'].get("transpose")) \
        # if augmentation.get('simple') is not None and \
        #    augmentation.get('simple') != {} else NoOp())  +

        # # scale and shift the intensity of the raw array
        (gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) \
        if augmentation.get('intensity') is not None and \
           augmentation.get('intensity') != {} else NoOp())  +

        gp.RasterizePoints(
            points,
            gt_cp,
            array_spec=gp.ArraySpec(voxel_size=voxel_size),
            settings=gp.RasterizationSettings(
                radius=(2, 2, 2),
                mode='peak')) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['gt_cp']: gt_cp,
                net_names['anchor']: anchor,
            },
            outputs={
                net_names['pred_cp']: pred_cp,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                # net_names['pred_cp']: pred_cp_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_cp: '/volumes/gt_cp',
                pred_cp: '/volumes/pred_cp',
                # pred_cp_gradients: '/volumes/pred_cp_gradients',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")