Exemplo n.º 1
0
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names,
                          loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    neuron_radius = setup_config["NEURON_RADIUS"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    gt_fg = gp.ArrayKey("GT_FG")
    fg_pred = gp.ArrayKey("FG_PRED")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")
    ratio_pos = gp.ArrayKey("RATIO_POS")
    ratio_neg = gp.ArrayKey("RATIO_NEG")
    dist = gp.ArrayKey("DIST")
    num_pos_pairs = gp.ArrayKey("NUM_POS")
    num_neg_pairs = gp.ArrayKey("NUM_NEG")

    # add request
    request = gp.BatchRequest()
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    request.add(labels_fg, output_size)

    # tensorflow requests
    # snapshot_request.add(raw, input_size)  # input_size request for positioning
    # snapshot_request.add(embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size)
    # snapshot_request.add(maxima, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    # snapshot_request[emst] = gp.ArraySpec()
    # snapshot_request[edges_u] = gp.ArraySpec()
    # snapshot_request[edges_v] = gp.ArraySpec()
    # snapshot_request[ratio_pos] = gp.ArraySpec()
    # snapshot_request[ratio_neg] = gp.ArraySpec()
    # snapshot_request[dist] = gp.ArraySpec()
    # snapshot_request[num_pos_pairs] = gp.ArraySpec()
    # snapshot_request[num_neg_pairs] = gp.ArraySpec()

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched) + RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + GrowLabels(labels, radii=[neuron_radius * micron_scale])
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(labels, labels_fg) +
        BinarizeGt(labels_fg, labels_fg_bin) +
        gp.BalanceLabels(labels_fg_bin, loss_weights) +
        gp.PreCache(cache_size=cache_size, num_workers=num_workers) +
        gp.tensorflow.Train(
            "train_net",
            optimizer=create_custom_loss(mknet_tensor_names, setup_config),
            loss=None,
            inputs={
                mknet_tensor_names["loss_weights"]: loss_weights,
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_labels"]: labels_fg,
            },
            outputs={
                mknet_tensor_names["embedding"]: embedding,
                mknet_tensor_names["fg"]: fg,
                loss_tensor_names["fg_pred"]: fg_pred,
                loss_tensor_names["maxima"]: maxima,
                loss_tensor_names["gt_fg"]: gt_fg,
                loss_tensor_names["emst"]: emst,
                loss_tensor_names["edges_u"]: edges_u,
                loss_tensor_names["edges_v"]: edges_v,
                loss_tensor_names["ratio_pos"]: ratio_pos,
                loss_tensor_names["ratio_neg"]: ratio_neg,
                loss_tensor_names["dist"]: dist,
                loss_tensor_names["num_pos_pairs"]: num_pos_pairs,
                loss_tensor_names["num_neg_pairs"]: num_neg_pairs,
            },
            gradients={
                mknet_tensor_names["embedding"]: gradient_embedding,
                mknet_tensor_names["fg"]: gradient_fg,
            },
            save_every=checkpoint_every,
            summary="Merge/MergeSummary:0",
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                # labeled data
                labels: "volumes/labels",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                embedding: "volumes/embedding",
                fg: "volumes/fg",
                maxima: "volumes/maxima",
                gt_fg: "volumes/gt_fg",
                fg_pred: "volumes/fg_pred",
                gradient_embedding: "volumes/gradient_embedding",
                gradient_fg: "volumes/gradient_fg",
                # output trees
                emst: "emst",
                edges_u: "edges_u",
                edges_v: "edges_v",
                # output debug data
                ratio_pos: "ratio_pos",
                ratio_neg: "ratio_neg",
                dist: "dist",
                num_pos_pairs: "num_pos_pairs",
                num_neg_pairs: "num_neg_pairs",
                loss_weights: "volumes/loss_weights",
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
Exemplo n.º 2
0
request.add(masked_base, input_size)
request.add(masked_add, input_size)
request.add(softmask, input_size)
request.add(mask_maxed, input_size)
request.add(masked_base_b, input_size)
request.add(masked_add_b, input_size)
request.add(softmask_b, input_size)
request.add(mask_maxed_b, input_size)

data_sources = tuple((
    gp.N5Source(
        filename=str((
            filename /
            "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5"
        ).absolute()),
        datasets={raw: "volume"},
        array_specs={
            raw:
            gp.ArraySpec(
                interpolatable=True, voxel_size=voxel_size, dtype=np.uint16)
        },
    ),
    MouselightSwcFileSource(
        filename=str((filename / "high-res-swcs/G-002.swc").absolute()),
        points=(swcs, ),
        scale=voxel_size,
        transpose=(2, 1, 0),
        transform_file=str((filename / "transform.txt").absolute()),
        ignore_human_nodes=False,
    ),
) + gp.MergeProvider() + gp.RandomLocation(
    ensure_nonempty=swcs, ensure_centered=True) + RasterizeSkeleton(
Exemplo n.º 3
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    swcs = gp.PointsKey("SWCS")
    labels = gp.ArrayKey("LABELS")

    # array keys for base volume
    raw_base = gp.ArrayKey("RAW_BASE")
    labels_base = gp.ArrayKey("LABELS_BASE")
    swc_base = gp.PointsKey("SWC_BASE")

    # array keys for add volume
    raw_add = gp.ArrayKey("RAW_ADD")
    labels_add = gp.ArrayKey("LABELS_ADD")
    swc_add = gp.PointsKey("SWC_ADD")

    # array keys for fused volume
    raw_fused = gp.ArrayKey("RAW_FUSED")
    labels_fused = gp.ArrayKey("LABELS_FUSED")
    swc_fused = gp.PointsKey("SWC_FUSED")

    # output data
    fg = gp.ArrayKey("FG")
    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")

    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    voxel_size = gp.Coordinate((10, 3, 3))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size
    output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw_fused, input_size)
    request.add(labels_fused, input_size)
    request.add(swc_fused, input_size)
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)

    # add snapshot request
    # request.add(fg, output_size)
    # request.add(labels_fg, output_size)
    request.add(gradient_fg, output_size)
    request.add(raw_base, input_size)
    request.add(raw_add, input_size)
    request.add(labels_base, input_size)
    request.add(labels_add, input_size)
    request.add(swc_base, input_size)
    request.add(swc_add, input_size)

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str(
                    (
                        filename
                        / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5"
                    ).absolute()
                ),
                datasets={raw: "volume"},
                array_specs={
                    raw: gp.ArraySpec(
                        interpolatable=True, voxel_size=voxel_size, dtype=np.uint16
                    )
                },
            ),
            MouselightSwcFileSource(
                filename=str(
                    (
                        filename
                        / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs"
                    ).absolute()
                ),
                points=(swcs,),
                scale=voxel_size,
                transpose=(2, 1, 0),
                transform_file=str((filename / "transform.txt").absolute()),
                ignore_human_nodes=True
            ),
        )
        + gp.MergeProvider()
        + gp.RandomLocation(
            ensure_nonempty=swcs, ensure_centered=True
        )
        + RasterizeSkeleton(
            points=swcs,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32
            ),
        )
        + GrowLabels(labels, radius=10)
        # augment
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
        )
        + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw)
        + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for filename in Path(sample_dir).iterdir()
        if "2018-08-01" in filename.name  # or "2018-07-02" in filename.name
    )

    pipeline = (
        data_sources
        + gp.RandomProvider()
        + GetNeuronPair(
            swcs,
            raw,
            labels,
            (swc_base, swc_add),
            (raw_base, raw_add),
            (labels_base, labels_add),
            seperate_by=150,
            shift_attempts=50,
            request_attempts=10,
        )
        + FusionAugment(
            raw_base,
            raw_add,
            labels_base,
            labels_add,
            swc_base,
            swc_add,
            raw_fused,
            labels_fused,
            swc_fused,
            blend_mode="labels_mask",
            blend_smoothness=10,
            num_blended_objects=0,
        )
        + Crop(labels_fused, labels_fg)
        + BinarizeGt(labels_fg, labels_fg_bin)
        + gp.BalanceLabels(labels_fg_bin, loss_weights)
        # train
        + gp.PreCache(cache_size=40, num_workers=10)
        + gp.tensorflow.Train(
            "./train_net",
            optimizer=net_names["optimizer"],
            loss=net_names["loss"],
            inputs={
                net_names["raw"]: raw_fused,
                net_names["labels_fg"]: labels_fg_bin,
                net_names["loss_weights"]: loss_weights,
            },
            outputs={net_names["fg"]: fg},
            gradients={net_names["fg"]: gradient_fg},
            save_every=100000,
        )
        + gp.Snapshot(
            output_filename="snapshot_{iteration}.hdf",
            dataset_names={
                raw_fused: "volumes/raw_fused",
                raw_base: "volumes/raw_base",
                raw_add: "volumes/raw_add",
                labels_fused: "volumes/labels_fused",
                labels_base: "volumes/labels_base",
                labels_add: "volumes/labels_add",
                labels_fg_bin: "volumes/labels_fg_bin",
                fg: "volumes/pred_fg",
                gradient_fg: "volumes/gradient_fg",
            },
            every=100,
        )
        + gp.PrintProfilingStats(every=10)
    )

    with gp.build(pipeline):

        logging.info("Starting training...")
        for i in range(max_iteration - trained_until):
            logging.info("requesting batch {}".format(i))
            batch = pipeline.request_batch(request)
            """
def train_distance_pipeline(n_iterations, setup_config, mknet_tensor_names,
                            loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    max_label_dist = setup_config["MAX_LABEL_DIST"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    dist = gp.ArrayKey("DIST")
    dist_mask = gp.ArrayKey("DIST_MASK")
    dist_cropped = gp.ArrayKey("DIST_CROPPED")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    fg_dist = gp.ArrayKey("FG_DIST")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")

    # add request
    request = gp.BatchRequest()
    request.add(dist_mask, output_size)
    request.add(dist_cropped, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(dist, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)
    request.add(loss_weights, output_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()

    # tensorflow requests
    snapshot_request.add(raw, input_size)  # input_size request for positioning
    snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    snapshot_request.add(fg_dist, output_size, voxel_size=voxel_size)

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched, center_size=output_size) +
        RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + gp.contrib.nodes.add_distance.AddDistance(
            labels,
            dist,
            dist_mask,
            max_distance=max_label_dist * micron_scale) + gp.contrib.nodes.
        tanh_saturate.TanhSaturate(dist, scale=micron_scale, offset=1)
        + ThresholdMask(dist, loss_weights, 1e-4)
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi /
             2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(dist, dist_cropped)
        # + gp.PreCache(cache_size=cache_size, num_workers=num_workers)
        + gp.tensorflow.Train(
            "train_net_foreground",
            optimizer=mknet_tensor_names["optimizer"],
            loss=mknet_tensor_names["fg_loss"],
            inputs={
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_distances"]: dist_cropped,
                mknet_tensor_names["loss_weights"]: loss_weights,
            },
            outputs={mknet_tensor_names["fg_pred"]: fg_dist},
            gradients={mknet_tensor_names["fg_pred"]: gradient_fg},
            save_every=checkpoint_every,
            # summary=mknet_tensor_names["summaries"],
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                labels: "volumes/labels",
                # labeled data
                dist_cropped: "volumes/dist",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                fg_dist: "volumes/fg_dist",
                gradient_fg: "volumes/gradient_fg",
                # output debug data
                dist_mask: "volumes/dist_mask",
                loss_weights: "volumes/loss_weights"
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
Exemplo n.º 5
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    swcs = gp.PointsKey("SWCS")

    voxel_size = gp.Coordinate((10, 3, 3))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size * 2

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(swcs, input_size)

    data_sources = tuple((
        gp.N5Source(
            filename=str((
                filename /
                "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5"
            ).absolute()),
            datasets={raw: "volume"},
            array_specs={
                raw:
                gp.ArraySpec(interpolatable=True,
                             voxel_size=voxel_size,
                             dtype=np.uint16)
            },
        ),
        MouselightSwcFileSource(
            filename=str((
                filename /
                "consensus-neurons-with-machine-centerpoints-labelled-as-swcs/G-002.swc"
            ).absolute()),
            points=(swcs, ),
            scale=voxel_size,
            transpose=(2, 1, 0),
            transform_file=str((filename / "transform.txt").absolute()),
        ),
    ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swcs,
                                               ensure_centered=True)
                         for filename in Path(sample_dir).iterdir()
                         if "2018-08-01" in filename.name)

    pipeline = data_sources + gp.RandomProvider()

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            batch = pipeline.request_batch(request)
            vis_points_with_array(batch[raw].data,
                                  points_to_graph(batch[swcs].data),
                                  np.array(voxel_size))