Пример #1
0
def create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter,
                  gt_neurons):
    data_sources = tuple((
        Hdf5PointsSource(os.path.join(data_dir_syn, sample + '.hdf'),
                         datasets={
                             presyn: 'annotations',
                             postsyn: 'annotations'
                         },
                         rois={
                             presyn: cremi_roi,
                             postsyn: cremi_roi
                         }),
        Hdf5PointsSource(
            os.path.join(data_dir_syn, sample + '.hdf'),
            datasets={dummypostsyn: 'annotations'},
            rois={
                # presyn: cremi_roi,
                dummypostsyn: cremi_roi
            },
            kind='postsyn'),
        gp.Hdf5Source(os.path.join(data_dir, sample + '.hdf'),
                      datasets={
                          raw: 'volumes/raw',
                          gt_neurons: 'volumes/labels/neuron_ids',
                      },
                      array_specs={
                          raw: gp.ArraySpec(interpolatable=True),
                          gt_neurons: gp.ArraySpec(interpolatable=False),
                      })))
    source_pip = data_sources + gp.MergeProvider() + gp.Normalize(
        raw) + gp.RandomLocation(ensure_nonempty=dummypostsyn,
                                 p_nonempty=parameter['reject_probability'])
    return source_pip
Пример #2
0
    def test_pipeline3(self):
        array_key = gp.ArrayKey("TEST_ARRAY")
        points_key = gp.PointsKey("TEST_POINTS")
        voxel_size = gp.Coordinate((1, 1))
        spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file,
                                    {array_key: 'testdata'},
                                    array_specs={array_key: spec})
        csv_source = gp.CsvPointsSource(
            self.fake_points_file, points_key,
            gp.PointsSpec(
                roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0))))

        request = gp.BatchRequest()
        shape = gp.Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = gp.ShiftAugment(prob_slip=0.2,
                                     prob_shift=0.2,
                                     sigma=5,
                                     shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() +
                    gp.RandomLocation(ensure_nonempty=points_key) + shift_node)
        with gp.build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = request[points_key].data
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points.values()
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(result_val, list(result_points.values()), target_vals,
                        self.fake_points))
Пример #3
0
def predict_3d(raw_data, gt_data, model, predictor, aux_tasks):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = model.input_shape
    output_shape = model.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    model_output = gp.ArrayKey('MODEL_OUTPUT')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=model,
                                 inputs={'x': raw},
                                 outputs={0: model_output})
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': model_output},
                                 outputs={0: prediction})
    aux_predictions = []
    for aux_name, aux_predictor, _ in aux_tasks:
        aux_pred_key = gp.ArrayKey(f"PRED_{aux_name.upper()}")
        pipeline += gp_torch.Predict(model=aux_predictor,
                                     inputs={'x': model_output},
                                     outputs={0: aux_pred_key})
        aux_predictions.append((aux_name, aux_pred_key))
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(model_output, output_size)
    scan_request.add(prediction, output_size)
    for aux_name, aux_key in aux_predictions:
        scan_request.add(aux_key, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # only output where the gt exists
    context = (input_size - output_size) / 2

    output_roi = gt_data.roi.intersect(raw_data.roi.grow(-context, -context))
    input_roi = output_roi.grow(context, context)

    assert all([a > b for a, b in zip(input_roi.get_shape(), input_size)])
    assert all([a > b for a, b in zip(output_roi.get_shape(), output_size)])

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=input_roi)
    total_request[model_output] = gp.ArraySpec(roi=output_roi)
    total_request[prediction] = gp.ArraySpec(roi=output_roi)
    for aux_name, aux_key in aux_predictions:
        total_request[aux_key] = gp.ArraySpec(roi=output_roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=output_roi)
        total_request[target] = gp.ArraySpec(roi=output_roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {
            'raw': batch[raw],
            'model_out': batch[model_output],
            'prediction': batch[prediction]
        }
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        for aux_name, aux_key in aux_predictions:
            ret[aux_name] = batch[aux_key]
        return ret
Пример #4
0
def predict_3d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # ensure validation ROI is at least the size of the network input
    roi = raw_data.roi.grow(input_size / 2, input_size / 2)

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=roi)
    total_request[prediction] = gp.ArraySpec(roi=roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=roi)
        total_request[target] = gp.ArraySpec(roi=roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Пример #5
0
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names,
                          loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    neuron_radius = setup_config["NEURON_RADIUS"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    gt_fg = gp.ArrayKey("GT_FG")
    fg_pred = gp.ArrayKey("FG_PRED")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")
    ratio_pos = gp.ArrayKey("RATIO_POS")
    ratio_neg = gp.ArrayKey("RATIO_NEG")
    dist = gp.ArrayKey("DIST")
    num_pos_pairs = gp.ArrayKey("NUM_POS")
    num_neg_pairs = gp.ArrayKey("NUM_NEG")

    # add request
    request = gp.BatchRequest()
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    request.add(labels_fg, output_size)

    # tensorflow requests
    # snapshot_request.add(raw, input_size)  # input_size request for positioning
    # snapshot_request.add(embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size)
    # snapshot_request.add(maxima, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    # snapshot_request[emst] = gp.ArraySpec()
    # snapshot_request[edges_u] = gp.ArraySpec()
    # snapshot_request[edges_v] = gp.ArraySpec()
    # snapshot_request[ratio_pos] = gp.ArraySpec()
    # snapshot_request[ratio_neg] = gp.ArraySpec()
    # snapshot_request[dist] = gp.ArraySpec()
    # snapshot_request[num_pos_pairs] = gp.ArraySpec()
    # snapshot_request[num_neg_pairs] = gp.ArraySpec()

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched) + RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + GrowLabels(labels, radii=[neuron_radius * micron_scale])
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(labels, labels_fg) +
        BinarizeGt(labels_fg, labels_fg_bin) +
        gp.BalanceLabels(labels_fg_bin, loss_weights) +
        gp.PreCache(cache_size=cache_size, num_workers=num_workers) +
        gp.tensorflow.Train(
            "train_net",
            optimizer=create_custom_loss(mknet_tensor_names, setup_config),
            loss=None,
            inputs={
                mknet_tensor_names["loss_weights"]: loss_weights,
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_labels"]: labels_fg,
            },
            outputs={
                mknet_tensor_names["embedding"]: embedding,
                mknet_tensor_names["fg"]: fg,
                loss_tensor_names["fg_pred"]: fg_pred,
                loss_tensor_names["maxima"]: maxima,
                loss_tensor_names["gt_fg"]: gt_fg,
                loss_tensor_names["emst"]: emst,
                loss_tensor_names["edges_u"]: edges_u,
                loss_tensor_names["edges_v"]: edges_v,
                loss_tensor_names["ratio_pos"]: ratio_pos,
                loss_tensor_names["ratio_neg"]: ratio_neg,
                loss_tensor_names["dist"]: dist,
                loss_tensor_names["num_pos_pairs"]: num_pos_pairs,
                loss_tensor_names["num_neg_pairs"]: num_neg_pairs,
            },
            gradients={
                mknet_tensor_names["embedding"]: gradient_embedding,
                mknet_tensor_names["fg"]: gradient_fg,
            },
            save_every=checkpoint_every,
            summary="Merge/MergeSummary:0",
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                # labeled data
                labels: "volumes/labels",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                embedding: "volumes/embedding",
                fg: "volumes/fg",
                maxima: "volumes/maxima",
                gt_fg: "volumes/gt_fg",
                fg_pred: "volumes/fg_pred",
                gradient_embedding: "volumes/gradient_embedding",
                gradient_fg: "volumes/gradient_fg",
                # output trees
                emst: "emst",
                edges_u: "edges_u",
                edges_v: "edges_v",
                # output debug data
                ratio_pos: "ratio_pos",
                ratio_neg: "ratio_neg",
                dist: "dist",
                num_pos_pairs: "num_pos_pairs",
                num_neg_pairs: "num_neg_pairs",
                loss_weights: "volumes/loss_weights",
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
Пример #6
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_threeclass = gp.ArrayKey('GT_THREECLASS')

    loss_weights_threeclass = gp.ArrayKey('LOSS_WEIGHTS_THREECLASS')

    pred_threeclass = gp.ArrayKey('PRED_THREECLASS')

    pred_threeclass_gradients = gp.ArrayKey('PRED_THREECLASS_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_threeclass, output_shape_world)
    request.add(anchor, output_shape_world)
    request.add(loss_weights_threeclass, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(gt_threeclass, output_shape_world)
    snapshot_request.add(pred_threeclass, output_shape_world)
    # snapshot_request.add(pred_threeclass_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for {} not implemented".format(
            kwargs['input_format']))

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    # padR = 46
    # padGT = 32

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    pipeline = (
        tuple(
            # read batches from the HDF5 file
            sourceNode(
                fls[t] + "." + kwargs['input_format'],
                datasets={
                    raw: 'volumes/raw',
                    gt_threeclass: 'volumes/gt_threeclass',
                    anchor: 'volumes/gt_threeclass',
                },
                array_specs={
                    raw: gp.ArraySpec(interpolatable=True),
                    gt_threeclass: gp.ArraySpec(interpolatable=False),
                    anchor: gp.ArraySpec(interpolatable=False)
                }
            )
            + gp.MergeProvider()
            + gp.Pad(raw, None)
            + gp.Pad(gt_threeclass, None)
            + gp.Pad(anchor, gp.Coordinate((2,2,2)))


            # chose a random location for each requested batch
            + gp.RandomLocation()

            for t in range(ln)
        ) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        # grow a boundary between labels
        # TODO: check
        # gp.GrowBoundary(
        #     gt_threeclass,
        #     steps=1,
        #     only_xy=False) +

        gp.BalanceLabels(
            gt_threeclass,
            loss_weights_threeclass,
            num_classes=3) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['anchor']: anchor,
                net_names['gt_threeclass']: gt_threeclass,
                net_names['loss_weights_threeclass']: loss_weights_threeclass
            },
            outputs={
                net_names['pred_threeclass']: pred_threeclass,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                net_names['pred_threeclass']: pred_threeclass_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_threeclass: '/volumes/gt_threeclass',
                pred_threeclass: '/volumes/pred_threeclass',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
Пример #7
0
def create_pipeline_2d(task, predictor, optimizer, batch_size, outdir,
                       snapshot_every):

    raw_channels = task.data.raw.num_channels
    filename = task.data.raw.train.filename
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = task.data.raw.train.shape
    dataset_roi = task.data.raw.train.roi
    voxel_size = task.data.raw.train.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    weights = gp.ArrayKey('WEIGHTS')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError("For 2D training, please provide a 3D array where "
                           "the first dimension indexes the samples.")

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    sources = (task.data.raw.train.get_source(raw, overwrite_spec=spec),
               task.data.gt.train.get_source(gt, overwrite_spec=spec))
    pipeline = sources + gp.MergeProvider()
    pipeline += gp.Pad(raw, None)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d=1, h, w)
    # gt: ([c,] d=1, h, w)
    pipeline += gp.RandomLocation()
    # raw: ([c,] d=1, h, w)
    # gt: ([c,] d=1, h, w)
    for augmentation in eval(task.augmentations):
        pipeline += augmentation
    pipeline += predictor.add_target(gt, target)
    # (don't care about gt anymore)
    # raw: ([c,] d=1, h, w)
    # target: ([c,] d=1, h, w)
    weights_node = task.loss.add_weights(target, weights)
    if weights_node:
        pipeline += weights_node
        loss_inputs = {0: prediction, 1: target, 2: weights}
    else:
        loss_inputs = {0: prediction, 1: target}
    # raw: ([c,] d=1, h, w)
    # target: ([c,] d=1, h, w)
    # [weights: ([c,] d=1, h, w)]
    # get rid of z dim:
    pipeline += Squash(dim=-3)
    # raw: ([c,] h, w)
    # target: ([c,] h, w)
    # [weights: ([c,] h, w)]
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, h, w)
    # target: ([c,] h, w)
    # [weights: ([c,] h, w)]
    pipeline += gp.PreCache()
    pipeline += gp.Stack(batch_size)
    # raw: (b, c, h, w)
    # target: (b, [c,] h, w)
    # [weights: (b, [c,] h, w)]
    pipeline += gp_torch.Train(model=predictor,
                               loss=task.loss,
                               optimizer=optimizer,
                               inputs={'x': raw},
                               loss_inputs=loss_inputs,
                               outputs={0: prediction},
                               save_every=1e6)
    # raw: (b, c, h, w)
    # target: (b, [c,] h, w)
    # [weights: (b, [c,] h, w)]
    # prediction: (b, [c,] h, w)
    if snapshot_every > 0:
        # get channels first
        pipeline += TransposeDims(raw, (1, 0, 2, 3))
        if predictor.target_channels > 0:
            pipeline += TransposeDims(target, (1, 0, 2, 3))
            if weights_node:
                pipeline += TransposeDims(weights, (1, 0, 2, 3))
        if predictor.prediction_channels > 0:
            pipeline += TransposeDims(prediction, (1, 0, 2, 3))
        # raw: (c, b, h, w)
        # target: ([c,] b, h, w)
        # [weights: ([c,] b, h, w)]
        # prediction: ([c,] b, h, w)
        if channel_dims == 0:
            pipeline += RemoveChannelDim(raw)
        # raw: ([c,] b, h, w)
        # target: ([c,] b, h, w)
        # [weights: ([c,] b, h, w)]
        # prediction: ([c,] b, h, w)
        pipeline += gp.Snapshot(dataset_names={
            raw: 'raw',
            target: 'target',
            prediction: 'prediction',
            weights: 'weights'
        },
                                every=snapshot_every,
                                output_dir=os.path.join(outdir, 'snapshots'),
                                output_filename="{iteration}.hdf")
    pipeline += gp.PrintProfilingStats(every=100)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, output_size)
    request.add(target, output_size)
    if weights_node:
        request.add(weights, output_size)
    request.add(prediction, output_size)

    return pipeline, request
Пример #8
0
def validation_pipeline(config):
    """
    Per block
    {
        Raw -> predict -> scan
        gt -> rasterize        -> merge -> candidates -> trees
    } -> merge -> comatch + evaluate
    """
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        raw = gp.ArrayKey(f"RAW_{block}")
        raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}")
        ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}")
        labels = gp.ArrayKey(f"LABELS_{block}")

        raw_source = (gp.ZarrSource(
            filename=str(Path(sample_dir, sample, raw_n5).absolute()),
            datasets={
                raw: "volume-rechunked",
                raw_clahed: "volume-rechunked"
            },
            array_specs={
                raw:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
                raw_clahed:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
            },
        ) + gp.Normalize(raw, dtype=np.float32) +
                      gp.Normalize(raw_clahed, dtype=np.float32) +
                      scipyCLAHE([raw_clahed], [20, 64, 64]))
        swc_source = nl.gunpowder.nodes.MouselightSwcFileSource(
            validation_dir,
            [ground_truth],
            transform_file=transform_template.format(sample=sample),
            ignore_human_nodes=False,
            scale=voxel_size,
            transpose=[2, 1, 0],
            points_spec=[
                gp.PointsSpec(roi=gp.Roi(
                    gp.Coordinate([None, None, None]),
                    gp.Coordinate([None, None, None]),
                ))
            ],
        )

        additional_request = BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)

        cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()),
                                  cube_roi.get_shape())
        input_roi = cube_roi_shifted.grow((input_size - output_size) // 2,
                                          (input_size - output_size) // 2)

        block_spec = specs.setdefault(block, {})
        block_spec[raw] = gp.ArraySpec(input_roi)
        additional_request[raw] = gp.ArraySpec(roi=input_roi)
        block_spec[raw_clahed] = gp.ArraySpec(input_roi)
        additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi)
        block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted)
        additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted)
        block_spec[labels] = gp.ArraySpec(cube_roi_shifted)
        additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted)

        pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() +
                    gp.SpecifiedLocation(locations=[cube_roi.get_center()]) +
                    gp.Crop(raw, roi=input_roi) +
                    gp.Crop(raw_clahed, roi=input_roi) +
                    gp.Crop(ground_truth, roi=cube_roi_shifted) +
                    nl.gunpowder.RasterizeSkeleton(
                        ground_truth,
                        labels,
                        connected_component_labeling=True,
                        array_spec=gp.ArraySpec(
                            voxel_size=voxel_size,
                            dtype=np.int64,
                            roi=gp.Roi(
                                gp.Coordinate([None, None, None]),
                                gp.Coordinate([None, None, None]),
                            ),
                        ),
                    ) + nl.gunpowder.GrowLabels(
                        labels, radii=[neuron_width * micron_scale]) +
                    gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot(
                        {
                            raw: f"volumes/{block}/raw",
                            raw_clahed: f"volumes/{block}/raw_clahe",
                            ground_truth: f"points/{block}/ground_truth",
                            labels: f"volumes/{block}/labels",
                        },
                        additional_request=additional_request,
                        output_dir="validations",
                        output_filename="validations.hdf",
                    ))

        validation_pipelines.append(pipeline)

    validation_pipeline = (tuple(pipeline
                                 for pipeline in validation_pipelines) +
                           gp.MergeProvider() + gp.PrintProfilingStats())
    return validation_pipeline, specs
Пример #9
0
def emb_validation_pipeline(
    config,
    snapshot_file,
    candidates_path,
    raw_path,
    gt_path,
    candidates_mst_path=None,
    candidates_mst_dense_path=None,
    path_stat="max",
):
    checkpoint = config["EMB_EVAL_CHECKPOINT"]
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]
    coordinate_scale = config["COORDINATE_SCALE"] * np.array(
        voxel_size) / micron_scale
    num_thresholds = config["NUM_EVAL_THRESHOLDS"]
    threshold_range = config["EVAL_THRESHOLD_RANGE"]

    edge_threshold_0 = config["EVAL_EDGE_THRESHOLD_0"]
    component_threshold_0 = config["COMPONENT_THRESHOLD_0"]
    component_threshold_1 = config["COMPONENT_THRESHOLD_1"]

    clip_limit = config["CLAHE_CLIP_LIMIT"]
    normalize = config["CLAHE_NORMALIZE"]

    validation_pipelines = []
    specs = {}

    emb_model = get_emb_model(config)
    emb_model.eval()

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array(voxel_size[::-1]),
        )

        candidates_1 = gp.ArrayKey(f"CANDIDATES_1_{block}")

        raw = gp.ArrayKey(f"RAW_{block}")
        mst_0 = gp.GraphKey(f"MST_0_{block}")
        mst_dense_0 = gp.GraphKey(f"MST_DENSE_0_{block}")
        mst_1 = gp.GraphKey(f"MST_1_{block}")
        mst_dense_1 = gp.GraphKey(f"MST_DENSE_1_{block}")
        mst_2 = gp.GraphKey(f"MST_2_{block}")
        mst_dense_2 = gp.GraphKey(f"MST_DENSE_2_{block}")
        gt = gp.GraphKey(f"GT_{block}")
        score = gp.ArrayKey(f"SCORE_{block}")
        details = gp.GraphKey(f"DETAILS_{block}")
        optimal_mst = gp.GraphKey(f"OPTIMAL_MST_{block}")

        # Volume Source
        raw_source = SnapshotSource(
            snapshot_file,
            datasets={
                raw: raw_path.format(block=block),
                candidates_1: candidates_path.format(block=block),
            },
        )

        # Graph Source
        graph_datasets = {gt: gt_path.format(block=block)}
        graph_directionality = {gt: False}
        edge_attrs = {}
        if candidates_mst_path is not None:
            graph_datasets[mst_0] = candidates_mst_path.format(block=block)
            graph_directionality[mst_0] = False
            edge_attrs[mst_0] = [distance_attr]
        if candidates_mst_dense_path is not None:
            graph_datasets[mst_dense_0] = candidates_mst_dense_path.format(
                block=block)
            graph_directionality[mst_dense_0] = False
            edge_attrs[mst_dense_0] = [distance_attr]
        gt_source = SnapshotSource(
            snapshot_file,
            datasets=graph_datasets,
            directed=graph_directionality,
            edge_attrs=edge_attrs,
        )

        if config["EVAL_CLAHE"]:
            raw_source = raw_source + scipyCLAHE(
                [raw],
                gp.Coordinate([20, 64, 64]) * voxel_size,
                clip_limit=clip_limit,
                normalize=normalize,
            )
        else:
            pass

        emb_source, emb, neighborhood = add_emb_pred(config, raw_source, raw,
                                                     block, emb_model)

        reference_sizes = {
            raw: input_size,
            emb: output_size,
            candidates_1: output_size
        }
        if neighborhood is not None:
            reference_sizes[neighborhood] = output_size

        emb_source = add_scan(emb_source, reference_sizes)

        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()),
                                  cube_roi.get_shape())
        input_roi = cube_roi_shifted.grow((input_size - output_size) // 2,
                                          (input_size - output_size) // 2)

        block_spec = specs.setdefault(block, {})
        block_spec[raw] = gp.ArraySpec(input_roi)
        block_spec[candidates_1] = gp.ArraySpec(cube_roi_shifted)
        block_spec[emb] = gp.ArraySpec(cube_roi_shifted)
        if neighborhood is not None:
            block_spec[neighborhood] = gp.ArraySpec(cube_roi_shifted)
        block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[mst_0] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[mst_dense_0] = gp.GraphSpec(cube_roi_shifted,
                                               directed=False)
        block_spec[mst_1] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[mst_dense_1] = gp.GraphSpec(cube_roi_shifted,
                                               directed=False)
        block_spec[mst_2] = gp.GraphSpec(cube_roi_shifted, directed=False)
        # block_spec[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[score] = gp.ArraySpec(nonspatial=True)
        block_spec[optimal_mst] = gp.GraphSpec(cube_roi_shifted,
                                               directed=False)

        additional_request = BatchRequest()
        additional_request[raw] = gp.ArraySpec(input_roi)
        additional_request[candidates_1] = gp.ArraySpec(cube_roi_shifted)
        additional_request[emb] = gp.ArraySpec(cube_roi_shifted)
        if neighborhood is not None:
            additional_request[neighborhood] = gp.ArraySpec(cube_roi_shifted)
        additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False)
        additional_request[mst_0] = gp.GraphSpec(cube_roi_shifted,
                                                 directed=False)
        additional_request[mst_dense_0] = gp.GraphSpec(cube_roi_shifted,
                                                       directed=False)
        additional_request[mst_1] = gp.GraphSpec(cube_roi_shifted,
                                                 directed=False)
        additional_request[mst_dense_1] = gp.GraphSpec(cube_roi_shifted,
                                                       directed=False)
        additional_request[mst_2] = gp.GraphSpec(cube_roi_shifted,
                                                 directed=False)
        # additional_request[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False)
        additional_request[details] = gp.GraphSpec(cube_roi_shifted,
                                                   directed=False)
        additional_request[optimal_mst] = gp.GraphSpec(cube_roi_shifted,
                                                       directed=False)

        pipeline = (emb_source, gt_source) + gp.MergeProvider()

        if candidates_mst_path is not None and candidates_mst_dense_path is not None:
            # mst_0 provided, just need to calculate distances.
            pass
        elif config["EVAL_MINIMAX_EMBEDDING_DIST"]:
            # No mst_0 provided, must first calculate mst_0 and dense mst_0
            pipeline += MiniMaxEmbeddings(
                emb,
                candidates_1,
                decimated=mst_0,
                dense=mst_dense_0,
                distance_attr=distance_attr,
            )

        else:
            # mst/mst_dense not provided. Simply use euclidean distance on candidates
            pipeline += EMST(
                emb,
                candidates_1,
                mst_0,
                distance_attr=distance_attr,
                coordinate_scale=coordinate_scale,
            )
            pipeline += EMST(
                emb,
                candidates_1,
                mst_dense_0,
                distance_attr=distance_attr,
                coordinate_scale=coordinate_scale,
            )

        pipeline += ThresholdEdges(
            (mst_0, mst_1),
            edge_threshold_0,
            component_threshold_0,
            msts_dense=(mst_dense_0, mst_dense_1),
            distance_attr=distance_attr,
        )

        pipeline += ComponentWiseEMST(
            emb,
            mst_1,
            mst_2,
            distance_attr=distance_attr,
            coordinate_scale=coordinate_scale,
        )

        # pipeline += ScoreEdges(
        #     mst, mst_dense, emb, distance_attr=distance_attr, path_stat=path_stat
        # )

        pipeline += Evaluate(
            gt,
            mst_2,
            score,
            roi=cube_roi_shifted,
            details=details,
            edge_threshold_attr=distance_attr,
            num_thresholds=num_thresholds,
            threshold_range=threshold_range,
            small_component_threshold=component_threshold_1,
            # connectivity=mst_1,
            output_graph=optimal_mst,
        )

        if config["EVAL_SNAPSHOT"]:
            snapshot_datasets = {
                raw: f"volumes/raw",
                emb: f"volumes/embeddings",
                candidates_1: f"volumes/candidates_1",
                mst_0: f"points/mst_0",
                mst_dense_0: f"points/mst_dense_0",
                mst_1: f"points/mst_1",
                mst_dense_1: f"points/mst_dense_1",
                # mst_2: f"points/mst_2",
                gt: f"points/gt",
                details: f"points/details",
                optimal_mst: f"points/optimal_mst",
            }
            if neighborhood is not None:
                snapshot_datasets[neighborhood] = f"volumes/neighborhood"
            pipeline += gp.Snapshot(
                snapshot_datasets,
                output_dir=config["EVAL_SNAPSHOT_DIR"],
                output_filename=config["EVAL_SNAPSHOT_NAME"].format(
                    checkpoint=checkpoint,
                    block=block,
                    coordinate_scale=",".join(
                        [str(x) for x in coordinate_scale]),
                ),
                edge_attrs={
                    mst_0: [distance_attr],
                    mst_dense_0: [distance_attr],
                    mst_1: [distance_attr],
                    mst_dense_1: [distance_attr],
                    # mst_2: [distance_attr],
                    # optimal_mst: [distance_attr], # it is unclear how to add distances if using connectivity graph
                    # mst_dense_2: [distance_attr],
                    details: ["details", "label_pair"],
                },
                node_attrs={details: ["details", "label_pair"]},
                additional_request=additional_request,
            )

        validation_pipelines.append(pipeline)

    final_score = gp.ArrayKey("SCORE")

    validation_pipeline = (tuple(pipeline
                                 for pipeline in validation_pipelines) +
                           gp.MergeProvider() +
                           MergeScores(final_score, specs) +
                           gp.PrintProfilingStats())
    return validation_pipeline, final_score
Пример #10
0
def pre_computed_fg_validation_pipeline(config, snapshot_file, raw_path,
                                        gt_path, fg_path):
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    candidate_spacing = config["CANDIDATE_SPACING"]
    candidate_threshold = config["CANDIDATE_THRESHOLD"]

    distance_attr = config["DISTANCE_ATTR"]
    num_thresholds = config["NUM_EVAL_THRESHOLDS"]
    threshold_range = config["EVAL_THRESHOLD_RANGE"]

    component_threshold = config["COMPONENT_THRESHOLD_1"]

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array(voxel_size[::-1]),
        )

        candidates = gp.ArrayKey(f"CANDIDATES_{block}")
        raw = gp.ArrayKey(f"RAW_{block}")
        mst = gp.GraphKey(f"MST_{block}")
        gt = gp.GraphKey(f"GT_{block}")
        fg = gp.ArrayKey(f"FG_{block}")
        score = gp.ArrayKey(f"SCORE_{block}")
        details = gp.GraphKey(f"DETAILS_{block}")

        raw_source = SnapshotSource(
            snapshot_file,
            datasets={
                raw: raw_path.format(block=block),
                fg: fg_path.format(block=block),
            },
        )
        gt_source = SnapshotSource(
            snapshot_file,
            datasets={gt: gt_path.format(block=block)},
            directed={gt: False},
        )

        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()),
                                  cube_roi.get_shape())
        input_roi = cube_roi_shifted.grow((input_size - output_size) // 2,
                                          (input_size - output_size) // 2)

        block_spec = specs.setdefault(block, {})
        block_spec[raw] = gp.ArraySpec(input_roi)
        block_spec[candidates] = gp.ArraySpec(cube_roi_shifted)
        block_spec[fg] = gp.ArraySpec(cube_roi_shifted)
        block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[mst] = gp.GraphSpec(cube_roi_shifted, directed=False)
        block_spec[score] = gp.ArraySpec(nonspatial=True)

        additional_request = BatchRequest()
        additional_request[raw] = gp.ArraySpec(input_roi)
        additional_request[candidates] = gp.ArraySpec(cube_roi_shifted)
        additional_request[fg] = gp.ArraySpec(cube_roi_shifted)
        additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False)
        additional_request[mst] = gp.GraphSpec(cube_roi_shifted,
                                               directed=False)
        additional_request[details] = gp.GraphSpec(cube_roi_shifted,
                                                   directed=False)

        pipeline = ((raw_source, gt_source) + gp.MergeProvider() + Skeletonize(
            fg, candidates, candidate_spacing, candidate_threshold) +
                    MiniMax(fg, candidates, mst, distance_attr=distance_attr))

        pipeline += Evaluate(
            gt,
            mst,
            score,
            roi=cube_roi_shifted,
            details=details,
            edge_threshold_attr=distance_attr,
            num_thresholds=num_thresholds,
            threshold_range=threshold_range,
            small_component_threshold=component_threshold,
        )

        if config["EVAL_SNAPSHOT"]:
            pipeline += gp.Snapshot(
                {
                    raw: f"volumes/raw",
                    fg: f"volumes/foreground",
                    candidates: f"volumes/candidates",
                    mst: f"points/mst",
                    gt: f"points/gt",
                    details: f"points/details",
                },
                output_dir="eval_results",
                output_filename=config["EVAL_SNAPSHOT_NAME"].format(
                    block=block),
                edge_attrs={
                    mst: [distance_attr],
                    details: ["details", "label_pair"]
                },
                node_attrs={details: ["details", "label_pair"]},
                additional_request=additional_request,
            )

        validation_pipelines.append(pipeline)

    final_score = gp.ArrayKey("SCORE")

    validation_pipeline = (tuple(pipeline
                                 for pipeline in validation_pipelines) +
                           gp.MergeProvider() +
                           MergeScores(final_score, specs) +
                           gp.PrintProfilingStats())
    return validation_pipeline, final_score
Пример #11
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for fused volume
    raw = gp.ArrayKey("RAW")
    labels = gp.ArrayKey("LABELS")
    labels_fg = gp.ArrayKey("LABELS_FG")

    # array keys for base volume
    raw_base = gp.ArrayKey("RAW_BASE")
    labels_base = gp.ArrayKey("LABELS_BASE")
    swc_base = gp.PointsKey("SWC_BASE")
    swc_center_base = gp.PointsKey("SWC_CENTER_BASE")

    # array keys for add volume
    raw_add = gp.ArrayKey("RAW_ADD")
    labels_add = gp.ArrayKey("LABELS_ADD")
    swc_add = gp.PointsKey("SWC_ADD")
    swc_center_add = gp.PointsKey("SWC_CENTER_ADD")

    # output data
    fg = gp.ArrayKey("FG")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    voxel_size = gp.Coordinate((4, 1, 1))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size
    output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_fg, output_size)
    request.add(loss_weights, output_size)
    request.add(swc_center_base, output_size)
    request.add(swc_center_add, output_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(fg, output_size)
    snapshot_request.add(labels_fg, output_size)
    snapshot_request.add(gradient_fg, output_size)
    snapshot_request.add(raw_base, input_size)
    snapshot_request.add(raw_add, input_size)
    snapshot_request.add(labels_base, input_size)
    snapshot_request.add(labels_add, input_size)

    # data source for "base" volume
    data_sources_base = tuple(
        (
            gp.Hdf5Source(
                filename,
                datasets={raw_base: "/volume"},
                array_specs={
                    raw_base:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
                channels_first=False,
            ),
            SwcSource(
                filename=filename,
                dataset="/reconstruction",
                points=(swc_center_base, swc_base),
                scale=voxel_size,
            ),
        ) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton(
            points=swc_base,
            array=labels_base,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            radius=5.0,
        ) for filename in files)

    # data source for "add" volume
    data_sources_add = tuple(
        (
            gp.Hdf5Source(
                file,
                datasets={raw_add: "/volume"},
                array_specs={
                    raw_add:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
                channels_first=False,
            ),
            SwcSource(
                filename=file,
                dataset="/reconstruction",
                points=(swc_center_add, swc_add),
                scale=voxel_size,
            ),
        ) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton(
            points=swc_add,
            array=labels_add,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            radius=5.0,
        ) for file in files)
    data_sources = (
        (data_sources_base + gp.RandomProvider()),
        (data_sources_add + gp.RandomProvider()),
    ) + gp.MergeProvider()

    pipeline = (
        data_sources + FusionAugment(
            raw_base,
            raw_add,
            labels_base,
            labels_add,
            raw,
            labels,
            blend_mode="labels_mask",
            blend_smoothness=10,
            num_blended_objects=0,
        ) +
        # augment
        gp.ElasticAugment([40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0],
                          subsample=4) +
        gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +
        gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(labels, labels_fg) +
        gp.BalanceLabels(labels_fg, loss_weights) +
        # train
        gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train(
            "./train_net",
            optimizer=net_names["optimizer"],
            loss=net_names["loss"],
            inputs={
                net_names["raw"]: raw,
                net_names["labels_fg"]: labels_fg,
                net_names["loss_weights"]: loss_weights,
            },
            outputs={net_names["fg"]: fg},
            gradients={net_names["fg"]: gradient_fg},
            save_every=100000,
        ) +
        # visualize
        gp.Snapshot(
            output_filename="snapshot_{iteration}.hdf",
            dataset_names={
                raw: "volumes/raw",
                raw_base: "volumes/raw_base",
                raw_add: "volumes/raw_add",
                labels: "volumes/labels",
                labels_base: "volumes/labels_base",
                labels_add: "volumes/labels_add",
                fg: "volumes/fg",
                labels_fg: "volumes/labels_fg",
                gradient_fg: "volumes/gradient_fg",
            },
            additional_request=snapshot_request,
            every=100,
        ) + gp.PrintProfilingStats(every=100))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Пример #12
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_instances = gp.ArrayKey('GT_INSTANCES')
    gt_mask = gp.ArrayKey('GT_MASK')
    pred_mask = gp.ArrayKey('PRED_MASK')
    #loss_weights = gp.ArrayKey('LOSS_WEIGHTS')
    loss_gradients = gp.ArrayKey('LOSS_GRADIENTS')

    # array keys for base and add volume
    raw_base = gp.ArrayKey('RAW_BASE')
    gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE')
    gt_mask_base = gp.ArrayKey('GT_MASK_BASE')
    raw_add = gp.ArrayKey('RAW_ADD')
    gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD')
    gt_mask_add = gp.ArrayKey('GT_MASK_ADD')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_instances, output_shape)
    request.add(gt_mask, output_shape)
    #request.add(loss_weights, output_shape)
    request.add(raw_base, input_shape)
    request.add(raw_add, input_shape)
    request.add(gt_mask_base, output_shape)
    request.add(gt_mask_add, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    #snapshot_request.add(raw_base, input_shape)
    #snapshot_request.add(raw_add, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    #snapshot_request.add(gt_mask_base, output_shape)
    #snapshot_request.add(gt_mask_add, output_shape)
    snapshot_request.add(pred_mask, output_shape)
    snapshot_request.add(loss_gradients, output_shape)

    # specify data source
    # data source for base volume
    data_sources_base = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_base += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_base: sample + '/raw',
                        gt_instances_base: sample + '/gt',
                        gt_mask_base: sample + '/fg',
                    },
                    array_specs={
                        raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_base, np.uint8) +
                gp.Pad(raw_base, context) +
                gp.Pad(gt_instances_base, context) +
                gp.Pad(gt_mask_base, context) +
                gp.RandomLocation(min_masked=0.005,  mask=gt_mask_base)
                #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.)
                for sample in f)
    data_sources_base += gp.RandomProvider()

    # data source for add volume
    data_sources_add = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_add += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_add: sample + '/raw',
                        gt_instances_add: sample + '/gt',
                        gt_mask_add: sample + '/fg',
                    },
                    array_specs={
                        raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_add, np.uint8) +
                gp.Pad(raw_add, context) +
                gp.Pad(gt_instances_add, context) +
                gp.Pad(gt_mask_add, context) +
                gp.RandomLocation() +
                gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95)
                for sample in f)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider()

    pipeline = (
            data_sources +
            nl.FusionAugment(
                raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances,
                blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0
            ) +
            BinarizeLabels(gt_instances, gt_mask) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2, -1) +
            #gp.BalanceLabels(gt_mask, loss_weights) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=10) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt']: gt_mask,
                    #net_names['loss_weights']: loss_weights,
                },
                outputs={
                    net_names['pred']: pred_mask,
                },
                gradients={
                    net_names['output']: loss_gradients,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    pred_mask: 'volumes/pred_mask',
                    gt_mask: 'volumes/gt_mask',
                    #loss_weights: 'volumes/loss_weights',
                    loss_gradients: 'volumes/loss_gradients',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2500) +
            gp.PrintProfilingStats(every=1000)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Пример #13
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')

    points = gp.PointsKey('POINTS')
    gt_cp = gp.ArrayKey('GT_CP')
    pred_cp = gp.ArrayKey('PRED_CP')
    pred_cp_gradients = gp.ArrayKey('PRED_CP_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_cp, output_shape_world)
    request.add(anchor, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(gt_cp, output_shape_world)
    snapshot_request.add(pred_cp, output_shape_world)
    # snapshot_request.add(pred_cp_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    shapes = []
    mn = []
    mx = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        mn.append(np.min(vol))
        mx.append(np.max(vol))
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    sources = tuple(
        (sourceNode(fls[t] + "." + kwargs['input_format'],
                    datasets={
                        raw: 'volumes/raw',
                        anchor: 'volumes/gt_fgbg',
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True),
                        anchor: gp.ArraySpec(interpolatable=False)
                    }),
         gp.CsvIDPointsSource(fls[t] + ".csv",
                              points,
                              points_spec=gp.PointsSpec(
                                  roi=gp.Roi(gp.Coordinate((
                                      0, 0, 0)), gp.Coordinate(shapes[t]))))) +
        gp.MergeProvider()
        # + Clip(raw, mn=mn[t], mx=mx[t])
        # + NormalizeMinMax(raw, mn=mn[t], mx=mx[t])
        + gp.Pad(raw, None) + gp.Pad(points, None)

        # chose a random location for each requested batch
        + gp.RandomLocation() for t in range(ln))
    pipeline = (
        sources +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

       # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +
        # (gp.SimpleAugment(
        #     mirror_only=augmentation['simple'].get("mirror"),
        #     transpose_only=augmentation['simple'].get("transpose")) \
        # if augmentation.get('simple') is not None and \
        #    augmentation.get('simple') != {} else NoOp())  +

        # # scale and shift the intensity of the raw array
        (gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) \
        if augmentation.get('intensity') is not None and \
           augmentation.get('intensity') != {} else NoOp())  +

        gp.RasterizePoints(
            points,
            gt_cp,
            array_spec=gp.ArraySpec(voxel_size=voxel_size),
            settings=gp.RasterizationSettings(
                radius=(2, 2, 2),
                mode='peak')) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['gt_cp']: gt_cp,
                net_names['anchor']: anchor,
            },
            outputs={
                net_names['pred_cp']: pred_cp,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                # net_names['pred_cp']: pred_cp_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_cp: '/volumes/gt_cp',
                pred_cp: '/volumes/pred_cp',
                # pred_cp_gradients: '/volumes/pred_cp_gradients',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
Пример #14
0
    def create_train_pipeline(self, model):

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        filename = self.params['data_file']
        datasets = self.params['dataset']

        raw_0 = gp.ArrayKey('RAW_0')
        points_0 = gp.GraphKey('POINTS_0')
        locations_0 = gp.ArrayKey('LOCATIONS_0')
        emb_0 = gp.ArrayKey('EMBEDDING_0')
        raw_1 = gp.ArrayKey('RAW_1')
        points_1 = gp.GraphKey('POINTS_1')
        locations_1 = gp.ArrayKey('LOCATIONS_1')
        emb_1 = gp.ArrayKey('EMBEDDING_1')

        data = daisy.open_ds(filename, datasets[0])
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_shape = gp.Coordinate(model.out_shape[2:])
        is_2d = in_shape.dims() == 2

        emb_voxel_size = voxel_size

        cv_loss = ContrastiveVolumeLoss(self.params['temperature'],
                                        self.params['point_density'],
                                        out_shape * voxel_size)

        # Add fake 3rd dim
        if is_2d:
            in_shape = gp.Coordinate((1, *in_shape))
            out_shape = gp.Coordinate((1, *out_shape))
            voxel_size = gp.Coordinate((1, *voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (data.shape[0], *source_roi.get_shape()))

        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw_0, in_shape)
        request.add(raw_1, in_shape)
        request.add(points_0, out_shape)
        request.add(points_1, out_shape)
        request[locations_0] = gp.ArraySpec(nonspatial=True)
        request[locations_1] = gp.ArraySpec(nonspatial=True)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi)
        snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi)

        random_point_generator = RandomPointGenerator(
            density=self.params['point_density'], repetitions=2)

        # Use volume to calculate probabilities, RandomSourceGenerator will
        # normalize volumes to probablilties
        probabilities = np.array([
            np.product(daisy.open_ds(filename, dataset).shape)
            for dataset in datasets
        ])
        random_source_generator = RandomSourceGenerator(
            num_sources=len(datasets),
            probabilities=probabilities,
            repetitions=2)

        array_sources = tuple(
            tuple(
                gp.ZarrSource(
                    filename,
                    {raw: dataset},
                    # fake 3D data
                    array_specs={
                        raw:
                        gp.ArraySpec(roi=source_roi,
                                     voxel_size=voxel_size,
                                     interpolatable=True)
                    }) for dataset in datasets) for raw in [raw_0, raw_1])

        # Choose a random dataset to pull from
        array_sources = \
            tuple(arrays +
                  RandomMultiBranchSource(random_source_generator) +
                  gp.Normalize(raw, self.params['norm_factor']) +
                  gp.Pad(raw, None)
                  for raw, arrays
                  in zip([raw_0, raw_1], array_sources))

        point_sources = tuple(
            (RandomPointSource(points_0,
                               random_point_generator=random_point_generator),
             RandomPointSource(points_1,
                               random_point_generator=random_point_generator)))

        # Merge the point and array sources together.
        # There is one array and point source per branch.
        sources = tuple((array_source, point_source) + gp.MergeProvider()
                        for array_source, point_source in zip(
                            array_sources, point_sources))

        sources = tuple(
            self._make_train_augmentation_pipeline(raw, source)
            for raw, source in zip([raw_0, raw_1], sources))

        pipeline = (sources + gp.MergeProvider() + gp.Crop(raw_0, source_roi) +
                    gp.Crop(raw_1, source_roi) + gp.RandomLocation() +
                    PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0,
                                 locations_1, is_2d) +
                    RejectArray(ensure_nonempty=locations_0) +
                    RejectArray(ensure_nonempty=locations_1))

        if not is_2d:
            pipeline = (pipeline + AddChannelDim(raw_0) + AddChannelDim(raw_1))

        pipeline = (pipeline + gp.PreCache() + gp.torch.Train(
            model,
            cv_loss,
            optimizer,
            inputs={
                'raw_0': raw_0,
                'raw_1': raw_1
            },
            loss_inputs={
                'emb_0': emb_0,
                'emb_1': emb_1,
                'locations_0': locations_0,
                'locations_1': locations_1
            },
            outputs={
                2: emb_0,
                3: emb_1
            },
            array_specs={
                emb_0: gp.ArraySpec(voxel_size=emb_voxel_size),
                emb_1: gp.ArraySpec(voxel_size=emb_voxel_size)
            },
            checkpoint_basename=self.logdir + '/contrastive/checkpoints/model',
            save_every=self.params['save_every'],
            log_dir=self.logdir + "/contrastive",
            log_every=self.log_every))

        if is_2d:
            pipeline = (
                pipeline +
                # everything is 3D, except emb_0 and emb_1
                AddSpatialDim(emb_0) + AddSpatialDim(emb_1))

        pipeline = (
            pipeline +
            # now everything is 3D
            RemoveChannelDim(raw_0) + RemoveChannelDim(raw_1) +
            RemoveChannelDim(emb_0) + RemoveChannelDim(emb_1) +
            gp.Snapshot(output_dir=self.logdir + '/contrastive/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw_0: 'raw_0',
                            raw_1: 'raw_1',
                            locations_0: 'locations_0',
                            locations_1: 'locations_1',
                            emb_0: 'emb_0',
                            emb_1: 'emb_1'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            gp.PrintProfilingStats(every=500))

        return pipeline, request
Пример #15
0
def random_point_pairs_pipeline(model,
                                loss,
                                optimizer,
                                dataset,
                                augmentation_parameters,
                                point_density,
                                out_dir,
                                normalize_factor=None,
                                checkpoint_interval=5000,
                                snapshot_interval=5000):

    raw_0 = gp.ArrayKey('RAW_0')
    points_0 = gp.GraphKey('POINTS_0')
    locations_0 = gp.ArrayKey('LOCATIONS_0')
    emb_0 = gp.ArrayKey('EMBEDDING_0')
    raw_1 = gp.ArrayKey('RAW_1')
    points_1 = gp.GraphKey('POINTS_1')
    locations_1 = gp.ArrayKey('LOCATIONS_1')
    emb_1 = gp.ArrayKey('EMBEDDING_1')

    # TODO parse this key from somewhere
    key = 'train/raw/0'

    data = daisy.open_ds(dataset.filename, key)
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    emb_voxel_size = voxel_size

    # Get in and out shape
    in_shape = gp.Coordinate(model.in_shape)
    out_shape = gp.Coordinate(model.out_shape)

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw_0, in_shape)
    request.add(raw_1, in_shape)
    request.add(points_0, out_shape)
    request.add(points_1, out_shape)
    request[locations_0] = gp.ArraySpec(nonspatial=True)
    request[locations_1] = gp.ArraySpec(nonspatial=True)

    snapshot_request = gp.BatchRequest()
    snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi)
    snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi)

    # Let's hardcode this for now
    # TODO read actual number from zarr file keys
    n_samples = 447
    batch_size = 1
    dim = 2
    padding = (100, 100)

    sources = []
    for i in range(n_samples):

        ds_key = f'train/raw/{i}'
        image_sources = tuple(
            gp.ZarrSource(
                dataset.filename, {raw: ds_key},
                {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) +
            gp.Pad(raw, None) for raw in [raw_0, raw_1])

        random_point_generator = RandomPointGenerator(density=point_density,
                                                      repetitions=2)

        point_sources = tuple(
            (RandomPointSource(points_0,
                               dim,
                               random_point_generator=random_point_generator),
             RandomPointSource(points_1,
                               dim,
                               random_point_generator=random_point_generator)))

        # TODO: get augmentation parameters from some config file!
        points_and_image_sources = tuple(
            (img_source, point_source) + gp.MergeProvider() + \
            gp.SimpleAugment() + \
            gp.ElasticAugment(
                spatial_dims=2,
                control_point_spacing=(10, 10),
                jitter_sigma=(0.0, 0.0),
                rotation_interval=(0, math.pi/2)) + \
            gp.IntensityAugment(r,
                                scale_min=0.8,
                                scale_max=1.2,
                                shift_min=-0.2,
                                shift_max=0.2,
                                clip=False) + \
            gp.NoiseAugment(r, var=0.01, clip=False)
            for r, img_source, point_source
            in zip([raw_0, raw_1], image_sources, point_sources))

        sample_source = points_and_image_sources + gp.MergeProvider()

        data = daisy.open_ds(dataset.filename, ds_key)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        sample_source += gp.Crop(raw_0, source_roi)
        sample_source += gp.Crop(raw_1, source_roi)
        sample_source += gp.Pad(raw_0, padding)
        sample_source += gp.Pad(raw_1, padding)
        sample_source += gp.RandomLocation()
        sources.append(sample_source)

    sources = tuple(sources)

    pipeline = sources + gp.RandomProvider()
    pipeline += gp.Unsqueeze([raw_0, raw_1])

    pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0,
                             locations_1)

    # How does prepare batch relate to Stack?????
    pipeline += RejectArray(ensure_nonempty=locations_1)
    pipeline += RejectArray(ensure_nonempty=locations_0)

    # batch content
    # raw_0:          (1, h, w)
    # raw_1:          (1, h, w)
    # locations_0:    (n, 2)
    # locations_1:    (n, 2)

    pipeline += gp.Stack(batch_size)

    # batch content
    # raw_0:          (b, 1, h, w)
    # raw_1:          (b, 1, h, w)
    # locations_0:    (b, n, 2)
    # locations_1:    (b, n, 2)

    pipeline += gp.PreCache(num_workers=10)

    pipeline += gp.torch.Train(
        model,
        loss,
        optimizer,
        inputs={
            'raw_0': raw_0,
            'raw_1': raw_1
        },
        loss_inputs={
            'emb_0': emb_0,
            'emb_1': emb_1,
            'locations_0': locations_0,
            'locations_1': locations_1
        },
        outputs={
            2: emb_0,
            3: emb_1
        },
        array_specs={
            emb_0: gp.ArraySpec(voxel_size=emb_voxel_size),
            emb_1: gp.ArraySpec(voxel_size=emb_voxel_size)
        },
        checkpoint_basename=os.path.join(out_dir, 'model'),
        save_every=checkpoint_interval)

    pipeline += gp.Snapshot(
        {
            raw_0: 'raw_0',
            raw_1: 'raw_1',
            emb_0: 'emb_0',
            emb_1: 'emb_1',
            # locations_0 : 'locations_0',
            # locations_1 : 'locations_1',
        },
        every=snapshot_interval,
        additional_request=snapshot_request)

    return pipeline, request
Пример #16
0
def validation_pipeline(config):
    """
    Per block
    {
        Raw -> predict -> scan
        gt -> rasterize        -> merge -> candidates -> trees
    } -> merge -> comatch + evaluate
    """
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]
    candidate_threshold = config["NMS_THRESHOLD"]
    candidate_spacing = min(config["NMS_WINDOW_SIZE"]) * micron_scale
    coordinate_scale = config["COORDINATE_SCALE"] * np.array(
        voxel_size) / micron_scale

    emb_model = get_emb_model(config)
    fg_model = get_fg_model(config)

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        raw = gp.ArrayKey(f"RAW_{block}")
        ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}")
        labels = gp.ArrayKey(f"LABELS_{block}")
        candidates = gp.ArrayKey(f"CANDIDATES_{block}")
        mst = gp.GraphKey(f"MST_{block}")

        raw_source = (gp.ZarrSource(
            filename=str(Path(sample_dir, sample, raw_n5).absolute()),
            datasets={raw: "volume-rechunked"},
            array_specs={
                raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size)
            },
        ) + gp.Normalize(raw, dtype=np.float32) + mCLAHE([raw], [20, 64, 64]))
        emb_source, emb = add_emb_pred(config, raw_source, raw, block,
                                       emb_model)
        pred_source, fg = add_fg_pred(config, emb_source, raw, block, fg_model)
        pred_source = add_scan(pred_source, {
            raw: input_size,
            emb: output_size,
            fg: output_size
        })
        swc_source = nl.gunpowder.nodes.MouselightSwcFileSource(
            validation_dir,
            [ground_truth],
            transform_file=transform_template.format(sample=sample),
            ignore_human_nodes=False,
            scale=voxel_size,
            transpose=[2, 1, 0],
            points_spec=[
                gp.PointsSpec(roi=gp.Roi(
                    gp.Coordinate([None, None, None]),
                    gp.Coordinate([None, None, None]),
                ))
            ],
        )

        additional_request = BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        block_spec = specs.setdefault(block, {})
        block_spec["raw"] = (raw, gp.ArraySpec(input_roi))
        additional_request[raw] = gp.ArraySpec(roi=input_roi)
        block_spec["ground_truth"] = (ground_truth, gp.GraphSpec(cube_roi))
        additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi)
        block_spec["labels"] = (labels, gp.ArraySpec(cube_roi))
        additional_request[labels] = gp.ArraySpec(roi=cube_roi)
        block_spec["fg_pred"] = (fg, gp.ArraySpec(cube_roi))
        additional_request[fg] = gp.ArraySpec(roi=cube_roi)
        block_spec["emb_pred"] = (emb, gp.ArraySpec(cube_roi))
        additional_request[emb] = gp.ArraySpec(roi=cube_roi)
        block_spec["candidates"] = (candidates, gp.ArraySpec(cube_roi))
        additional_request[candidates] = gp.ArraySpec(roi=cube_roi)
        block_spec["mst_pred"] = (mst, gp.GraphSpec(cube_roi))
        additional_request[mst] = gp.GraphSpec(roi=cube_roi)

        pipeline = ((swc_source, pred_source) + gp.nodes.MergeProvider() +
                    nl.gunpowder.RasterizeSkeleton(
                        ground_truth,
                        labels,
                        connected_component_labeling=True,
                        array_spec=gp.ArraySpec(
                            voxel_size=voxel_size,
                            dtype=np.int64,
                            roi=gp.Roi(
                                gp.Coordinate([None, None, None]),
                                gp.Coordinate([None, None, None]),
                            ),
                        ),
                    ) + nl.gunpowder.GrowLabels(
                        labels, radii=[neuron_width * micron_scale]) +
                    Skeletonize(fg, candidates, candidate_spacing,
                                candidate_threshold) + EMST(
                                    emb,
                                    candidates,
                                    mst,
                                    distance_attr=distance_attr,
                                    coordinate_scale=coordinate_scale,
                                ) + gp.Snapshot(
                                    {
                                        raw: f"volumes/{raw}",
                                        ground_truth: f"points/{ground_truth}",
                                        labels: f"volumes/{labels}",
                                        fg: f"volumes/{fg}",
                                        emb: f"volumes/{emb}",
                                        candidates: f"volumes/{candidates}",
                                        mst: f"points/{mst}",
                                    },
                                    additional_request=additional_request,
                                    output_dir="snapshots",
                                    output_filename="{id}.hdf",
                                    edge_attrs={mst: [distance_attr]},
                                ))

        validation_pipelines.append(pipeline)

    full_gt = gp.GraphKey("FULL_GT")
    full_mst = gp.GraphKey("FULL_MST")
    score = gp.ArrayKey("SCORE")

    validation_pipeline = (
        tuple(pipeline for pipeline in validation_pipelines) +
        gp.MergeProvider() + MergeGraphs(specs, full_gt, full_mst) +
        Evaluate(full_gt, full_mst, score, edge_threshold_attr=distance_attr) +
        gp.PrintProfilingStats())
    return validation_pipeline, score
Пример #17
0
def train_until(max_iteration, return_intermediates=False):

    # get the latest checkpoint
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # input data
    ch1 = gp.ArrayKey('CH1')
    ch2 = gp.ArrayKey('CH2')
    swc = gp.PointsKey('SWC')
    swc_env = gp.PointsKey('SWC_ENV')
    swc_center = gp.PointsKey('SWC_CENTER')
    gt = gp.ArrayKey('GT')
    gt_fg = gp.ArrayKey('GT_FG')

    # show fusion augment batches
    if return_intermediates:

        a_ch1 = gp.ArrayKey('A_CH1')
        a_ch2 = gp.ArrayKey('A_CH2')
        b_ch1 = gp.ArrayKey('B_CH1')
        b_ch2 = gp.ArrayKey('B_CH2')
        soft_mask = gp.ArrayKey('SOFT_MASK')

    # output data
    fg = gp.ArrayKey('FG')
    gradient_fg = gp.ArrayKey('GRADIENT_FG')
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    voxel_size = gp.Coordinate((4, 1, 1))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(ch1, input_size)
    request.add(ch2, input_size)
    request.add(swc, input_size)
    request.add(swc_center, output_size)
    request.add(gt, output_size)
    request.add(gt_fg, output_size)
    # request.add(loss_weights, output_size)

    if return_intermediates:

        request.add(a_ch1, input_size)
        request.add(a_ch2, input_size)
        request.add(b_ch1, input_size)
        request.add(b_ch2, input_size)
        request.add(soft_mask, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    # snapshot_request[fg] = request[gt]
    # snapshot_request[gt_fg] = request[gt]
    # snapshot_request[gradient_fg] = request[gt]

    data_sources = tuple()
    data_sources += tuple(
        (Hdf5ChannelSource(file,
                           datasets={
                               ch1: '/volume',
                               ch2: '/volume',
                           },
                           channel_ids={
                               ch1: 0,
                               ch2: 1,
                           },
                           data_format='channels_last',
                           array_specs={
                               ch1:
                               gp.ArraySpec(interpolatable=True,
                                            voxel_size=voxel_size,
                                            dtype=np.uint16),
                               ch2:
                               gp.ArraySpec(interpolatable=True,
                                            voxel_size=voxel_size,
                                            dtype=np.uint16),
                           }),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center, swc),
                   return_env=True,
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center) + RasterizeSkeleton(
            points=swc,
            array=gt,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            points_env=swc_env,
            iteration=10) for file in files)

    snapshot_datasets = {}

    if return_intermediates:

        snapshot_datasets = {
            ch1: 'volumes/ch1',
            ch2: 'volumes/ch2',
            a_ch1: 'volumes/a_ch1',
            a_ch2: 'volumes/a_ch2',
            b_ch1: 'volumes/b_ch1',
            b_ch2: 'volumes/b_ch2',
            soft_mask: 'volumes/soft_mask',
            gt: 'volumes/gt',
            fg: 'volumes/fg',
            gt_fg: 'volumes/gt_fg',
            gradient_fg: 'volumes/gradient_fg',
        }

    else:

        snapshot_datasets = {
            ch1: 'volumes/ch1',
            ch2: 'volumes/ch2',
            gt: 'volumes/gt',
            fg: 'volumes/fg',
            gt_fg: 'volumes/gt_fg',
            gradient_fg: 'volumes/gradient_fg',
        }

    pipeline = (
        data_sources +
        #gp.RandomProvider() +
        FusionAugment(ch1,
                      ch2,
                      gt,
                      smoothness=1,
                      return_intermediate=return_intermediates) +

        # augment
        #gp.ElasticAugment(...) +
        #gp.SimpleAugment() +
        gp.Normalize(ch1) + gp.Normalize(ch2) + gp.Normalize(a_ch1) +
        gp.Normalize(a_ch2) + gp.Normalize(b_ch1) + gp.Normalize(b_ch2) +
        gp.IntensityAugment(ch1, 0.9, 1.1, -0.001, 0.001) +
        gp.IntensityAugment(ch2, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(gt, gt_fg) +

        # visualize
        gp.Snapshot(output_filename='snapshot_{iteration}.hdf',
                    dataset_names=snapshot_datasets,
                    additional_request=snapshot_request,
                    every=20) + gp.PrintProfilingStats(every=1000))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Пример #18
0
def get_mouselight_data_sources(setup_config: Dict[str, Any],
                                source_samples: List[str],
                                locations=False):
    # Source Paths and accessibility
    raw_n5 = setup_config["RAW_N5"]
    mongo_url = setup_config["MONGO_URL"]
    samples_path = Path(setup_config["SAMPLES_PATH"])

    # specified_locations = setup_config.get("SPECIFIED_LOCATIONS")

    # Graph matching parameters
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    matching_failures_dir = setup_config["MATCHING_FAILURES_DIR"]
    matching_failures_dir = (matching_failures_dir
                             if matching_failures_dir is None else
                             Path(matching_failures_dir))

    # Data Properties
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    output_size = output_shape * voxel_size
    micron_scale = voxel_size[0]

    distance_attr = setup_config["DISTANCE_ATTRIBUTE"]
    target_distance = float(setup_config["MIN_DIST_TO_FALLBACK"])
    max_nonempty_points = int(setup_config["MAX_RANDOM_LOCATION_POINTS"])

    mongo_db_template = setup_config["MONGO_DB_TEMPLATE"]
    matched_source = setup_config.get("MATCHED_SOURCE", "matched")

    # New array keys
    # Note: These are intended to be requested with size input_size
    raw = ArrayKey("RAW")
    matched = gp.PointsKey("MATCHED")
    nonempty_placeholder = gp.PointsKey("NONEMPTY")
    labels = ArrayKey("LABELS")

    ensure_nonempty = nonempty_placeholder

    node_offset = {
        sample.name: (daisy.persistence.MongoDbGraphProvider(
            mongo_db_template.format(sample=sample.name,
                                     source="skeletonization"),
            mongo_url,
        ).num_nodes(None) + 1)
        for sample in samples_path.iterdir() if sample.name in source_samples
    }

    # if specified_locations is not None:
    #     centers = pickle.load(open(specified_locations, "rb"))
    #     random = gp.SpecifiedLocation
    #     kwargs = {"locations": centers, "choose_randomly": True}
    #     logger.info(f"Using specified locations from {specified_locations}")
    # elif locations:
    #     random = RandomLocations
    #     kwargs = {
    #         "ensure_nonempty": ensure_nonempty,
    #         "ensure_centered": True,
    #         "point_balance_radius": point_balance_radius * micron_scale,
    #         "loc": gp.ArrayKey("RANDOM_LOCATION"),
    #     }
    # else:

    random = RandomLocation
    kwargs = {
        "ensure_nonempty": ensure_nonempty,
        "ensure_centered": True,
        "point_balance_radius": point_balance_radius * micron_scale,
    }

    data_sources = (tuple(
        (
            gp.ZarrSource(
                filename=str((sample / raw_n5).absolute()),
                datasets={raw: "volume-rechunked"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            DaisyGraphProvider(
                mongo_db_template.format(sample=sample.name,
                                         source=matched_source),
                mongo_url,
                points=[matched],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            FilteredDaisyGraphProvider(
                mongo_db_template.format(sample=sample.name,
                                         source=matched_source),
                mongo_url,
                points=[nonempty_placeholder],
                directed=True,
                node_attrs=["distance_to_fallback"],
                edge_attrs=[],
                num_nodes=max_nonempty_points,
                dist_attribute=distance_attr,
                min_dist=target_distance,
            ),
        ) + gp.MergeProvider() + random(**kwargs) + gp.Normalize(raw) +
        FilterComponents(
            matched, node_offset[sample.name], centroid_size=output_size) +
        RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.int64),
        ) for sample in samples_path.iterdir()
        if sample.name in source_samples) + gp.RandomProvider())

    return (data_sources, raw, labels, nonempty_placeholder, matched)
Пример #19
0
def create_pipeline_3d(task, predictor, optimizer, batch_size, outdir,
                       snapshot_every):

    raw_channels = max(1, task.data.raw.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = task.data.raw.train.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    weights = gp.ArrayKey('WEIGHTS')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = task.data.raw.train.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D training not yet implemented")

    sources = (task.data.raw.train.get_source(raw),
               task.data.gt.train.get_source(gt))
    pipeline = sources + gp.MergeProvider()
    pipeline += gp.Pad(raw, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.RandomLocation()
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    for augmentation in eval(task.augmentations):
        pipeline += augmentation
    pipeline += predictor.add_target(gt, target)
    # (don't care about gt anymore)
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    weights_node = task.loss.add_weights(target, weights)
    if weights_node:
        pipeline += weights_node
        loss_inputs = {0: prediction, 1: target, 2: weights}
    else:
        loss_inputs = {0: prediction, 1: target}
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    pipeline += gp.PreCache()
    pipeline += gp.Stack(batch_size)
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    pipeline += gp_torch.Train(model=predictor,
                               loss=task.loss,
                               optimizer=optimizer,
                               inputs={'x': raw},
                               loss_inputs=loss_inputs,
                               outputs={0: prediction},
                               save_every=1e6)
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    # prediction: (b, [c,] d, h, w)
    if snapshot_every > 0:
        # get channels first
        pipeline += TransposeDims(raw, (1, 0, 2, 3, 4))
        if predictor.target_channels > 0:
            pipeline += TransposeDims(target, (1, 0, 2, 3, 4))
            if weights_node:
                pipeline += TransposeDims(weights, (1, 0, 2, 3, 4))
        if predictor.prediction_channels > 0:
            pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4))
        # raw: (c, b, d, h, w)
        # target: ([c,] b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: ([c,] b, d, h, w)
        if channel_dims == 0:
            pipeline += RemoveChannelDim(raw)
        # raw: ([c,] b, d, h, w)
        # target: (c, b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: (c, b, d, h, w)
        pipeline += gp.Snapshot(dataset_names={
            raw: 'raw',
            target: 'target',
            prediction: 'prediction',
            weights: 'weights'
        },
                                every=snapshot_every,
                                output_dir=os.path.join(outdir, 'snapshots'),
                                output_filename="{iteration}.hdf")
    pipeline += gp.PrintProfilingStats(every=100)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, output_size)
    request.add(target, output_size)
    if weights_node:
        request.add(weights, output_size)
    request.add(prediction, output_size)

    return pipeline, request
Пример #20
0
def create_pipeline_3d(
    task, data, predictor, optimizer, batch_size, outdir, snapshot_every
):

    raw_channels = max(1, data.raw.num_channels)
    input_shape = gp.Coordinate(task.model.input_shape)
    output_shape = gp.Coordinate(task.model.output_shape)
    voxel_size = data.raw.train.voxel_size

    task.predictor = task.predictor.to("cuda")

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey("RAW")
    gt = gp.ArrayKey("GT")
    mask = gp.ArrayKey("MASK")
    target = gp.ArrayKey("TARGET")
    weights = gp.ArrayKey("WEIGHTS")
    model_outputs = gp.ArrayKey("MODEL_OUTPUTS")
    model_output_grads = gp.ArrayKey("MODEL_OUT_GRAD")
    prediction = gp.ArrayKey("PREDICTION")
    pred_gradients = gp.ArrayKey("PRED_GRADIENTS")

    snapshot_dataset_names = {
        raw: "raw",
        model_outputs: "model_outputs",
        model_output_grads: "model_out_grad",
        target: "target",
        prediction: "prediction",
        pred_gradients: "pred_gradients",
        weights: "weights",
    }

    aux_keys = {}
    aux_grad_keys = {}
    for name, _, _ in task.aux_tasks:
        aux_keys[name] = (
            gp.ArrayKey(f"{name.upper()}_PREDICTION"),
            gp.ArrayKey(f"{name.upper()}_TARGET"),
            None,
        )
        aux_grad_keys[name] = gp.ArrayKey(f"{name.upper()}_PRED_GRAD")

        aux_pred, aux_target, _ = aux_keys[name]

        snapshot_dataset_names[aux_pred] = f"{name}_pred"
        snapshot_dataset_names[aux_target] = f"{name}_target"
        
        aux_grad = aux_grad_keys[name]
        snapshot_dataset_names[aux_grad] = f"{name}_aux_grad"

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = data.raw.train.num_samples
    assert num_samples == 0, "Multiple samples for 3D training not yet implemented"

    sources = (data.raw.train.get_source(raw), data.gt.train.get_source(gt))
    pipeline = sources + gp.MergeProvider()
    pipeline += gp.Pad(raw, input_shape / 2 * voxel_size)
    # pipeline += gp.Pad(gt, input_shape / 2 * voxel_size)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    
    mask_node = task.loss.add_mask(gt, mask)
    if mask_node is not None:
        pipeline += mask_node
        pipeline += gp.RandomLocation(min_masked=1e-6, mask=mask)
    else:
        # raw: ([c,] d, h, w)
        # gt: ([c,] d, h, w)
        pipeline += gp.RandomLocation()
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    for augmentation in eval(task.augmentations):
        pipeline += augmentation
    pipeline += predictor.add_target(gt, target)
    # (don't care about gt anymore)
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    weights_node = task.loss.add_weights(target, weights)
    loss_inputs = []
    if weights_node:
        pipeline += weights_node
        loss_inputs.append({0: prediction, 1: target, 2: weights})
    else:
        loss_inputs.append({0: prediction, 1: target})

    head_outputs = []
    head_gradients = []
    for name, aux_predictor, aux_loss in task.aux_tasks:
        aux_prediction, aux_target, aux_weights = aux_keys[name]
        pipeline += aux_predictor.add_target(gt, aux_target)
        aux_weights_node = aux_loss.add_weights(aux_target, aux_weights)
        if aux_weights_node:
            aux_weights = gp.ArrayKey(f"{name.upper()}_WEIGHTS")
            aux_keys[name] = (
                aux_prediction,
                aux_target,
                aux_weights,
            )
            pipeline += aux_weights_node
            loss_inputs.append({0: aux_prediction, 1: aux_target, 2: aux_weights})
            snapshot_dataset_names[aux_weights] = f"{name}_weights"
        else:
            loss_inputs.append({0: aux_prediction, 1: aux_target})
        head_outputs.append({0: aux_prediction})
        aux_pred_gradient = aux_grad_keys[name]
        head_gradients.append({0: aux_pred_gradient})
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    pipeline += gp.PreCache()
    pipeline += gp.Stack(batch_size)
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    pipeline += Train(
        model=task.model,
        heads=[("opt", predictor)]
        + [(name, aux_pred) for name, aux_pred, _ in task.aux_tasks],
        losses=[task.loss] + [loss for _, _, loss in task.aux_tasks],
        optimizer=optimizer,
        inputs={"x": raw},
        outputs={0: model_outputs},
        head_outputs=[{0: prediction}] + head_outputs,
        loss_inputs=loss_inputs,
        gradients=[{0: model_output_grads}, {0: pred_gradients}] + head_gradients,
        save_every=1e6,
    )
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    # prediction: (b, [c,] d, h, w)
    if snapshot_every > 0:
        # get channels first
        pipeline += TransposeDims(raw, (1, 0, 2, 3, 4))
        if predictor.target_channels > 0:
            pipeline += TransposeDims(target, (1, 0, 2, 3, 4))
            if weights_node:
                pipeline += TransposeDims(weights, (1, 0, 2, 3, 4))
        if predictor.prediction_channels > 0:
            pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4))
        # raw: (c, b, d, h, w)
        # target: ([c,] b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: ([c,] b, d, h, w)
        if channel_dims == 0:
            pipeline += RemoveChannelDim(raw)
        # raw: ([c,] b, d, h, w)
        # target: (c, b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: (c, b, d, h, w)
        pipeline += gp.Snapshot(
            dataset_names=snapshot_dataset_names,
            every=snapshot_every,
            output_dir=os.path.join(outdir, "snapshots"),
            output_filename="{iteration}.hdf",
        )
    pipeline += gp.PrintProfilingStats(every=10)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, output_size)
    if mask_node is not None:
        request.add(mask, output_size)
    request.add(target, output_size)
    for name, _, _ in task.aux_tasks:
        aux_pred, aux_target, aux_weight = aux_keys[name]
        request.add(aux_pred, output_size)
        request.add(aux_target, output_size)
        if aux_weight is not None:
            request.add(aux_weight, output_size)
        aux_pred_grad = aux_grad_keys[name]
        request.add(aux_pred_grad, output_size)
    if weights_node:
        request.add(weights, output_size)
    request.add(prediction, output_size)
    request.add(pred_gradients, output_size)

    return pipeline, request
Пример #21
0
def merge_pipelines(pipelines_a, pipelines_b):
    merged_pipelines = []
    for a, b in zip(pipelines_a, pipelines_b):
        merged_pipelines.append((a, b) + gp.MergeProvider())
    return merged_pipelines
Пример #22
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    swcs = gp.PointsKey("SWCS")
    labels = gp.ArrayKey("LABELS")

    # array keys for base volume
    raw_base = gp.ArrayKey("RAW_BASE")
    labels_base = gp.ArrayKey("LABELS_BASE")
    swc_base = gp.PointsKey("SWC_BASE")

    # array keys for add volume
    raw_add = gp.ArrayKey("RAW_ADD")
    labels_add = gp.ArrayKey("LABELS_ADD")
    swc_add = gp.PointsKey("SWC_ADD")

    # array keys for fused volume
    raw_fused = gp.ArrayKey("RAW_FUSED")
    labels_fused = gp.ArrayKey("LABELS_FUSED")
    swc_fused = gp.PointsKey("SWC_FUSED")

    # output data
    fg = gp.ArrayKey("FG")
    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")

    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    voxel_size = gp.Coordinate((10, 3, 3))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size
    output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw_fused, input_size)
    request.add(labels_fused, input_size)
    request.add(swc_fused, input_size)
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)

    # add snapshot request
    # request.add(fg, output_size)
    # request.add(labels_fg, output_size)
    request.add(gradient_fg, output_size)
    request.add(raw_base, input_size)
    request.add(raw_add, input_size)
    request.add(labels_base, input_size)
    request.add(labels_add, input_size)
    request.add(swc_base, input_size)
    request.add(swc_add, input_size)

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str(
                    (
                        filename
                        / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5"
                    ).absolute()
                ),
                datasets={raw: "volume"},
                array_specs={
                    raw: gp.ArraySpec(
                        interpolatable=True, voxel_size=voxel_size, dtype=np.uint16
                    )
                },
            ),
            MouselightSwcFileSource(
                filename=str(
                    (
                        filename
                        / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs"
                    ).absolute()
                ),
                points=(swcs,),
                scale=voxel_size,
                transpose=(2, 1, 0),
                transform_file=str((filename / "transform.txt").absolute()),
                ignore_human_nodes=True
            ),
        )
        + gp.MergeProvider()
        + gp.RandomLocation(
            ensure_nonempty=swcs, ensure_centered=True
        )
        + RasterizeSkeleton(
            points=swcs,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32
            ),
        )
        + GrowLabels(labels, radius=10)
        # augment
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
        )
        + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw)
        + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for filename in Path(sample_dir).iterdir()
        if "2018-08-01" in filename.name  # or "2018-07-02" in filename.name
    )

    pipeline = (
        data_sources
        + gp.RandomProvider()
        + GetNeuronPair(
            swcs,
            raw,
            labels,
            (swc_base, swc_add),
            (raw_base, raw_add),
            (labels_base, labels_add),
            seperate_by=150,
            shift_attempts=50,
            request_attempts=10,
        )
        + FusionAugment(
            raw_base,
            raw_add,
            labels_base,
            labels_add,
            swc_base,
            swc_add,
            raw_fused,
            labels_fused,
            swc_fused,
            blend_mode="labels_mask",
            blend_smoothness=10,
            num_blended_objects=0,
        )
        + Crop(labels_fused, labels_fg)
        + BinarizeGt(labels_fg, labels_fg_bin)
        + gp.BalanceLabels(labels_fg_bin, loss_weights)
        # train
        + gp.PreCache(cache_size=40, num_workers=10)
        + gp.tensorflow.Train(
            "./train_net",
            optimizer=net_names["optimizer"],
            loss=net_names["loss"],
            inputs={
                net_names["raw"]: raw_fused,
                net_names["labels_fg"]: labels_fg_bin,
                net_names["loss_weights"]: loss_weights,
            },
            outputs={net_names["fg"]: fg},
            gradients={net_names["fg"]: gradient_fg},
            save_every=100000,
        )
        + gp.Snapshot(
            output_filename="snapshot_{iteration}.hdf",
            dataset_names={
                raw_fused: "volumes/raw_fused",
                raw_base: "volumes/raw_base",
                raw_add: "volumes/raw_add",
                labels_fused: "volumes/labels_fused",
                labels_base: "volumes/labels_base",
                labels_add: "volumes/labels_add",
                labels_fg_bin: "volumes/labels_fg_bin",
                fg: "volumes/pred_fg",
                gradient_fg: "volumes/gradient_fg",
            },
            every=100,
        )
        + gp.PrintProfilingStats(every=10)
    )

    with gp.build(pipeline):

        logging.info("Starting training...")
        for i in range(max_iteration - trained_until):
            logging.info("requesting batch {}".format(i))
            batch = pipeline.request_batch(request)
            """
Пример #23
0
        datasets={raw: "volume"},
        array_specs={
            raw:
            gp.ArraySpec(
                interpolatable=True, voxel_size=voxel_size, dtype=np.uint16)
        },
    ),
    MouselightSwcFileSource(
        filename=str((filename / "high-res-swcs/G-002.swc").absolute()),
        points=(swcs, ),
        scale=voxel_size,
        transpose=(2, 1, 0),
        transform_file=str((filename / "transform.txt").absolute()),
        ignore_human_nodes=False,
    ),
) + gp.MergeProvider() + gp.RandomLocation(
    ensure_nonempty=swcs, ensure_centered=True) + RasterizeSkeleton(
        points=swcs,
        array=labels,
        array_spec=gp.ArraySpec(
            interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
    ) + GrowLabels(labels, radius=20)
                     # augment
                     + gp.ElasticAugment(
                         [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0],
                         subsample=4) + gp.SimpleAugment(
                             mirror_only=[1, 2], transpose_only=[1, 2]) +
                     gp.Normalize(raw) +
                     gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
                     for filename in path_to_data.iterdir()
                     if "2018-07-02" in filename.name)
Пример #24
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for fused volume
    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    labels_fg = gp.ArrayKey('LABELS_FG')

    # array keys for base volume
    raw_base = gp.ArrayKey('RAW_BASE')
    labels_base = gp.ArrayKey('LABELS_BASE')
    swc_base = gp.PointsKey('SWC_BASE')
    swc_center_base = gp.PointsKey('SWC_CENTER_BASE')

    # array keys for add volume
    raw_add = gp.ArrayKey('RAW_ADD')
    labels_add = gp.ArrayKey('LABELS_ADD')
    swc_add = gp.PointsKey('SWC_ADD')
    swc_center_add = gp.PointsKey('SWC_CENTER_ADD')

    # output data
    fg = gp.ArrayKey('FG')
    gradient_fg = gp.ArrayKey('GRADIENT_FG')
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    voxel_size = gp.Coordinate((3, 3, 3))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_fg, output_size)
    request.add(loss_weights, output_size)

    request.add(swc_center_base, output_size)
    request.add(swc_base, input_size)

    request.add(swc_center_add, output_size)
    request.add(swc_add, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(fg, output_size)
    snapshot_request.add(labels_fg, output_size)
    snapshot_request.add(gradient_fg, output_size)
    snapshot_request.add(raw_base, input_size)
    snapshot_request.add(raw_add, input_size)
    snapshot_request.add(labels_base, input_size)
    snapshot_request.add(labels_add, input_size)

    # data source for "base" volume
    data_sources_base = tuple()
    data_sources_base += tuple(
        (gp.Hdf5Source(file,
                       datasets={
                           raw_base: '/volume',
                       },
                       array_specs={
                           raw_base:
                           gp.ArraySpec(interpolatable=True,
                                        voxel_size=voxel_size,
                                        dtype=np.uint16),
                       },
                       channels_first=False),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center_base, swc_base),
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton(
            points=swc_base,
            array=labels_base,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            iteration=10) for file in files)
    data_sources_base += gp.RandomProvider()

    # data source for "add" volume
    data_sources_add = tuple()
    data_sources_add += tuple(
        (gp.Hdf5Source(file,
                       datasets={
                           raw_add: '/volume',
                       },
                       array_specs={
                           raw_add:
                           gp.ArraySpec(interpolatable=True,
                                        voxel_size=voxel_size,
                                        dtype=np.uint16),
                       },
                       channels_first=False),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center_add, swc_add),
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton(
            points=swc_add,
            array=labels_add,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            iteration=1) for file in files)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add
                          ]) + gp.MergeProvider()

    pipeline = (
        data_sources + FusionAugment(raw_base,
                                     raw_add,
                                     labels_base,
                                     labels_add,
                                     raw,
                                     labels,
                                     blend_mode='labels_mask',
                                     blend_smoothness=10,
                                     num_blended_objects=0) +

        # augment
        gp.ElasticAugment([10, 10, 10], [1, 1, 1], [0, math.pi / 2.0],
                          subsample=8) +
        gp.SimpleAugment(mirror_only=[2], transpose_only=[]) +
        gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(labels, labels_fg) +
        gp.BalanceLabels(labels_fg, loss_weights) +

        # train
        gp.PreCache(cache_size=40, num_workers=10) +
        gp.tensorflow.Train('./train_net',
                            optimizer=net_names['optimizer'],
                            loss=net_names['loss'],
                            inputs={
                                net_names['raw']: raw,
                                net_names['labels_fg']: labels_fg,
                                net_names['loss_weights']: loss_weights,
                            },
                            outputs={
                                net_names['fg']: fg,
                            },
                            gradients={
                                net_names['fg']: gradient_fg,
                            },
                            save_every=100) +

        # visualize
        gp.Snapshot(output_filename='snapshot_{iteration}.hdf',
                    dataset_names={
                        raw: 'volumes/raw',
                        raw_base: 'volumes/raw_base',
                        raw_add: 'volumes/raw_add',
                        labels: 'volumes/labels',
                        labels_base: 'volumes/labels_base',
                        labels_add: 'volumes/labels_add',
                        fg: 'volumes/fg',
                        labels_fg: 'volumes/labels_fg',
                        gradient_fg: 'volumes/gradient_fg',
                    },
                    additional_request=snapshot_request,
                    every=10) + gp.PrintProfilingStats(every=100))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Пример #25
0
    def create_train_pipeline(self, model):

        print(f"Creating training pipeline with batch size \
              {self.params['batch_size']}")

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        points = gp.GraphKey("POINTS")
        locations = gp.ArrayKey("LOCATIONS")
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_roi = gp.Coordinate(model.base_encoder.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_roi = out_roi * out_voxel_size
        out_shape = gp.Coordinate(
            (self.params["num_points"], *model.out_shape[2:]))

        context = (in_shape - out_roi) / 2
        gt_labels_out_shape = out_roi
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            gt_labels_out_shape = (1, *gt_labels_out_shape)

            points_roi = out_voxel_size * tuple((*self.params["point_roi"], ))
            points_pad = (0, *points_roi)
            context = gp.Coordinate((0, None, None))
        else:
            points_roi = source_voxel_size * tuple(self.params["point_roi"])
            points_pad = points_roi
            context = gp.Coordinate((None, None, None))

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")
        logger.info(f"out_voxel_size: {out_voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(points, points_roi)
        request.add(gt_labels, out_roi)
        request[locations] = gp.ArraySpec(nonspatial=True)
        request[predictions] = gp.ArraySpec(nonspatial=True)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))

        source = (
            (gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                           array_specs={
                               raw:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size,
                                            interpolatable=True),
                               gt_labels:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size)
                           }),
             PointsLabelsSource(points, self.data, scale=source_voxel_size)) +
            gp.MergeProvider() + gp.Pad(raw, context) +
            gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) +
            gp.RandomLocation(ensure_nonempty=points) +
            gp.Normalize(raw, self.params['norm_factor'])
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            # If 2d then source_roi = (1, input_shape) in order to select a RL
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # Batches seem to be rejected because points are chosen near the
            # edge of the points ROI and the augmentations remove them.
            # TODO: Figure out if this is an actual issue, and if anything can
            # be done.
            gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) +
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            AddChannelDim(raw) + AddChannelDim(gt_labels)
            # raw      : (c=1, source_roi)
            # gt_labels: (c=2, source_roi)
            # points   : (c=1, source_locations_shape)
        )

        if is_2d:
            pipeline = (
                # Remove extra dim the 2d roi had
                pipeline + RemoveSpatialDim(raw) +
                RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points)
                # raw      : (c=1, roi)
                # gt_labels: (c=1, roi)
                # points   : (c=1, locations_shape)
            )

        pipeline = (
            pipeline +
            FillLocations(raw, points, locations, is_2d=False, max_points=1) +
            gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, roi)
            # gt_labels: (b, c=1, roi)
            # locations: (b, c=1, locations_shape)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={
                    'raw': raw,
                    'points': locations
                },
                loss_inputs={
                    0: predictions,
                    1: gt_labels,
                    2: locations
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(nonspatial=True),
                    emb: gp.ArraySpec(voxel_size=out_voxel_size)
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, roi)
            # gt_labels  : (b, c=1, roi)
            # predictions: (b, num_points)
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            InspectBatch('END') + gp.PrintProfilingStats(every=500))

        return pipeline, request
def train_distance_pipeline(n_iterations, setup_config, mknet_tensor_names,
                            loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    max_label_dist = setup_config["MAX_LABEL_DIST"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    dist = gp.ArrayKey("DIST")
    dist_mask = gp.ArrayKey("DIST_MASK")
    dist_cropped = gp.ArrayKey("DIST_CROPPED")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    fg_dist = gp.ArrayKey("FG_DIST")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")

    # add request
    request = gp.BatchRequest()
    request.add(dist_mask, output_size)
    request.add(dist_cropped, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(dist, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)
    request.add(loss_weights, output_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()

    # tensorflow requests
    snapshot_request.add(raw, input_size)  # input_size request for positioning
    snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    snapshot_request.add(fg_dist, output_size, voxel_size=voxel_size)

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched, center_size=output_size) +
        RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + gp.contrib.nodes.add_distance.AddDistance(
            labels,
            dist,
            dist_mask,
            max_distance=max_label_dist * micron_scale) + gp.contrib.nodes.
        tanh_saturate.TanhSaturate(dist, scale=micron_scale, offset=1)
        + ThresholdMask(dist, loss_weights, 1e-4)
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi /
             2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(dist, dist_cropped)
        # + gp.PreCache(cache_size=cache_size, num_workers=num_workers)
        + gp.tensorflow.Train(
            "train_net_foreground",
            optimizer=mknet_tensor_names["optimizer"],
            loss=mknet_tensor_names["fg_loss"],
            inputs={
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_distances"]: dist_cropped,
                mknet_tensor_names["loss_weights"]: loss_weights,
            },
            outputs={mknet_tensor_names["fg_pred"]: fg_dist},
            gradients={mknet_tensor_names["fg_pred"]: gradient_fg},
            save_every=checkpoint_every,
            # summary=mknet_tensor_names["summaries"],
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                labels: "volumes/labels",
                # labeled data
                dist_cropped: "volumes/dist",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                fg_dist: "volumes/fg_dist",
                gradient_fg: "volumes/gradient_fg",
                # output debug data
                dist_mask: "volumes/dist_mask",
                loss_weights: "volumes/loss_weights"
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
Пример #27
0
sourceB = gp.ZarrSource('../data/cropped_sample_B.zarr', {
    raw: 'raw',
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
    seg: gp.ArraySpec(interpolatable=False)
})
sourceC = gp.ZarrSource('../data/cropped_sample_C.zarr', {
    raw: 'raw',
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
    seg: gp.ArraySpec(interpolatable=False)
})

source = (sourceA, sourceB, sourceC) + gp.MergeProvider()
print(source)
normalize = gp.Normalize(raw)
simulate_cages = SimulateCages(raw, seg, out_cage_map, out_density_map, psf,
                               (min_density, max_density), [cage1], 0.5)
add_channel_dim = gp.Stack(1)
stack = gp.Stack(5)
prepare_data = PrepareTrainingData()
train = gp.torch.Train(model,
                       loss,
                       optimizer,
                       inputs={'input': raw},
                       loss_inputs={
                           0: prediction,
                           1: out_cage_map
                       },
Пример #28
0
    def build_batch_provider(self, datasets, model, task, snapshot_container=None):
        input_shape = Coordinate(model.input_shape)
        output_shape = Coordinate(model.output_shape)

        # get voxel sizes
        raw_voxel_size = datasets[0].raw.voxel_size
        prediction_voxel_size = model.scale(raw_voxel_size)

        # define input and output size:
        # switch to world units
        input_size = raw_voxel_size * input_shape
        output_size = prediction_voxel_size * output_shape

        # padding of groundtruth/mask
        gt_mask_padding = output_size + task.predictor.padding(prediction_voxel_size)

        # define keys:
        raw_key = gp.ArrayKey("RAW")
        gt_key = gp.ArrayKey("GT")
        mask_key = gp.ArrayKey("MASK")

        target_key = gp.ArrayKey("TARGET")
        weight_key = gp.ArrayKey("WEIGHT")

        # Get source nodes
        dataset_sources = []
        for dataset in datasets:

            raw_source = DaCapoArraySource(dataset.raw, raw_key)
            raw_source += gp.Pad(raw_key, None, 0)
            gt_source = DaCapoArraySource(dataset.gt, gt_key)
            gt_source += gp.Pad(gt_key, gt_mask_padding, 0)
            if dataset.mask is not None:
                mask_source = DaCapoArraySource(dataset.mask, mask_key)
            else:
                # Always provide a mask. By default it is simply an array
                # of ones with the same shape/roi as gt. Avoids making us
                # specially handle no mask case and allows padding of the
                # ground truth without worrying about training on incorrect
                # data.
                mask_source = DaCapoArraySource(OnesArray.like(dataset.gt), mask_key)
            mask_source += gp.Pad(mask_key, gt_mask_padding, 0)
            array_sources = [raw_source, gt_source, mask_source]

            dataset_source = (
                tuple(array_sources) + gp.MergeProvider() + gp.RandomLocation()
            )

            dataset_sources.append(dataset_source)
        pipeline = tuple(dataset_sources) + gp.RandomProvider()

        for augment in self.augments:
            pipeline += augment.node(raw_key, gt_key, mask_key)

        pipeline += gp.Reject(mask_key, min_masked=self.min_masked)

        # Add predictor nodes to pipeline
        pipeline += DaCapoTargetFilter(
            task.predictor,
            gt_key=gt_key,
            target_key=target_key,
            weights_key=weight_key,
            mask_key=mask_key,
        )

        # Trainer attributes:
        if self.num_data_fetchers > 1:
            pipeline += gp.PreCache(num_workers=self.num_data_fetchers)

        # stack to create a batch dimension
        pipeline += gp.Stack(self.batch_size)

        # print profiling stats
        pipeline += gp.PrintProfilingStats(every=self.print_profiling)

        # generate request for all necessary inputs to training
        request = gp.BatchRequest()
        request.add(raw_key, input_size)
        request.add(target_key, output_size)
        request.add(weight_key, output_size)
        # request additional keys for snapshots
        request.add(gt_key, output_size)
        request.add(mask_key, output_size)

        self._request = request
        self._pipeline = pipeline
        self._raw_key = raw_key
        self._gt_key = gt_key
        self._mask_key = mask_key
        self._weight_key = weight_key
        self._target_key = target_key
        self._loss = task.loss

        self.snapshot_container = snapshot_container
Пример #29
0
def predict_2d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = raw_data.shape
    dataset_roi = raw_data.roi
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError(
            "For 2D validation, please provide a 3D array where the first "
            "dimension indexes the samples.")

    num_samples = raw_data.num_samples

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    if gt_data:
        sources = (raw_data.get_source(raw, overwrite_spec=spec),
                   gt_data.get_source(gt, overwrite_spec=spec))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw, overwrite_spec=spec)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    # target: ([c,] s, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    if gt_data and predictor.target_channels == 0:
        pipeline += AddChannelDim(target)
    # raw: (c, s, h, w)
    # gt: ([c,] s, h, w)
    # target: (c, s, h, w)
    pipeline += TransposeDims(raw, (1, 0, 2, 3))
    if gt_data:
        pipeline += TransposeDims(target, (1, 0, 2, 3))
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    # prediction: (s, c, h, w)
    pipeline += gp.Scan(scan_request)

    total_request = gp.BatchRequest()
    total_request.add(raw, sample_size)
    total_request.add(prediction, sample_size)
    if gt_data:
        total_request.add(gt, sample_size)
        total_request.add(target, sample_size)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Пример #30
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    swcs = gp.PointsKey("SWCS")

    voxel_size = gp.Coordinate((10, 3, 3))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size * 2

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(swcs, input_size)

    data_sources = tuple((
        gp.N5Source(
            filename=str((
                filename /
                "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5"
            ).absolute()),
            datasets={raw: "volume"},
            array_specs={
                raw:
                gp.ArraySpec(interpolatable=True,
                             voxel_size=voxel_size,
                             dtype=np.uint16)
            },
        ),
        MouselightSwcFileSource(
            filename=str((
                filename /
                "consensus-neurons-with-machine-centerpoints-labelled-as-swcs/G-002.swc"
            ).absolute()),
            points=(swcs, ),
            scale=voxel_size,
            transpose=(2, 1, 0),
            transform_file=str((filename / "transform.txt").absolute()),
        ),
    ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swcs,
                                               ensure_centered=True)
                         for filename in Path(sample_dir).iterdir()
                         if "2018-08-01" in filename.name)

    pipeline = data_sources + gp.RandomProvider()

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            batch = pipeline.request_batch(request)
            vis_points_with_array(batch[raw].data,
                                  points_to_graph(batch[swcs].data),
                                  np.array(voxel_size))