Пример #1
0
def train_until(max_iteration, data_sources, input_shape, output_shape,
                dt_scaling_factor, loss_name):
    ArrayKey("RAW")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("GT_MASK")
    ArrayKey("TRAINING_MASK")
    ArrayKey("GT_SCALE")
    ArrayKey("LOSS_GRADIENT")
    ArrayKey("GT_DIST")
    ArrayKey("PREDICTED_DIST_LABELS")

    data_providers = []
    if cremi_version == "2016":
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/"
        filename = "sample_{0:}_padded_20160501."
    elif cremi_version == "2017":
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/"
        filename = "sample_{0:}_padded_20170424."
    if aligned:
        filename += "aligned."
    filename += "0bg.hdf"
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")
    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(cremi_dir, filename.format(sample)),
            datasets={
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_LABELS: "volumes/labels/clefts",
                ArrayKeys.GT_MASK: "volumes/masks/groundtruth",
                ArrayKeys.TRAINING_MASK: "volumes/masks/validation",
            },
            array_specs={ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)},
        )
        data_providers.append(h5_source)

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    context = input_size - output_size
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(ArrayKeys.GT_SCALE, output_size)
    request.add(ArrayKeys.GT_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize(ArrayKeys.RAW) +
        IntensityScaleShift(  # ensures RAW is in float in [0, 1]
            ArrayKeys.TRAINING_MASK, -1, 1) +
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) +
        Pad(ArrayKeys.TRAINING_MASK, context) +
        RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK) +
        Reject(ArrayKeys.GT_MASK) +
        Reject(  # reject batches wich do contain less than 50% labelled data
            ArrayKeys.GT_LABELS,
            min_masked=0.0,
            reject_probability=0.95) for provider in data_providers)

    snapshot_request = BatchRequest({
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_LABELS],
        ArrayKeys.PREDICTED_DIST_LABELS:
        request[ArrayKeys.GT_LABELS],
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_DIST],
    })

    train_pipeline = (
        data_sources + RandomProvider() + ElasticAugment(
            (4, 40, 40),
            (0.0, 0.0, 0.0),
            (0, math.pi / 2.0),
            prob_slip=0.0,
            prob_shift=0.0,
            max_misalign=0,
            subsample=8,
        ) + SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) +
        IntensityAugment(
            ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=False) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
        ) + BalanceByThreshold(
            ArrayKeys.GT_LABELS, ArrayKeys.GT_SCALE, mask=ArrayKeys.GT_MASK) +
        PreCache(cache_size=40, num_workers=10) + Train(
            "unet",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: ArrayKeys.RAW,
                net_io_names["gt_dist"]: ArrayKeys.GT_DIST,
                net_io_names["loss_weights"]: ArrayKeys.GT_SCALE,
                net_io_names["mask"]: ArrayKeys.GT_MASK,
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={net_io_names["dist"]: ArrayKeys.PREDICTED_DIST_LABELS},
            gradients={net_io_names["dist"]: ArrayKeys.LOSS_GRADIENT},
        ) + Snapshot(
            {
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_LABELS: "volumes/labels/gt_clefts",
                ArrayKeys.GT_DIST: "volumes/labels/gt_clefts_dist",
                ArrayKeys.PREDICTED_DIST_LABELS:
                "volumes/labels/pred_clefts_dist",
                ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient",
            },
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        ) + PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
def train_until(
    max_iteration,
    data_sources,
    ribo_sources,
    input_shape,
    output_shape,
    dt_scaling_factor,
    loss_name,
    labels,
    net_name,
    min_masked_voxels=17561.0,
    mask_ds_name="volumes/masks/training_cropped",
):
    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    ArrayKey("RAW")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("MASK")
    ArrayKey("RIBO_GT")

    voxel_size_up = Coordinate((2, 2, 2))
    voxel_size_orig = Coordinate((4, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size_orig
    output_size = Coordinate(output_shape) * voxel_size_orig
    # context = input_size-output_size

    keep_thr = float(min_masked_voxels) / np.prod(output_shape)

    data_providers = []
    inputs = dict()
    outputs = dict()
    snapshot = dict()
    request = BatchRequest()
    snapshot_request = BatchRequest()

    datasets_ribo = {
        ArrayKeys.RAW: "volumes/raw/data/s0",
        ArrayKeys.GT_LABELS: "volumes/labels/all",
        ArrayKeys.MASK: mask_ds_name,
        ArrayKeys.RIBO_GT: "volumes/labels/ribosomes",
    }
    # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all
    # instead (only one with the right resolution)
    datasets_no_ribo = {
        ArrayKeys.RAW: "volumes/raw/data/s0",
        ArrayKeys.GT_LABELS: "volumes/labels/all",
        ArrayKeys.MASK: mask_ds_name,
        ArrayKeys.RIBO_GT: "volumes/labels/all",
    }

    array_specs = {
        ArrayKeys.MASK: ArraySpec(interpolatable=False),
        ArrayKeys.RAW: ArraySpec(voxel_size=Coordinate(voxel_size_orig)),
    }
    array_specs_pred = {}

    inputs[net_io_names["raw"]] = ArrayKeys.RAW

    snapshot[ArrayKeys.RAW] = "volumes/raw"
    snapshot[ArrayKeys.GT_LABELS] = "volumes/labels/gt_labels"

    request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.RIBO_GT, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_orig)

    for label in labels:
        datasets_no_ribo[label.mask_key] = "volumes/masks/" + label.labelname
        datasets_ribo[label.mask_key] = "volumes/masks/" + label.labelname

        array_specs[label.mask_key] = ArraySpec(interpolatable=False)
        array_specs_pred[label.pred_dist_key] = ArraySpec(
            voxel_size=voxel_size_orig, interpolatable=True)

        inputs[net_io_names["mask_" + label.labelname]] = label.mask_key
        inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key
        if label.scale_loss or label.scale_key is not None:
            inputs[net_io_names["w_" + label.labelname]] = label.scale_key

        outputs[net_io_names[label.labelname]] = label.pred_dist_key

        snapshot[
            label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname
        snapshot[label.
                 pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname

        request.add(label.gt_dist_key, output_size, voxel_size=voxel_size_orig)
        request.add(label.pred_dist_key,
                    output_size,
                    voxel_size=voxel_size_orig)
        request.add(label.mask_key, output_size, voxel_size=voxel_size_orig)
        if label.scale_loss:
            request.add(label.scale_key,
                        output_size,
                        voxel_size=voxel_size_orig)

        snapshot_request.add(label.pred_dist_key,
                             output_size,
                             voxel_size=voxel_size_orig)

    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")

    for src in data_sources:
        if src not in ribo_sources:
            n5_source = N5Source(src.full_path,
                                 datasets=datasets_no_ribo,
                                 array_specs=array_specs)
        else:
            n5_source = N5Source(src.full_path,
                                 datasets=datasets_ribo,
                                 array_specs=array_specs)

        data_providers.append(n5_source)

    # create a tuple of data sources, one for each HDF file
    data_stream = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        # Pad(ArrayKeys.RAW, context) +
        RandomLocation() +
        RejectEfficiently(  # chose a random location inside the provided arrays
            ArrayKeys.MASK,
            min_masked=keep_thr)
        # Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data
        for provider in data_providers)

    train_pipeline = (
        data_stream +
        RandomProvider(tuple([ds.labeled_voxels for ds in data_sources])) +
        gpn.SimpleAugment() + gpn.ElasticAugment(
            voxel_size_orig,
            (100, 100, 100),
            (10.0, 10.0, 10.0),
            (0, math.pi / 2.0),
            spatial_dims=3,
            subsample=8,
        ) +
        # ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) +
        gpn.IntensityAugment(ArrayKeys.RAW, 0.25, 1.75, -0.5, 0.35) +
        GammaAugment(ArrayKeys.RAW, 0.5, 2.0) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1))
    # ZeroOutConstSections(ArrayKeys.RAW))

    for label in labels:
        if label.labelname != "ribosomes":
            train_pipeline += AddDistance(
                label_array_key=ArrayKeys.GT_LABELS,
                distance_array_key=label.gt_dist_key,
                normalize="tanh",
                normalize_args=dt_scaling_factor,
                label_id=label.labelid,
                factor=2,
            )
        else:
            train_pipeline += AddDistance(
                label_array_key=ArrayKeys.RIBO_GT,
                distance_array_key=label.gt_dist_key,
                normalize="tanh+",
                normalize_args=(dt_scaling_factor, 8),
                label_id=label.labelid,
                factor=2,
            )

    for label in labels:
        if label.scale_loss:
            train_pipeline += BalanceByThreshold(label.gt_dist_key,
                                                 label.scale_key,
                                                 mask=label.mask_key)

    train_pipeline = (train_pipeline +
                      PreCache(cache_size=30, num_workers=30) + Train(
                          net_name,
                          optimizer=net_io_names["optimizer"],
                          loss=net_io_names[loss_name],
                          inputs=inputs,
                          summary=net_io_names["summary"],
                          log_dir="log",
                          outputs=outputs,
                          gradients={},
                          log_every=5,
                          save_every=500,
                          array_specs=array_specs_pred,
                      ) + Snapshot(
                          snapshot,
                          every=500,
                          output_filename="batch_{iteration}.hdf",
                          output_dir="snapshots/",
                          additional_request=snapshot_request,
                      ) + PrintProfilingStats(every=500))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            start_it = time.time()
            b.request_batch(request)
            time_it = time.time() - start_it
            logging.info("it {0:}: {1:}".format(i + 1, time_it))

    print("Training finished")
Пример #3
0
def train_until(max_iteration, data_sources, input_shape, output_shape,
                dt_scaling_factor, loss_name, labels):
    ArrayKey('RAW')
    ArrayKey('ALPHA_MASK')
    ArrayKey('GT_LABELS')
    ArrayKey('MASK')

    data_providers = []
    data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/{0:}.n5"
    voxel_size = Coordinate((2, 2, 2))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size

    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')
    for src in data_sources:
        n5_source = N5Source(
            os.path.join(data_dir.format(src)),
            datasets={
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_LABELS: 'volumes/labels/all',
                ArrayKeys.MASK: 'volumes/mask'
            },
            array_specs={ArrayKeys.MASK: ArraySpec(interpolatable=False)})
        data_providers.append(n5_source)

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    inputs = dict()
    inputs[net_io_names['raw']] = ArrayKeys.RAW
    outputs = dict()
    snapshot = dict()
    snapshot[ArrayKeys.RAW] = 'volumes/raw'
    snapshot[ArrayKeys.GT_LABELS] = 'volumes/labels/gt_labels'
    for label in labels:
        inputs[net_io_names['gt_' + label.labelname]] = label.gt_dist_key
        if label.scale_loss or label.scale_key is not None:
            inputs[net_io_names['w_' + label.labelname]] = label.scale_key
        outputs[net_io_names[label.labelname]] = label.pred_dist_key
        snapshot[
            label.gt_dist_key] = 'volumes/labels/gt_dist_' + label.labelname
        snapshot[label.
                 pred_dist_key] = 'volumes/labels/pred_dist_' + label.labelname

    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    snapshot_request = BatchRequest()

    request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size)
    request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size)
    request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size)

    for label in labels:
        request.add(label.gt_dist_key, output_size, voxel_size=voxel_size)
        snapshot_request.add(label.pred_dist_key,
                             output_size,
                             voxel_size=voxel_size)
        if label.scale_loss:
            request.add(label.scale_key, output_size, voxel_size=voxel_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) +
        RandomLocation(min_masked=0.5, mask=ArrayKeys.MASK
                       )  # chose a random location inside the provided arrays
        #Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data
        for provider in data_providers)

    train_pipeline = (
        data_sources + RandomProvider() + ElasticAugment(
            (100, 100, 100), (10., 10., 10.), (0, math.pi / 2.0),
            prob_slip=0,
            prob_shift=0,
            max_misalign=0,
            subsample=8) + SimpleAugment() +
        #ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) +
        IntensityAugment(ArrayKeys.RAW, 0.95, 1.05, -0.05, 0.05) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW))

    for label in labels:
        train_pipeline += AddDistance(label_array_key=ArrayKeys.GT_LABELS,
                                      distance_array_key=label.gt_dist_key,
                                      normalize='tanh',
                                      normalize_args=dt_scaling_factor,
                                      label_id=label.labelid)

    train_pipeline = (train_pipeline)
    for label in labels:
        if label.scale_loss:
            train_pipeline += BalanceByThreshold(label.gt_dist_key,
                                                 label.scale_key,
                                                 mask=ArrayKeys.MASK)
    train_pipeline = (train_pipeline +
                      PreCache(cache_size=40, num_workers=10) +
                      Train('build',
                            optimizer=net_io_names['optimizer'],
                            loss=net_io_names[loss_name],
                            inputs=inputs,
                            summary=net_io_names['summary'],
                            log_dir='log',
                            outputs=outputs,
                            gradients={}) +
                      Snapshot(snapshot,
                               every=500,
                               output_filename='batch_{iteration}.hdf',
                               output_dir='snapshots/',
                               additional_request=snapshot_request) +
                      PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
Пример #4
0
def train_until(
    max_iteration,
    cremi_dir,
    data_sources,
    input_shape,
    output_shape,
    dt_scaling_factor,
    loss_name,
    cache_size=10,
    num_workers=10,
):
    ArrayKey("RAW")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_SYN_LABELS")
    ArrayKey("GT_LABELS")
    ArrayKey("GT_MASK")
    ArrayKey("TRAINING_MASK")
    ArrayKey("GT_SYN_SCALE")
    ArrayKey("LOSS_GRADIENT")
    ArrayKey("GT_SYN_DIST")
    ArrayKey("PREDICTED_SYN_DIST")
    ArrayKey("GT_BDY_DIST")
    ArrayKey("PREDICTED_BDY_DIST")

    data_providers = []
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")

    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(cremi_dir, "sample_" + sample + "_cleftsorig.hdf"),
            datasets={
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_SYN_LABELS: "volumes/labels/clefts",
                ArrayKeys.GT_MASK: "volumes/masks/groundtruth",
                ArrayKeys.TRAINING_MASK: "volumes/masks/validation",
                ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids",
            },
            array_specs={ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)},
        )

        data_providers.append(h5_source)

    # todo: dvid source

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    # input_size = Coordinate((132,)*3) * voxel_size
    # output_size = Coordinate((44,)*3) * voxel_size

    # specifiy which volumes should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_SYN_LABELS, output_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_BDY_DIST, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(ArrayKeys.GT_SYN_SCALE, output_size)
    request.add(ArrayKeys.GT_SYN_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None)
        + Pad(ArrayKeys.GT_MASK, None)
        + Pad(ArrayKeys.TRAINING_MASK, None)
        + RandomLocation()
        + Reject(  # chose a random location inside the provided arrays
            ArrayKeys.GT_MASK
        )
        + Reject(  # reject batches wich do contain less than 50% labelled data
            ArrayKeys.TRAINING_MASK, min_masked=0.99
        )
        + Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95)
        for provider in data_providers
    )

    snapshot_request = BatchRequest(
        {
            ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_LABELS],
            ArrayKeys.PREDICTED_SYN_DIST: request[ArrayKeys.GT_SYN_LABELS],
            ArrayKeys.PREDICTED_BDY_DIST: request[ArrayKeys.GT_LABELS],
            ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_DIST],
        }
    )

    artifact_source = (
        Hdf5Source(
            os.path.join(cremi_dir, "sample_ABC_padded_20160501.defects.hdf"),
            datasets={
                ArrayKeys.RAW: "defect_sections/raw",
                ArrayKeys.ALPHA_MASK: "defect_sections/mask",
            },
            array_specs={
                ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)),
                ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)),
            },
        )
        + RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK)
        + Normalize(ArrayKeys.RAW)
        + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
        + ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi / 2.0), subsample=8)
        + SimpleAugment(transpose_only=[1, 2])
    )

    train_pipeline = (
        data_sources
        + RandomProvider()
        + ElasticAugment(
            (4, 40, 40),
            (0.0, 2.0, 2.0),
            (0, math.pi / 2.0),
            prob_slip=0.05,
            prob_shift=0.05,
            max_misalign=10,
            subsample=8,
        )
        + SimpleAugment(transpose_only=[1, 2])
        + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
        + DefectAugment(
            ArrayKeys.RAW,
            prob_missing=0.03,
            prob_low_contrast=0.01,
            prob_artifact=0.03,
            artifact_source=artifact_source,
            artifacts=ArrayKeys.RAW,
            artifacts_mask=ArrayKeys.ALPHA_MASK,
            contrast_scale=0.5,
        )
        + IntensityScaleShift(ArrayKeys.RAW, 2, -1)
        + ZeroOutConstSections(ArrayKeys.RAW)
        + GrowBoundary(ArrayKeys.GT_LABELS, ArrayKeys.GT_MASK, steps=1, only_xy=True)
        + AddBoundaryDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_BDY_DIST,
            normalize="tanh",
            normalize_args=100,
        )
        + AddDistance(
            label_array_key=ArrayKeys.GT_SYN_LABELS,
            distance_array_key=ArrayKeys.GT_SYN_DIST,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
        )
        + BalanceLabels(
            ArrayKeys.GT_SYN_LABELS, ArrayKeys.GT_SYN_SCALE, ArrayKeys.GT_MASK
        )
        + PreCache(cache_size=cache_size, num_workers=num_workers)
        + Train(
            "unet",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: ArrayKeys.RAW,
                net_io_names["gt_syn_dist"]: ArrayKeys.GT_SYN_DIST,
                net_io_names["gt_bdy_dist"]: ArrayKeys.GT_BDY_DIST,
                net_io_names["loss_weights"]: ArrayKeys.GT_SYN_SCALE,
                net_io_names["mask"]: ArrayKeys.GT_MASK,
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={
                net_io_names["syn_dist"]: ArrayKeys.PREDICTED_SYN_DIST,
                net_io_names["bdy_dist"]: ArrayKeys.PREDICTED_BDY_DIST,
            },
            gradients={net_io_names["syn_dist"]: ArrayKeys.LOSS_GRADIENT},
        )
        + Snapshot(
            {
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_SYN_LABELS: "volumes/labels/gt_clefts",
                ArrayKeys.GT_SYN_DIST: "volumes/labels/gt_clefts_dist",
                ArrayKeys.PREDICTED_SYN_DIST: "volumes/labels/pred_clefts_dist",
                ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient",
                ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids",
                ArrayKeys.PREDICTED_BDY_DIST: "volumes/labels/pred_bdy_dist",
                ArrayKeys.GT_BDY_DIST: "volumes/labels/gt_bdy_dist",
            },
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        )
        + PrintProfilingStats(every=50)
    )

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
Пример #5
0
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name):
    ArrayKey('RAW')
    ArrayKey('ALPHA_MASK')
    ArrayKey('GT_SYN_LABELS')
    ArrayKey('GT_LABELS')
    ArrayKey('GT_MASK')
    ArrayKey('TRAINING_MASK')
    ArrayKey('GT_SYN_SCALE')
    ArrayKey('LOSS_GRADIENT')
    ArrayKey('GT_SYN_DIST')
    ArrayKey('PREDICTED_SYN_DIST')
    ArrayKey('GT_BDY_DIST')
    ArrayKey('PREDICTED_BDY_DIST')

    data_providers = []
    cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/"
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(cremi_dir, 'sample_'+sample+'_cleftsorig.hdf'),
            datasets={
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_SYN_LABELS: 'volumes/labels/clefts',
                ArrayKeys.GT_MASK: 'volumes/masks/groundtruth',
                ArrayKeys.TRAINING_MASK: 'volumes/masks/validation',
                ArrayKeys.GT_LABELS: 'volumes/labels/neuron_ids'
            },
            array_specs={
                ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)
            }
        )

        data_providers.append(h5_source)

    #todo: dvid source

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    # input_size = Coordinate((132,)*3) * voxel_size
    # output_size = Coordinate((44,)*3) * voxel_size

    # specifiy which volumes should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_SYN_LABELS, output_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_BDY_DIST, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(ArrayKeys.GT_SYN_SCALE, output_size)
    request.add(ArrayKeys.GT_SYN_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) +
        Pad(ArrayKeys.GT_MASK, None) +
        Pad(ArrayKeys.TRAINING_MASK, None) +
        RandomLocation() + # chose a random location inside the provided arrays
        Reject(ArrayKeys.GT_MASK) + # reject batches wich do contain less than 50% labelled data
        Reject(ArrayKeys.TRAINING_MASK, min_masked=0.99) +
        Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95)

        for provider in data_providers)

    snapshot_request = BatchRequest({
        ArrayKeys.LOSS_GRADIENT:      request[ArrayKeys.GT_SYN_LABELS],
        ArrayKeys.PREDICTED_SYN_DIST: request[ArrayKeys.GT_SYN_LABELS],
        ArrayKeys.PREDICTED_BDY_DIST: request[ArrayKeys.GT_LABELS],
        ArrayKeys.LOSS_GRADIENT:      request[ArrayKeys.GT_SYN_DIST],

    })

    artifact_source = (
        Hdf5Source(
            os.path.join(cremi_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                ArrayKeys.RAW:        'defect_sections/raw',
                ArrayKeys.ALPHA_MASK: 'defect_sections/mask',
            },
            array_specs={
                ArrayKeys.RAW:        ArraySpec(voxel_size=(40, 4, 4)),
                ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)),
            }
        ) +
        RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK) +
        Normalize(ArrayKeys.RAW) +
        IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi/2.0), subsample=8) +
        SimpleAugment(transpose_only=[1,2])
    )

    train_pipeline = (
        data_sources +
        RandomProvider() +
        ElasticAugment((4, 40, 40), (0., 2., 2.), (0, math.pi/2.0),
                       prob_slip=0.05, prob_shift=0.05, max_misalign=10,
                       subsample=8) +
        SimpleAugment(transpose_only=[1,2]) +
        IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        DefectAugment(ArrayKeys.RAW,
                      prob_missing=0.03,
                      prob_low_contrast=0.01,
                      prob_artifact=0.03,
                      artifact_source=artifact_source,
                      artifacts=ArrayKeys.RAW,
                      artifacts_mask=ArrayKeys.ALPHA_MASK,
                      contrast_scale=0.5) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) +
        GrowBoundary(ArrayKeys.GT_LABELS, ArrayKeys.GT_MASK, steps=1, only_xy=True) +

        AddBoundaryDistance(label_array_key=ArrayKeys.GT_LABELS,
                            distance_array_key=ArrayKeys.GT_BDY_DIST,
                            normalize='tanh',
                            normalize_args=100) +

        AddDistance(label_array_key=ArrayKeys.GT_SYN_LABELS,
                            distance_array_key=ArrayKeys.GT_SYN_DIST,
                            normalize='tanh',
                            normalize_args=dt_scaling_factor
                            ) +
        BalanceLabels(ArrayKeys.GT_SYN_LABELS, ArrayKeys.GT_SYN_SCALE, ArrayKeys.GT_MASK) +
        PreCache(
            cache_size=40,
            num_workers=10)+

        Train(
            'unet',
            optimizer=net_io_names['optimizer'],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names['raw']:          ArrayKeys.RAW,
                net_io_names['gt_syn_dist']:  ArrayKeys.GT_SYN_DIST,
                net_io_names['gt_bdy_dist']:  ArrayKeys.GT_BDY_DIST,
                net_io_names['loss_weights']: ArrayKeys.GT_SYN_SCALE,
                net_io_names['mask']:         ArrayKeys.GT_MASK,
            },
            summary=net_io_names['summary'],
            log_dir='log',
            outputs={
                net_io_names['syn_dist']: ArrayKeys.PREDICTED_SYN_DIST,
                net_io_names['bdy_dist']: ArrayKeys.PREDICTED_BDY_DIST
            },
            gradients={
                net_io_names['syn_dist']: ArrayKeys.LOSS_GRADIENT
            }) +

        Snapshot({
                ArrayKeys.RAW:                'volumes/raw',
                ArrayKeys.GT_SYN_LABELS:      'volumes/labels/gt_clefts',
                ArrayKeys.GT_SYN_DIST:        'volumes/labels/gt_clefts_dist',
                ArrayKeys.PREDICTED_SYN_DIST: 'volumes/labels/pred_clefts_dist',
                ArrayKeys.LOSS_GRADIENT:      'volumes/loss_gradient',
                ArrayKeys.GT_LABELS:          'volumes/labels/neuron_ids',
                ArrayKeys.PREDICTED_BDY_DIST: 'volumes/labels/pred_bdy_dist',
                ArrayKeys.GT_BDY_DIST:        'volumes/labels/gt_bdy_dist'
            },
            every=500,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request) +

        PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
Пример #6
0
def train_until(
    max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name
):
    raw = ArrayKey("RAW")
    # ArrayKey('ALPHA_MASK')
    clefts = ArrayKey("GT_LABELS")
    mask = ArrayKey("GT_MASK")
    scale = ArrayKey("GT_SCALE")
    # grad = ArrayKey('LOSS_GRADIENT')
    gt_dist = ArrayKey("GT_DIST")
    pred_dist = ArrayKey("PREDICTED_DIST")

    data_providers = []

    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")
    if trained_until >= max_iteration:
        return
    data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/fib19/mine/"
    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(data_dir, "cube{0:}.hdf".format(sample)),
            datasets={
                raw: "volumes/raw",
                clefts: "volumes/labels/clefts",
                mask: "/volumes/masks/groundtruth",
            },
            array_specs={mask: ArraySpec(interpolatable=False)},
        )
        data_providers.append(h5_source)

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((8, 8, 8))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    # input_size = Coordinate((132,)*3) * voxel_size
    # output_size = Coordinate((44,)*3) * voxel_size

    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(raw, input_size)
    request.add(clefts, output_size)
    request.add(mask, output_size)
    # request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(scale, output_size)
    request.add(gt_dist, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(raw, None)
        + RandomLocation()
        +  # chose a random location inside the provided arrays
        # Reject(ArrayKeys.GT_MASK) + # reject batches wich do contain less than 50% labelled data
        # Reject(ArrayKeys.TRAINING_MASK, min_masked=0.99) +
        Reject(mask=mask) + Reject(clefts, min_masked=0.0, reject_probability=0.95)
        for provider in data_providers
    )

    snapshot_request = BatchRequest({pred_dist: request[clefts]})

    train_pipeline = (
        data_sources
        + RandomProvider()
        + ElasticAugment(
            (40, 40, 40),
            (2.0, 2.0, 2.0),
            (0, math.pi / 2.0),
            prob_slip=0.01,
            prob_shift=0.01,
            max_misalign=1,
            subsample=8,
        )
        + SimpleAugment()
        + IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1)
        + IntensityScaleShift(raw, 2, -1)
        + ZeroOutConstSections(raw)
        +
        # GrowBoundary(steps=1) +
        # SplitAndRenumberSegmentationLabels() +
        # AddGtAffinities(malis.mknhood3d()) +
        AddDistance(
            label_array_key=clefts,
            distance_array_key=gt_dist,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
        )
        +
        # BalanceLabels(clefts, scale, mask) +
        BalanceByThreshold(labels=ArrayKeys.GT_DIST, scales=ArrayKeys.GT_SCALE)
        +
        # {
        #     ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE
        # },
        # {
        #     ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_MASK
        # }) +
        PreCache(cache_size=40, num_workers=10)
        + Train(
            "unet",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: raw,
                net_io_names["gt_dist"]: gt_dist,
                net_io_names["loss_weights"]: scale,
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={net_io_names["dist"]: pred_dist},
            gradients={},
        )
        + Snapshot(
            {
                raw: "volumes/raw",
                clefts: "volumes/labels/gt_clefts",
                gt_dist: "volumes/labels/gt_clefts_dist",
                pred_dist: "volumes/labels/pred_clefts_dist",
            },
            dataset_dtypes={clefts: np.uint64},
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        )
        + PrintProfilingStats(every=50)
    )

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration - trained_until):
            b.request_batch(request)

    print("Training finished")
Пример #7
0
def train_until(max_iteration, data_sources, ribo_sources, input_shape, output_shape, \
                                                          dt_scaling_factor, loss_name,
                labels, net_name, min_masked_voxels=17561., mask_ds_name='volumes/masks/training'):
    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    ArrayKey('RAW')
    ArrayKey('ALPHA_MASK')
    ArrayKey('GT_LABELS')
    ArrayKey('MASK')
    ArrayKey('RIBO_GT')

    voxel_size_up = Coordinate((2, 2, 2))
    voxel_size_input = Coordinate((8, 8, 8))
    voxel_size_output = Coordinate((4, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size_input
    output_size = Coordinate(output_shape) * voxel_size_output
    # context = input_size-output_size

    keep_thr = float(min_masked_voxels) / np.prod(output_shape)

    data_providers = []
    inputs = dict()
    outputs = dict()
    snapshot = dict()
    request = BatchRequest()
    snapshot_request = BatchRequest()

    datasets_ribo = {
        ArrayKeys.RAW: None,
        ArrayKeys.GT_LABELS: 'volumes/labels/all',
        ArrayKeys.MASK: mask_ds_name,
        ArrayKeys.RIBO_GT: 'volumes/labels/ribosomes',
    }
    # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all
    # instead (only one with the right resolution)
    datasets_no_ribo = {
        ArrayKeys.RAW: None,
        ArrayKeys.GT_LABELS: 'volumes/labels/all',
        ArrayKeys.MASK: mask_ds_name,
        ArrayKeys.RIBO_GT: 'volumes/labels/all',
    }

    array_specs = {
        ArrayKeys.MASK: ArraySpec(interpolatable=False),
        ArrayKeys.RAW: ArraySpec(voxel_size=Coordinate(voxel_size_input))
    }
    array_specs_pred = {}

    inputs[net_io_names['raw']] = ArrayKeys.RAW

    snapshot[ArrayKeys.RAW] = 'volumes/raw'
    snapshot[ArrayKeys.GT_LABELS] = 'volumes/labels/gt_labels'

    request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_output)
    request.add(ArrayKeys.RIBO_GT, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_input)

    for label in labels:
        datasets_no_ribo[label.mask_key] = 'volumes/masks/' + label.labelname
        datasets_ribo[label.mask_key] = 'volumes/masks/' + label.labelname

        array_specs[label.mask_key] = ArraySpec(interpolatable=False)
        array_specs_pred[label.pred_dist_key] = ArraySpec(
            voxel_size=voxel_size_output, interpolatable=True)

        inputs[net_io_names['mask_' + label.labelname]] = label.mask_key
        inputs[net_io_names['gt_' + label.labelname]] = label.gt_dist_key
        if label.scale_loss or label.scale_key is not None:
            inputs[net_io_names['w_' + label.labelname]] = label.scale_key

        outputs[net_io_names[label.labelname]] = label.pred_dist_key

        snapshot[
            label.gt_dist_key] = 'volumes/labels/gt_dist_' + label.labelname
        snapshot[label.
                 pred_dist_key] = 'volumes/labels/pred_dist_' + label.labelname

        request.add(label.gt_dist_key,
                    output_size,
                    voxel_size=voxel_size_output)
        request.add(label.pred_dist_key,
                    output_size,
                    voxel_size=voxel_size_output)
        request.add(label.mask_key, output_size, voxel_size=voxel_size_output)
        if label.scale_loss:
            request.add(label.scale_key,
                        output_size,
                        voxel_size=voxel_size_output)

        snapshot_request.add(label.pred_dist_key,
                             output_size,
                             voxel_size=voxel_size_output)

    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    for src in data_sources:
        for subsample_variant in range(8):
            dnr = datasets_no_ribo.copy()
            dr = datasets_ribo.copy()
            dnr[ArrayKeys.RAW] = 'volumes/subsampled/raw{0:}/'.format(
                subsample_variant)
            dr[ArrayKeys.RAW] = 'volumes/subsampled/raw{0:}/'.format(
                subsample_variant)

            if src not in ribo_sources:
                n5_source = N5Source(src.full_path,
                                     datasets=dnr,
                                     array_specs=array_specs)
            else:
                n5_source = N5Source(src.full_path,
                                     datasets=dr,
                                     array_specs=array_specs)

            data_providers.append(n5_source)

    # create a tuple of data sources, one for each HDF file
    data_stream = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        # Pad(ArrayKeys.RAW, context) +
        RandomLocation()
        +  # chose a random location inside the provided arrays
        Reject(ArrayKeys.MASK, min_masked=keep_thr)
        for provider in data_providers)

    train_pipeline = (
        data_stream + RandomProvider(
            tuple(np.repeat([ds.labeled_voxels for ds in data_sources], 8))) +
        gpn.SimpleAugment() + gpn.ElasticAugment(voxel_size_output,
                                                 (100, 100, 100),
                                                 (10., 10., 10.),
                                                 (0, math.pi / 2.0),
                                                 spatial_dims=3,
                                                 subsample=8) +
        gpn.IntensityAugment(ArrayKeys.RAW, 0.25, 1.75, -0.5, 0.35) +
        GammaAugment(ArrayKeys.RAW, 0.5, 2.0) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1))

    for label in labels:
        if label.labelname != 'ribosomes':
            train_pipeline += AddDistance(label_array_key=ArrayKeys.GT_LABELS,
                                          distance_array_key=label.gt_dist_key,
                                          normalize='tanh',
                                          normalize_args=dt_scaling_factor,
                                          label_id=label.labelid,
                                          factor=2)
        else:
            train_pipeline += AddDistance(label_array_key=ArrayKeys.RIBO_GT,
                                          distance_array_key=label.gt_dist_key,
                                          normalize='tanh+',
                                          normalize_args=(dt_scaling_factor,
                                                          8),
                                          label_id=label.labelid,
                                          factor=2)

    for label in labels:
        if label.scale_loss:
            train_pipeline += BalanceByThreshold(label.gt_dist_key,
                                                 label.scale_key,
                                                 mask=label.mask_key)

    train_pipeline = (train_pipeline +
                      PreCache(cache_size=10, num_workers=20) +
                      Train(net_name,
                            optimizer=net_io_names['optimizer'],
                            loss=net_io_names[loss_name],
                            inputs=inputs,
                            summary=net_io_names['summary'],
                            log_dir='log',
                            outputs=outputs,
                            gradients={},
                            log_every=5,
                            save_every=500,
                            array_specs=array_specs_pred) +
                      Snapshot(snapshot,
                               every=500,
                               output_filename='batch_{iteration}.hdf',
                               output_dir='snapshots/',
                               additional_request=snapshot_request) +
                      PrintProfilingStats(every=500))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            start_it = time.time()
            b.request_batch(request)
            time_it = time.time() - start_it
            logging.info('it {0:}: {1:}'.format(i + 1, time_it))

    print("Training finished")
Пример #8
0
def train_until(max_iteration, data_dir, data_sources, input_shape,
                output_shape, loss_name):
    ArrayKey('RAW')
    ArrayKey('ALPHA_MASK')
    ArrayKey('GT_LABELS')
    ArrayKey('GT_DIST_SCALE')
    # ArrayKey('GT_AFF_SCALE')
    ArrayKey('LOSS_GRADIENT')
    ArrayKey('GT_DIST')
    ArrayKey('PREDICTED_DIST')
    # ArrayKey('GT_AFF')
    # ArrayKey('PREDICTED_AFF1')
    # ArrayKey('PREDICTED_AFF3')
    # ArrayKey('PREDICTED_AFF9')

    data_providers = []
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    for sample in data_sources:
        h5_source = Hdf5Source(data_dir,
                               datasets={
                                   ArrayKeys.RAW: sample + '/image',
                                   ArrayKeys.GT_LABELS: sample + '/mask',
                               },
                               array_specs={
                                   ArrayKeys.RAW:
                                   ArraySpec(voxel_size=Coordinate((1, 1))),
                                   ArrayKeys.GT_LABELS:
                                   ArraySpec(voxel_size=Coordinate((1, 1)))
                               })

        data_providers.append(h5_source)

    #todo: dvid source

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((1, 1))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size

    # specifiy which volumes should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    # request.add(ArrayKeys.GT_AFF, output_size)
    request.add(ArrayKeys.GT_DIST, output_size)
    request.add(ArrayKeys.GT_DIST_SCALE, output_size)
    # request.add(ArrayKeys.GT_AFF_SCALE, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) + RandomLocation()
        +  # chose a random location inside the provided arrays
        Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95)
        for provider in data_providers)

    snapshot_request = BatchRequest({
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_LABELS],
        ArrayKeys.PREDICTED_DIST:
        request[ArrayKeys.GT_DIST],
        # ArrayKeys.PREDICTED_AFF1:      request[ArrayKeys.GT_AFF],
        # ArrayKeys.PREDICTED_AFF3:      request[ArrayKeys.GT_AFF],
        # ArrayKeys.PREDICTED_AFF9:      request[ArrayKeys.GT_AFF],
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_DIST],
    })

    train_pipeline = (
        data_sources + RandomProvider() +
        #ElasticAugment((40, 40), (2., 2.), (0, math.pi/2.0),
        #               subsample=4, spatial_dims=2) +
        #SimpleAugment() +
        IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1) +
        #DefectAugment(ArrayKeys.RAW, prob_low_contrast=0.01, contrast_scale=0.5) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        # AddAffinities([[-1, 0], [0, -1],
        #                [-3, 0], [0, -3],
        #                [-9, 0], [0, -9]],
        #               ArrayKeys.GT_LABELS,
        #               ArrayKeys.GT_AFF) +
        AddDistance(label_array_key=ArrayKeys.GT_LABELS,
                    distance_array_key=ArrayKeys.GT_DIST,
                    normalize='tanh',
                    normalize_args=150) +

        # BalanceLabels(ArrayKeys.GT_AFF, ArrayKeys.GT_AFF_SCALE) +
        BalanceByThreshold(labels=ArrayKeys.GT_LABELS,
                           scales=ArrayKeys.GT_DIST_SCALE) +
        # PreCache(
        #    cache_size=40,
        #    num_workers=10) +
        Train(
            'unet',
            optimizer=net_io_names['optimizer'],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names['raw']:
                ArrayKeys.RAW,
                net_io_names['gt_dist']:
                ArrayKeys.GT_DIST,
                # net_io_names['gt_aff']:  ArrayKeys.GT_AFF,
                net_io_names['loss_weights_dist']:
                ArrayKeys.GT_DIST_SCALE,
                # net_io_names['loss_weights_aff']: ArrayKeys.GT_AFF_SCALE
            },
            summary=net_io_names['summary'],
            log_dir='log',
            outputs={
                net_io_names['dist']: ArrayKeys.PREDICTED_DIST,
                # net_io_names['aff1']:  ArrayKeys.PREDICTED_AFF1,
                # net_io_names['aff3']:  ArrayKeys.PREDICTED_AFF3,
                # net_io_names['aff9']:  ArrayKeys.PREDICTED_AFF9
            },
            gradients={net_io_names['dist']: ArrayKeys.LOSS_GRADIENT}) +
        Snapshot(
            {
                ArrayKeys.RAW:
                'volumes/raw',
                ArrayKeys.GT_DIST:
                'volumes/labels/dist',
                # ArrayKeys.GT_AFF:         'volumes/labels/aff',
                ArrayKeys.GT_LABELS:
                'volumes/labels/nuclei',
                ArrayKeys.PREDICTED_DIST:
                'volumes/predictions/dist',
                # ArrayKeys.PREDICTED_AFF1: 'volumes/predictions/aff1',
                # ArrayKeys.PREDICTED_AFF3: 'volumes/predictions/aff3',
                # ArrayKeys.PREDICTED_AFF9: 'volumes/predictions/aff9',
                ArrayKeys.LOSS_GRADIENT:
                'volumes/loss_gradient',
            },
            every=500,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request) +
        PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
Пример #9
0
def train_until(max_iteration, data_sources, input_shape, output_shape,
                dt_scaling_factor, loss_name):
    ArrayKey("RAW")
    ArrayKey("RAW_UP")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("MASK")
    ArrayKey("MASK_UP")
    ArrayKey("GT_DIST_CENTROSOME")
    ArrayKey("GT_DIST_GOLGI")
    ArrayKey("GT_DIST_GOLGI_MEM")
    ArrayKey("GT_DIST_ER")
    ArrayKey("GT_DIST_ER_MEM")
    ArrayKey("GT_DIST_MVB")
    ArrayKey("GT_DIST_MVB_MEM")
    ArrayKey("GT_DIST_MITO")
    ArrayKey("GT_DIST_MITO_MEM")
    ArrayKey("GT_DIST_LYSOSOME")
    ArrayKey("GT_DIST_LYSOSOME_MEM")

    ArrayKey("PRED_DIST_CENTROSOME")
    ArrayKey("PRED_DIST_GOLGI")
    ArrayKey("PRED_DIST_GOLGI_MEM")
    ArrayKey("PRED_DIST_ER")
    ArrayKey("PRED_DIST_ER_MEM")
    ArrayKey("PRED_DIST_MVB")
    ArrayKey("PRED_DIST_MVB_MEM")
    ArrayKey("PRED_DIST_MITO")
    ArrayKey("PRED_DIST_MITO_MEM")
    ArrayKey("PRED_DIST_LYSOSOME")
    ArrayKey("PRED_DIST_LYSOSOME_MEM")

    ArrayKey("SCALE_CENTROSOME")
    ArrayKey("SCALE_GOLGI")
    ArrayKey("SCALE_GOLGI_MEM")
    ArrayKey("SCALE_ER")
    ArrayKey("SCALE_ER_MEM")
    ArrayKey("SCALE_MVB")
    ArrayKey("SCALE_MVB_MEM")
    ArrayKey("SCALE_MITO")
    ArrayKey("SCALE_MITO_MEM")
    ArrayKey("SCALE_LYSOSOME")
    ArrayKey("SCALE_LYSOSOME_MEM")

    data_providers = []
    data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/{0:}.n5"
    voxel_size_up = Coordinate((4, 4, 4))
    voxel_size_orig = Coordinate((8, 8, 8))
    input_size = Coordinate(input_shape) * voxel_size_orig
    output_size = Coordinate(output_shape) * voxel_size_orig

    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")
    for src in data_sources:
        n5_source = N5Source(
            os.path.join(data_dir.format(src)),
            datasets={
                ArrayKeys.RAW_UP: "volumes/raw",
                ArrayKeys.GT_LABELS: "volumes/labels/all",
                ArrayKeys.MASK_UP: "volumes/mask",
            },
            array_specs={ArrayKeys.MASK_UP: ArraySpec(interpolatable=False)},
        )
        data_providers.append(n5_source)

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.RAW_UP, input_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK_UP, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_CENTROSOME,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_GOLGI,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_GOLGI_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_ER, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_ER_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_MVB, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_MVB_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_MITO,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_MITO_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_LYSOSOME,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_LYSOSOME_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_CENTROSOME,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_GOLGI, output_size, voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_GOLGI_MEM, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_ER, output_size, voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_ER_MEM, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_MVB, output_size, voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_MVB_MEM, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_MITO, output_size, voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_MITO_MEM, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_LYSOSOME,
                output_size,
                voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_LYSOSOME_MEM, output_size, voxel_size=voxel_size_orig)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW_UP) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW_UP, None) +
        RandomLocation(min_masked=0.5, mask=ArrayKeys.MASK_UP
                       )  # chose a random location inside the provided arrays
        # Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data
        for provider in data_providers)

    snapshot_request = BatchRequest()
    snapshot_request.add(ArrayKeys.PRED_DIST_CENTROSOME, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_GOLGI, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_GOLGI_MEM, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_ER, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_ER_MEM, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_MVB, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_MVB_MEM, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_MITO, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_MITO_MEM, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_LYSOSOME, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_LYSOSOME_MEM, output_size)
    train_pipeline = (
        data_sources + RandomProvider() + ElasticAugment(
            (100, 100, 100),
            (10.0, 10.0, 10.0),
            (0, math.pi / 2.0),
            prob_slip=0,
            prob_shift=0,
            max_misalign=0,
            subsample=8,
        ) + SimpleAugment() + ElasticAugment(
            (40, 1000, 1000), (10.0, 0.0, 0.0), (0, 0), subsample=8) +
        IntensityAugment(ArrayKeys.RAW_UP, 0.9, 1.1, -0.1, 0.1) +
        IntensityScaleShift(ArrayKeys.RAW_UP, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW_UP) +
        # GrowBoundary(steps=1) +
        # SplitAndRenumberSegmentationLabels() +
        # AddGtAffinities(malis.mknhood3d()) +
        AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_CENTROSOME,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=1,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_GOLGI,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(2, 11),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_GOLGI_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=11,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_ER,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(3, 10),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_ER_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=10,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_MVB,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(4, 9),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_MVB_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=9,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_MITO,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(5, 8),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_MITO_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=8,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_LYSOSOME,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(6, 7),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_LYSOSOME_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=7,
            factor=2,
        ) + DownSample(ArrayKeys.MASK_UP, 2, ArrayKeys.MASK) +
        BalanceByThreshold(
            ArrayKeys.GT_DIST_CENTROSOME,
            ArrayKeys.SCALE_CENTROSOME,
            mask=ArrayKeys.MASK,
        ) + BalanceByThreshold(ArrayKeys.GT_DIST_GOLGI,
                               ArrayKeys.SCALE_GOLGI,
                               mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_GOLGI_MEM, ArrayKeys.SCALE_GOLGI_MEM, mask=ArrayKeys.MASK) +
        BalanceByThreshold(
            ArrayKeys.GT_DIST_ER, ArrayKeys.SCALE_ER, mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_ER_MEM, ArrayKeys.SCALE_ER_MEM, mask=ArrayKeys.MASK) +
        BalanceByThreshold(
            ArrayKeys.GT_DIST_MVB, ArrayKeys.SCALE_MVB, mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_MVB_MEM, ArrayKeys.SCALE_MVB_MEM, mask=ArrayKeys.MASK) +
        BalanceByThreshold(ArrayKeys.GT_DIST_MITO,
                           ArrayKeys.SCALE_MITO,
                           mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_MITO_MEM, ArrayKeys.SCALE_MITO_MEM, mask=ArrayKeys.MASK) +
        BalanceByThreshold(ArrayKeys.GT_DIST_LYSOSOME,
                           ArrayKeys.SCALE_LYSOSOME,
                           mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_LYSOSOME_MEM, ArrayKeys.SCALE_LYSOSOME_MEM, mask=ArrayKeys.MASK) +
        # BalanceByThreshold(
        #    labels=ArrayKeys.GT_DIST,
        #    scales= ArrayKeys.GT_SCALE) +
        # {
        #     ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE
        # },
        # {
        #     ArrayKeys.GT_AFFINITIES: ArrayKeys.MASK
        # }) +
        DownSample(ArrayKeys.RAW_UP, 2, ArrayKeys.RAW) +
        PreCache(cache_size=40, num_workers=10) + Train(
            "build",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: ArrayKeys.RAW,
                net_io_names["gt_centrosome"]: ArrayKeys.GT_DIST_CENTROSOME,
                net_io_names["gt_golgi"]: ArrayKeys.GT_DIST_GOLGI,
                net_io_names["gt_golgi_mem"]: ArrayKeys.GT_DIST_GOLGI_MEM,
                net_io_names["gt_er"]: ArrayKeys.GT_DIST_ER,
                net_io_names["gt_er_mem"]: ArrayKeys.GT_DIST_ER_MEM,
                net_io_names["gt_mvb"]: ArrayKeys.GT_DIST_MVB,
                net_io_names["gt_mvb_mem"]: ArrayKeys.GT_DIST_MVB_MEM,
                net_io_names["gt_mito"]: ArrayKeys.GT_DIST_MITO,
                net_io_names["gt_mito_mem"]: ArrayKeys.GT_DIST_MITO_MEM,
                net_io_names["gt_lysosome"]: ArrayKeys.GT_DIST_LYSOSOME,
                net_io_names["gt_lysosome_mem"]:
                ArrayKeys.GT_DIST_LYSOSOME_MEM,
                net_io_names["w_centrosome"]: ArrayKeys.SCALE_CENTROSOME,
                net_io_names["w_golgi"]: ArrayKeys.SCALE_GOLGI,
                net_io_names["w_golgi_mem"]: ArrayKeys.SCALE_GOLGI,
                net_io_names["w_er"]: ArrayKeys.SCALE_ER,
                net_io_names["w_er_mem"]: ArrayKeys.SCALE_ER,
                net_io_names["w_mvb"]: ArrayKeys.SCALE_MVB,
                net_io_names["w_mvb_mem"]: ArrayKeys.SCALE_MVB,
                net_io_names["w_mito"]: ArrayKeys.SCALE_MITO,
                net_io_names["w_mito_mem"]: ArrayKeys.SCALE_MITO,
                net_io_names["w_lysosome"]: ArrayKeys.SCALE_LYSOSOME,
                net_io_names["w_lysosome_mem"]: ArrayKeys.SCALE_LYSOSOME,
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={
                net_io_names["centrosome"]: ArrayKeys.PRED_DIST_CENTROSOME,
                net_io_names["golgi"]: ArrayKeys.PRED_DIST_GOLGI,
                net_io_names["golgi_mem"]: ArrayKeys.PRED_DIST_GOLGI_MEM,
                net_io_names["er"]: ArrayKeys.PRED_DIST_ER,
                net_io_names["er_mem"]: ArrayKeys.PRED_DIST_ER_MEM,
                net_io_names["mvb"]: ArrayKeys.PRED_DIST_MVB,
                net_io_names["mvb_mem"]: ArrayKeys.PRED_DIST_MVB_MEM,
                net_io_names["mito"]: ArrayKeys.PRED_DIST_MITO,
                net_io_names["mito_mem"]: ArrayKeys.PRED_DIST_MITO_MEM,
                net_io_names["lysosome"]: ArrayKeys.PRED_DIST_LYSOSOME,
                net_io_names["lysosome_mem"]: ArrayKeys.PRED_DIST_LYSOSOME_MEM,
            },
            gradients={},
        ) + Snapshot(
            {
                ArrayKeys.RAW:
                "volumes/raw",
                ArrayKeys.GT_LABELS:
                "volumes/labels/gt_labels",
                ArrayKeys.GT_DIST_CENTROSOME:
                "volumes/labels/gt_dist_centrosome",
                ArrayKeys.PRED_DIST_CENTROSOME:
                "volumes/labels/pred_dist_centrosome",
                ArrayKeys.GT_DIST_GOLGI:
                "volumes/labels/gt_dist_golgi",
                ArrayKeys.PRED_DIST_GOLGI:
                "volumes/labels/pred_dist_golgi",
                ArrayKeys.GT_DIST_GOLGI_MEM:
                "volumes/labels/gt_dist_golgi_mem",
                ArrayKeys.PRED_DIST_GOLGI_MEM:
                "volumes/labels/pred_dist_golgi_mem",
                ArrayKeys.GT_DIST_ER:
                "volumes/labels/gt_dist_er",
                ArrayKeys.PRED_DIST_ER:
                "volumes/labels/pred_dist_er",
                ArrayKeys.GT_DIST_ER_MEM:
                "volumes/labels/gt_dist_er_mem",
                ArrayKeys.PRED_DIST_ER_MEM:
                "volumes/labels/pred_dist_er_mem",
                ArrayKeys.GT_DIST_MVB:
                "volumes/labels/gt_dist_mvb",
                ArrayKeys.PRED_DIST_MVB:
                "volumes/labels/pred_dist_mvb",
                ArrayKeys.GT_DIST_MVB_MEM:
                "volumes/labels/gt_dist_mvb_mem",
                ArrayKeys.PRED_DIST_MVB_MEM:
                "volumes/labels/pred_dist_mvb_mem",
                ArrayKeys.GT_DIST_MITO:
                "volumes/labels/gt_dist_mito",
                ArrayKeys.PRED_DIST_MITO:
                "volumes/labels/pred_dist_mito",
                ArrayKeys.GT_DIST_MITO_MEM:
                "volumes/labels/gt_dist_mito_mem",
                ArrayKeys.PRED_DIST_MITO_MEM:
                "volumes/labels/pred_dist_mito_mem",
                ArrayKeys.GT_DIST_LYSOSOME:
                "volumes/labels/gt_dist_lysosome",
                ArrayKeys.PRED_DIST_LYSOSOME:
                "volumes/labels/pred_dist_lysosome",
                ArrayKeys.GT_DIST_LYSOSOME_MEM:
                "volumes/labels/gt_dist_lysosome_mem",
                ArrayKeys.PRED_DIST_LYSOSOME_MEM:
                "volumes/labels/pred_dist_lysosome_mem",
            },
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        ) + PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
Пример #10
0
def train_until(max_iteration, data_sources, input_shape, output_shape,
                dt_scaling_factor, loss_name, cremi_version, aligned):
    ArrayKey('RAW')
    ArrayKey('ALPHA_MASK')
    ArrayKey('GT_LABELS')
    ArrayKey('GT_CLEFTS')
    ArrayKey('GT_MASK')
    ArrayKey('TRAINING_MASK')
    ArrayKey('CLEFT_SCALE')
    ArrayKey('PRE_SCALE')
    ArrayKey('POST_SCALE')
    ArrayKey('LOSS_GRADIENT')
    ArrayKey('GT_CLEFT_DIST')
    ArrayKey('PRED_CLEFT_DIST')
    ArrayKey('GT_PRE_DIST')
    ArrayKey('PRED_PRE_DIST')
    ArrayKey('GT_POST_DIST')
    ArrayKey('PRED_POST_DIST')
    ArrayKey('GT_POST_DIST')
    data_providers = []
    if cremi_version == '2016':
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/"
        filename = 'sample_{0:}_padded_20160501.'
    elif cremi_version == '2017':
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/"
        filename = 'sample_{0:}_padded_20170424.'
    if aligned:
        filename += 'aligned.'
    filename += '0bg.hdf'
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')
    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(cremi_dir, filename.format(sample)),
            datasets={
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_CLEFTS: 'volumes/labels/clefts',
                ArrayKeys.GT_MASK: 'volumes/masks/groundtruth',
                ArrayKeys.TRAINING_MASK: 'volumes/masks/validation',
                ArrayKeys.GT_LABELS: 'volumes/labels/neuron_ids'
            },
            array_specs={
                ArrayKeys.GT_MASK: ArraySpec(interpolatable=False),
                ArrayKeys.GT_CLEFTS: ArraySpec(interpolatable=False),
                ArrayKeys.TRAINING_MASK: ArraySpec(interpolatable=False)
            })
        data_providers.append(h5_source)

    if cremi_version == '2017':
        csv_files = [
            os.path.join(cremi_dir, 'cleft-partners_' + sample + '_2017.csv')
            for sample in data_sources
        ]
    elif cremi_version == '2016':
        csv_files = [
            os.path.join(
                cremi_dir,
                'cleft-partners-' + sample + '-20160501.aligned.corrected.csv')
            for sample in data_sources
        ]
    cleft_to_pre, cleft_to_post = make_cleft_to_prepostsyn_neuron_id_dict(
        csv_files)
    print(cleft_to_pre, cleft_to_post)
    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    context = input_size - output_size
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_CLEFTS, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(ArrayKeys.CLEFT_SCALE, output_size)
    request.add(ArrayKeys.GT_CLEFT_DIST, output_size)
    request.add(ArrayKeys.GT_PRE_DIST, output_size)
    request.add(ArrayKeys.GT_POST_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        IntensityScaleShift(ArrayKeys.TRAINING_MASK, -1, 1) +
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) +
        Pad(ArrayKeys.TRAINING_MASK, context) +
        RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK)
        +  # chose a random location inside the provided arrays
        Reject(ArrayKeys.GT_MASK)
        +  # reject batches which do contain less than 50% labelled data
        Reject(ArrayKeys.GT_CLEFTS, min_masked=0.0, reject_probability=0.95)
        for provider in data_providers)

    snapshot_request = BatchRequest({
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_CLEFTS],
        ArrayKeys.PRED_CLEFT_DIST:
        request[ArrayKeys.GT_CLEFT_DIST],
        ArrayKeys.PRED_PRE_DIST:
        request[ArrayKeys.GT_PRE_DIST],
        ArrayKeys.PRED_POST_DIST:
        request[ArrayKeys.GT_POST_DIST],
    })

    train_pipeline = (
        data_sources + RandomProvider() + ElasticAugment(
            (4, 40, 40), (0., 0., 0.), (0, math.pi / 2.0),
            prob_slip=0.0,
            prob_shift=0.0,
            max_misalign=0,
            subsample=8) +
        SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) +
        IntensityAugment(
            ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=False) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) +
        AddDistance(label_array_key=ArrayKeys.GT_CLEFTS,
                    distance_array_key=ArrayKeys.GT_CLEFT_DIST,
                    normalize='tanh',
                    normalize_args=dt_scaling_factor) +
        AddPrePostCleftDistance(ArrayKeys.GT_CLEFTS,
                                ArrayKeys.GT_LABELS,
                                ArrayKeys.GT_PRE_DIST,
                                ArrayKeys.GT_POST_DIST,
                                cleft_to_pre,
                                cleft_to_post,
                                normalize='tanh',
                                normalize_args=dt_scaling_factor,
                                include_cleft=False) +
        BalanceByThreshold(labels=ArrayKeys.GT_CLEFT_DIST,
                           scales=ArrayKeys.CLEFT_SCALE,
                           mask=ArrayKeys.GT_MASK) +
        BalanceByThreshold(labels=ArrayKeys.GT_PRE_DIST,
                           scales=ArrayKeys.PRE_SCALE,
                           mask=ArrayKeys.GT_MASK,
                           threshold=-0.5) +
        BalanceByThreshold(labels=ArrayKeys.GT_POST_DIST,
                           scales=ArrayKeys.POST_SCALE,
                           mask=ArrayKeys.GT_MASK,
                           threshold=-0.5) +
        PreCache(cache_size=40, num_workers=10) + Train(
            'unet',
            optimizer=net_io_names['optimizer'],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names['raw']: ArrayKeys.RAW,
                net_io_names['gt_cleft_dist']: ArrayKeys.GT_CLEFT_DIST,
                net_io_names['gt_pre_dist']: ArrayKeys.GT_PRE_DIST,
                net_io_names['gt_post_dist']: ArrayKeys.GT_POST_DIST,
                net_io_names['loss_weights_cleft']: ArrayKeys.CLEFT_SCALE,
                net_io_names['loss_weights_pre']: ArrayKeys.CLEFT_SCALE,
                net_io_names['loss_weights_post']: ArrayKeys.CLEFT_SCALE,
                net_io_names['mask']: ArrayKeys.GT_MASK
            },
            summary=net_io_names['summary'],
            log_dir='log',
            outputs={
                net_io_names['cleft_dist']: ArrayKeys.PRED_CLEFT_DIST,
                net_io_names['pre_dist']: ArrayKeys.PRED_PRE_DIST,
                net_io_names['post_dist']: ArrayKeys.PRED_POST_DIST
            },
            gradients={net_io_names['cleft_dist']: ArrayKeys.LOSS_GRADIENT}) +
        Snapshot(
            {
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_CLEFTS: 'volumes/labels/gt_clefts',
                ArrayKeys.GT_CLEFT_DIST: 'volumes/labels/gt_clefts_dist',
                ArrayKeys.PRED_CLEFT_DIST: 'volumes/labels/pred_clefts_dist',
                ArrayKeys.LOSS_GRADIENT: 'volumes/loss_gradient',
                ArrayKeys.PRED_PRE_DIST: 'volumes/labels/pred_pre_dist',
                ArrayKeys.PRED_POST_DIST: 'volumes/labels/pred_post_dist',
                ArrayKeys.GT_PRE_DIST: 'volumes/labels/gt_pre_dist',
                ArrayKeys.GT_POST_DIST: 'volumes/labels/gt_post_dist'
            },
            every=500,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request) +
        PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
Пример #11
0
def train_until(
    max_iteration,
    data_sources,
    ribo_sources,
    dt_scaling_factor,
    loss_name,
    labels,
    scnet,
    raw_name="raw",
    min_masked_voxels=17561.0,
):
    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("MASK")
    ArrayKey("RIBO_GT")

    datasets_ribo = {
        ArrayKeys.GT_LABELS: "volumes/labels/all",
        ArrayKeys.MASK: "volumes/masks/training_cropped",
        ArrayKeys.RIBO_GT: "volumes/labels/ribosomes",
    }
    # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all
    # instead (only one with the right resolution)
    datasets_no_ribo = {
        ArrayKeys.GT_LABELS: "volumes/labels/all",
        ArrayKeys.MASK: "volumes/masks/training_cropped",
        ArrayKeys.RIBO_GT: "volumes/labels/all",
    }
    array_specs = {ArrayKeys.MASK: ArraySpec(interpolatable=False)}
    array_specs_pred = {}

    # individual mask per label
    for label in labels:
        datasets_no_ribo[label.mask_key] = "volumes/masks/" + label.labelname
        datasets_ribo[label.mask_key] = "volumes/masks/" + label.labelname
        array_specs[label.mask_key] = ArraySpec(interpolatable=False)
    # inputs = {net_io_names['mask']: ArrayKeys.MASK}
    snapshot = {ArrayKeys.GT_LABELS: "volumes/labels/all"}

    request = BatchRequest()
    snapshot_request = BatchRequest()

    raw_array_keys = []
    contexts = []
    # input and output sizes in world coordinates
    input_sizes_wc = [
        Coordinate(inp_sh) * Coordinate(vs)
        for inp_sh, vs in zip(scnet.input_shapes, scnet.voxel_sizes)
    ]
    output_size_wc = Coordinate(scnet.output_shapes[0]) * Coordinate(
        scnet.voxel_sizes[0]
    )
    keep_thr = float(min_masked_voxels) / np.prod(scnet.output_shapes[0])

    voxel_size_up = Coordinate((2, 2, 2))
    voxel_size_orig = Coordinate((4, 4, 4))
    assert voxel_size_orig == Coordinate(
        scnet.voxel_sizes[0]
    )  # make sure that scnet has the same base voxel size
    inputs = {}
    # add multiscale raw data as inputs
    for k, (inp_sh_wc, vs) in enumerate(zip(input_sizes_wc, scnet.voxel_sizes)):
        ak = ArrayKey("RAW_S{0:}".format(k))
        raw_array_keys.append(ak)
        datasets_ribo[ak] = "volumes/{0:}/data/s{1:}".format(raw_name, k)
        datasets_no_ribo[ak] = "volumes/{0:}/data/s{1:}".format(raw_name, k)
        inputs[net_io_names["raw_{0:}".format(vs[0])]] = ak
        snapshot[ak] = "volumes/raw_s{0:}".format(k)
        array_specs[ak] = ArraySpec(voxel_size=Coordinate(vs))
        request.add(ak, inp_sh_wc, voxel_size=Coordinate(vs))
        contexts.append(inp_sh_wc - output_size_wc)
    outputs = dict()
    for label in labels:
        inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key
        if label.scale_loss or label.scale_key is not None:
            inputs[net_io_names["w_" + label.labelname]] = label.scale_key
        inputs[net_io_names["mask_" + label.labelname]] = label.mask_key
        outputs[net_io_names[label.labelname]] = label.pred_dist_key
        snapshot[label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname
        snapshot[label.pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname
        array_specs_pred[label.pred_dist_key] = ArraySpec(
            voxel_size=voxel_size_orig, interpolatable=True
        )

    request.add(ArrayKeys.GT_LABELS, output_size_wc, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK, output_size_wc, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.RIBO_GT, output_size_wc, voxel_size=voxel_size_up)
    for label in labels:
        request.add(label.gt_dist_key, output_size_wc, voxel_size=voxel_size_orig)
        snapshot_request.add(
            label.pred_dist_key, output_size_wc, voxel_size=voxel_size_orig
        )
        request.add(label.pred_dist_key, output_size_wc, voxel_size=voxel_size_orig)
        request.add(label.mask_key, output_size_wc, voxel_size=voxel_size_orig)
        if label.scale_loss:
            request.add(label.scale_key, output_size_wc, voxel_size=voxel_size_orig)

    data_providers = []
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")
    for src in data_sources:

        if src not in ribo_sources:
            n5_source = N5Source(
                src.full_path, datasets=datasets_no_ribo, array_specs=array_specs
            )
        else:
            n5_source = N5Source(
                src.full_path, datasets=datasets_ribo, array_specs=array_specs
            )

        data_providers.append(n5_source)

    data_stream = []
    for provider in data_providers:
        data_stream.append(provider)
        for ak, context in zip(raw_array_keys, contexts):
            data_stream[-1] += Normalize(ak)
            # data_stream[-1] += Pad(ak, context) # this shouldn't be necessary as I cropped the input data to have
            # sufficient padding
        data_stream[-1] += RandomLocation()
        data_stream[-1] += Reject(ArrayKeys.MASK, min_masked=keep_thr)
    data_stream = tuple(data_stream)

    train_pipeline = (
        data_stream
        + RandomProvider(tuple([ds.labeled_voxels for ds in data_sources]))
        + gpn.SimpleAugment()
        + gpn.ElasticAugment(
            tuple(scnet.voxel_sizes[0]),
            (100, 100, 100),
            (10.0, 10.0, 10.0),
            (0, math.pi / 2.0),
            spatial_dims=3,
            subsample=8,
        )
        + gpn.IntensityAugment(raw_array_keys, 0.25, 1.75, -0.5, 0.35)
        + GammaAugment(raw_array_keys, 0.5, 2.0)
    )
    for ak in raw_array_keys:
        train_pipeline += IntensityScaleShift(ak, 2, -1)
        # train_pipeline += ZeroOutConstSections(ak)

    for label in labels:
        if label.labelname != "ribosomes":
            train_pipeline += AddDistance(
                label_array_key=ArrayKeys.GT_LABELS,
                distance_array_key=label.gt_dist_key,
                normalize="tanh",
                normalize_args=dt_scaling_factor,
                label_id=label.labelid,
                factor=2,
            )
        else:
            train_pipeline += AddDistance(
                label_array_key=ArrayKeys.RIBO_GT,
                distance_array_key=label.gt_dist_key,
                normalize="tanh+",
                normalize_args=(dt_scaling_factor, 8),
                label_id=label.labelid,
                factor=2,
            )

    for label in labels:
        if label.scale_loss:
            train_pipeline += BalanceByThreshold(
                label.gt_dist_key, label.scale_key, mask=label.mask_key
            )

    train_pipeline = (
        train_pipeline
        + PreCache(cache_size=10, num_workers=40)
        + Train(
            scnet.name,
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs=inputs,
            summary=net_io_names["summary"],
            log_dir="log",
            outputs=outputs,
            gradients={},
            log_every=5,
            save_every=500,
            array_specs=array_specs_pred,
        )
        + Snapshot(
            snapshot,
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        )
        + PrintProfilingStats(every=10)
    )

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            start_it = time.time()
            b.request_batch(request)
            time_it = time.time() - start_it
            logging.info("it {0:}: {1:}".format(i + 1, time_it))
    print("Training finished")
Пример #12
0
def train_until(
    max_iteration: int,
    gt_version: str,
    labels: List[CNNectome.utils.label.Label],
    net_name: str,
    input_shape: Union[np.ndarray, List[int]],
    output_shape: Union[np.ndarray, List[int]],
    loss_name: str,
    balance_global: bool = False,
    data_dir: Optional[str] = None,
    prioritized_label: Optional[CNNectome.utils.label.Label] = None,
    dataset: Optional[str] = None,
    prob_prioritized: float = 0.5,
    completion_min: int = 6,
    dt_scaling_factor: int = 50,
    cache_size: int = 5,
    num_workers: int = 10,
    min_masked_voxels: Union[float, int] = 17561.,
    voxel_size_labels: Coordinate = Coordinate((2, 2, 2)),
    voxel_size: Coordinate = Coordinate((4, 4, 4)),
    voxel_size_input: Coordinate = Coordinate((4, 4, 4))
):
    """
    Training a tensorflow network to learn signed distance transforms of specified labels (organelles) using gunpowder.
    Training data is read from crops whose metadata are organized in a database.

    Args:
        max_iteration: Total number of iterations that network should be trained for.
        gt_version: Version of groundtruth annotations, e.g. "v0003".
        labels: List of labels that the network needs to be trained for.
        net_name: Filename of tensorflow meta graph definition.
        input_shape: Input shape of network.
        output_shape: Output shape of network.
        loss_name: Name of loss used as stored in net io names json file.
        balance_global: If Ture, use globabl balancing, i.e. weigh loss for each label using its `frac_pos` and
                        `frac_neg` attributes.
        data_dir: Path to directory where data is stored. If None, read from config file.
        prioritized_label: Label to use for prioritizing sampling from crops that contain examples of it. If None
                           (default), sample from each crop equally.
        dataset: Only consider crops that come from the specified dataset. If None (default), use all othwerwise
                 eligible training data.
        prob_prioritized: If `prioritized_label` is not None, this is the probability with which to sample from the
                          crops containing the label. Default is .5, which implies sampling equally from crops
                          containing the labels and all others.
        completion_min: Minimal completion status for a crop from the database to be added to the training.
        dt_scaling_factor: Scaling factor to divide distance transform by before applying nonlinearity tanh.
        cache_size: Cache size for queue grabbing batches.
        num_workers: Number of workers grabbing batches.
        min_masked_voxels: Minimum number of voxels in a batch that need to be part of the groundtruth annotation.
        voxel_size_labels: Voxel size of the annotated labels.
        voxel_size: Voxel size of the desired output.
        voxel_size_input: Voxel size of the raw input data.
    """

    keep_thr = float(min_masked_voxels) / np.prod(output_shape)
    one_vx_thr = 1. / np.prod(output_shape)
    max_distance = 2.76 * dt_scaling_factor

    ak_raw = ArrayKey("RAW")
    ak_labels = ArrayKey("GT_LABELS")
    ak_labels_downsampled = ArrayKey("GT_LABELS_DOWNSAMPLED")
    ak_mask = ArrayKey("MASK")
    ak_labelmasks_comb = ArrayKey("LABELMASKS_COMBINED")
    input_size = Coordinate(input_shape) * voxel_size_input
    output_size = Coordinate(output_shape) * voxel_size
    crop_width = Coordinate((max_distance,) * len(voxel_size_labels))
    crop_width = crop_width//voxel_size
    if crop_width == 0:
        crop_width *= voxel_size
    else:
        crop_width = (crop_width+(1,)*len(crop_width)) * voxel_size
    # crop_width = crop_width  # (Coordinate((max_distance,) * len(voxel_size_labels))/2 )

    db = CNNectome.utils.cosem_db.MongoCosemDB(gt_version=gt_version)
    collection = db.access("crops", db.gt_version)
    db_filter = {"completion": {"$gte": completion_min}}
    if dataset is not None:
        db_filter['dataset_id'] = dataset
    skip = {"_id": 0, "number": 1, "labels": 1, "dataset_id": 1, "parent":1, "dimensions": 1}

    net_io_names, start_iteration, inputs, outputs = _network_setup(max_iteration, ak_raw, ak_mask, labels)

    # construct batch request
    request = BatchRequest()
    request.add(ak_labels, output_size, voxel_size=voxel_size_labels)
    request.add(ak_labels_downsampled, output_size, voxel_size=voxel_size)
    request.add(ak_mask, output_size, voxel_size=voxel_size)
    request.add(ak_labelmasks_comb, output_size, voxel_size=voxel_size)
    request.add(ak_raw, input_size, voxel_size=voxel_size_input)
    for label in labels:
        if label.separate_labelset:
            request.add(label.gt_key, output_size, voxel_size=voxel_size_labels)
        request.add(label.gt_dist_key, output_size, voxel_size=voxel_size)
        request.add(label.pred_dist_key, output_size, voxel_size=voxel_size)
        request.add(label.mask_key, output_size, voxel_size=voxel_size)
        if label.scale_loss:
            request.add(label.scale_key, output_size, voxel_size=voxel_size)

    # specify specs for output
    array_specs_pred = dict()
    for label in labels:
        array_specs_pred[label.pred_dist_key] = ArraySpec(voxel_size=voxel_size,
                                                          interpolatable=True)
    # specify snapshot data layout
    snapshot_data = dict()
    snapshot_data[ak_raw] = "volumes/raw"
    snapshot_data[ak_mask] = "volumes/masks/all"
    if len(_label_filter(lambda l: not l.separate_labelset, labels)) > 0:
        snapshot_data[ak_labels] = "volumes/labels/gt_labels"
    for label in _label_filter(lambda l: l.separate_labelset, labels):
        snapshot_data[label.gt_key] = "volumes/labels/gt_"+label.labelname
    for label in labels:
        snapshot_data[label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname
        snapshot_data[label.pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname
        snapshot_data[label.mask_key] = "volumes/masks/" + label.labelname

    # specify snapshot request
    snapshot_request = BatchRequest()

    crop_srcs = []
    crop_sizes = []
    if prioritized_label is not None:
        crop_prioritized_label_indicator = []

    for crop in collection.find(db_filter, skip):
        if len(set(get_all_annotated_label_ids(crop)).intersection(set(get_all_labelids(labels)))) > 0:
            logging.info("Adding crop number {0:}".format(crop["number"]))
            if voxel_size_input != voxel_size:
                for subsample_variant in range(int(np.prod(voxel_size_input/voxel_size))):
                    crop_srcs.append(
                        _make_crop_source(crop, data_dir, subsample_variant, gt_version, labels, ak_raw, ak_labels,
                                          ak_labels_downsampled, ak_mask, input_size, output_size, voxel_size_input,
                                          voxel_size, crop_width, keep_thr))
                    crop_sizes.append(get_crop_size(crop))
                if prioritized_label is not None:
                    crop_prioritized = is_prioritized(crop, prioritized_label)
                    logging.info(f"Crop {crop['number']} is {'not ' if not crop_prioritized else ''}prioritized")
                    crop_prioritized_label_indicator.extend(
                        [crop_prioritized] * int(np.prod(voxel_size_input/voxel_size))
                    )
            else:
                crop_srcs.append(_make_crop_source(crop, data_dir, None, gt_version, labels, ak_raw, ak_labels,
                                                   ak_labels_downsampled, ak_mask, input_size, output_size,
                                                   voxel_size_input, voxel_size, crop_width, keep_thr))
                crop_sizes.append(get_crop_size(crop))
                if prioritized_label is not None:
                    crop_prioritized = is_prioritized(crop, prioritized_label)
                    logging.info(f"Crop {crop['number']} is {'not ' if not crop_prioritized else ''}prioritized")
                    crop_prioritized_label_indicator.append(crop_prioritized)

    if prioritized_label is not None:
        sampling_probs = prioritized_sampling_probabilities(
            crop_sizes, crop_prioritized_label_indicator, prob_prioritized
        )
    else:
        sampling_probs = crop_sizes
    print(sampling_probs)
    pipeline = (tuple(crop_srcs)
                + RandomProvider(sampling_probs)
                )

    pipeline += Normalize(ak_raw, 1.0/255)
    pipeline += IntensityCrop(ak_raw, 0., 1.)

    # augmentations
    pipeline = (pipeline
                + fuse.SimpleAugment()
                + fuse.ElasticAugment(voxel_size,
                                      (100, 100, 100),
                                      (10., 10., 10.),
                                      (0, math.pi / 2.),
                                      spatial_dims=3,
                                      subsample=8
                                      )
                + fuse.IntensityAugment(ak_raw, 0.25, 1.75, -0.5, 0.35)
                + GammaAugment(ak_raw, 0.5, 2.)
                )

    pipeline += IntensityScaleShift(ak_raw, 2, -1)

    # label generation
    for label in labels:
        pipeline += AddDistance(
            label_array_key=label.gt_key,
            distance_array_key=label.gt_dist_key,
            mask_array_key=label.mask_key,
            add_constant=label.add_constant,
            label_id=label.labelid,
            factor=2,
            max_distance=max_distance,
        )

    # combine distances for centrosomes

    centrosome = _get_label("centrosome", labels)
    microtubules = _get_label("microtubules", labels)
    microtubules_out = _get_label("microtubules_out", labels)
    subdistal_app = _get_label("subdistal_app", labels)
    distal_app = _get_label("distal_app", labels)

    # add the centrosomes to the microtubules
    if microtubules_out is not None and centrosome is not None:
        pipeline += CombineDistances(
            (microtubules_out.gt_dist_key, centrosome.gt_dist_key),
            microtubules_out.gt_dist_key,
            (microtubules_out.mask_key, centrosome.mask_key),
            microtubules_out.mask_key
        )
    if microtubules is not None and centrosome is not None:
        pipeline += CombineDistances(
            (microtubules.gt_dist_key, centrosome.gt_dist_key),
            microtubules.gt_dist_key,
            (microtubules.mask_key, centrosome.mask_key),
            microtubules.mask_key
        )

    # add the distal_app and subdistal_app to the centrosomes
    if centrosome is not None and distal_app is not None and subdistal_app is not None:
        pipeline += CombineDistances(
            (distal_app.gt_dist_key, subdistal_app.gt_dist_key, centrosome.gt_dist_key),
            centrosome.gt_dist_key,
            (distal_app.mask_key, subdistal_app.mask_key, centrosome.mask_key),
            centrosome.mask_key
        )

    arrays_that_need_to_be_cropped = []

    for label in labels:
        arrays_that_need_to_be_cropped.append(label.gt_key)
        arrays_that_need_to_be_cropped.append(label.gt_dist_key)
        arrays_that_need_to_be_cropped.append(label.mask_key)
    arrays_that_need_to_be_cropped.append(ak_labels)
    arrays_that_need_to_be_cropped.append(ak_labels_downsampled)
    arrays_that_need_to_be_cropped = list(set(arrays_that_need_to_be_cropped))
    for ak in arrays_that_need_to_be_cropped:
        pipeline += CropArray(ak, crop_width, crop_width)

    for label in labels:
        pipeline += TanhSaturate(label.gt_dist_key, dt_scaling_factor)

    for label in _label_filter(lambda l: l.scale_loss, labels):
        if balance_global:
            pipeline += BalanceGlobalByThreshold(
                label.gt_dist_key,
                label.scale_key,
                label.frac_pos,
                label.frac_neg
            )
        else:
            pipeline += BalanceByThreshold(
                label.gt_dist_key,
                label.scale_key,
                mask=(label.mask_key, ak_mask)
                )
    pipeline += Sum([l.mask_key for l in labels], ak_labelmasks_comb, sum_array_spec=ArraySpec(
                    dtype=np.uint8, interpolatable=False))
    pipeline += Reject(ak_labelmasks_comb, min_masked=one_vx_thr)

    pipeline = (pipeline
                + PreCache(cache_size=cache_size,
                           num_workers=num_workers)
                + Train(net_name,
                        optimizer=net_io_names["optimizer"],
                        loss=net_io_names[loss_name],
                        inputs=inputs,
                        summary=net_io_names["summary"],
                        log_dir="log",
                        outputs=outputs,
                        gradients={},
                        log_every=10,
                        save_every=500,
                        array_specs=array_specs_pred,
                        )
                + Snapshot(snapshot_data,
                           every=500,
                           output_filename="batch_{iteration}.hdf",
                           output_dir="snapshots/",
                           additional_request=snapshot_request,
                           )
                + PrintProfilingStats(every=50)
                )

    logging.info("Starting training...")
    with build(pipeline) as pp:
        for i in range(start_iteration, max_iteration+1):
            start_it = time.time()
            pp.request_batch(request)
            time_it = time.time() - start_it
            logging.info("it{0:}: {1:}".format(i+1, time_it))
    logging.info("Training finished")