예제 #1
0
def train_until(max_iteration, data_sources, input_shape, output_shape,
                dt_scaling_factor, loss_name):
    ArrayKey("RAW")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("GT_MASK")
    ArrayKey("TRAINING_MASK")
    ArrayKey("GT_SCALE")
    ArrayKey("LOSS_GRADIENT")
    ArrayKey("GT_DIST")
    ArrayKey("PREDICTED_DIST_LABELS")

    data_providers = []
    if cremi_version == "2016":
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/"
        filename = "sample_{0:}_padded_20160501."
    elif cremi_version == "2017":
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/"
        filename = "sample_{0:}_padded_20170424."
    if aligned:
        filename += "aligned."
    filename += "0bg.hdf"
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")
    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(cremi_dir, filename.format(sample)),
            datasets={
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_LABELS: "volumes/labels/clefts",
                ArrayKeys.GT_MASK: "volumes/masks/groundtruth",
                ArrayKeys.TRAINING_MASK: "volumes/masks/validation",
            },
            array_specs={ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)},
        )
        data_providers.append(h5_source)

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    context = input_size - output_size
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(ArrayKeys.GT_SCALE, output_size)
    request.add(ArrayKeys.GT_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize(ArrayKeys.RAW) +
        IntensityScaleShift(  # ensures RAW is in float in [0, 1]
            ArrayKeys.TRAINING_MASK, -1, 1) +
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) +
        Pad(ArrayKeys.TRAINING_MASK, context) +
        RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK) +
        Reject(ArrayKeys.GT_MASK) +
        Reject(  # reject batches wich do contain less than 50% labelled data
            ArrayKeys.GT_LABELS,
            min_masked=0.0,
            reject_probability=0.95) for provider in data_providers)

    snapshot_request = BatchRequest({
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_LABELS],
        ArrayKeys.PREDICTED_DIST_LABELS:
        request[ArrayKeys.GT_LABELS],
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_DIST],
    })

    train_pipeline = (
        data_sources + RandomProvider() + ElasticAugment(
            (4, 40, 40),
            (0.0, 0.0, 0.0),
            (0, math.pi / 2.0),
            prob_slip=0.0,
            prob_shift=0.0,
            max_misalign=0,
            subsample=8,
        ) + SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) +
        IntensityAugment(
            ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=False) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
        ) + BalanceByThreshold(
            ArrayKeys.GT_LABELS, ArrayKeys.GT_SCALE, mask=ArrayKeys.GT_MASK) +
        PreCache(cache_size=40, num_workers=10) + Train(
            "unet",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: ArrayKeys.RAW,
                net_io_names["gt_dist"]: ArrayKeys.GT_DIST,
                net_io_names["loss_weights"]: ArrayKeys.GT_SCALE,
                net_io_names["mask"]: ArrayKeys.GT_MASK,
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={net_io_names["dist"]: ArrayKeys.PREDICTED_DIST_LABELS},
            gradients={net_io_names["dist"]: ArrayKeys.LOSS_GRADIENT},
        ) + Snapshot(
            {
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_LABELS: "volumes/labels/gt_clefts",
                ArrayKeys.GT_DIST: "volumes/labels/gt_clefts_dist",
                ArrayKeys.PREDICTED_DIST_LABELS:
                "volumes/labels/pred_clefts_dist",
                ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient",
            },
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        ) + PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
예제 #2
0
def train_until(max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name):
    ArrayKey('RAW')
    ArrayKey('ALPHA_MASK')
    ArrayKey('GT_LABELS')
    ArrayKey('GT_MASK')
    ArrayKey('TRAINING_MASK')
    ArrayKey('GT_SCALE')
    ArrayKey('LOSS_GRADIENT')
    ArrayKey('GT_DIST')
    ArrayKey('PREDICTED_DIST_LABELS')

    data_providers = []
    cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/"
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')
    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(cremi_dir, 'sample_'+sample+'_cleftsorig.hdf'),
            datasets={
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_LABELS: 'volumes/labels/clefts',
                ArrayKeys.GT_MASK: 'volumes/masks/groundtruth',
                ArrayKeys.TRAINING_MASK: 'volumes/masks/training'
            },
            array_specs={
                ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)
            }
        )
        data_providers.append(h5_source)

    #todo: dvid source

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)



    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    # input_size = Coordinate((132,)*3) * voxel_size
    # output_size = Coordinate((44,)*3) * voxel_size

    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(ArrayKeys.GT_SCALE, output_size)
    request.add(ArrayKeys.GT_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW) + # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(
            {
                ArrayKeys.RAW: Coordinate((8, 8, 8)) * voxel_size,
                ArrayKeys.GT_MASK: Coordinate((8, 8, 8)) * voxel_size,
                ArrayKeys.TRAINING_MASK: Coordinate((8, 8, 8)) * voxel_size
              #ArrayKeys.GT_LABELS: Coordinate((100, 100, 100)) * voxel_size # added later
            }
        ) +
        RandomLocation() + # chose a random location inside the provided arrays
        Reject(ArrayKeys.GT_MASK) + # reject batches wich do contain less than 50% labelled data
        Reject(ArrayKeys.TRAINING_MASK, min_masked=0.99) +
        Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95)

        for provider in data_providers)

    snapshot_request = BatchRequest({
        ArrayKeys.LOSS_GRADIENT:         request[ArrayKeys.GT_LABELS],
        ArrayKeys.PREDICTED_DIST_LABELS: request[ArrayKeys.GT_LABELS],
        ArrayKeys.LOSS_GRADIENT:         request[ArrayKeys.GT_DIST],

    })

    artifact_source = (
        Hdf5Source(
            os.path.join(cremi_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                ArrayKeys.RAW:        'defect_sections/raw',
                ArrayKeys.ALPHA_MASK: 'defect_sections/mask',
            },
            array_specs={
                ArrayKeys.RAW:        ArraySpec(voxel_size=(40, 4, 4)),
                ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)),
            }
        ) +
        RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK) +
        Normalize(ArrayKeys.RAW) +
        IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi/2.0), subsample=8) +
        SimpleAugment(transpose_only_xy=True)
    )

    train_pipeline = (
        data_sources +
        RandomProvider() +
        ElasticAugment((4, 40, 40), (0., 2., 2.), (0, math.pi/2.0),
                       prob_slip=0.05, prob_shift=0.05, max_misalign=10,
                       subsample=8) +
        SimpleAugment(transpose_only_xy=True) +
        IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        DefectAugment(ArrayKeys.RAW,
                      prob_missing=0.03,
                      prob_low_contrast=0.01,
                      prob_artifact=0.03,
                      artifact_source=artifact_source,
                      artifacts=ArrayKeys.RAW,
                      artifacts_mask=ArrayKeys.ALPHA_MASK,
                      contrast_scale=0.5) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) +

        #GrowBoundary(steps=1) +
        #SplitAndRenumberSegmentationLabels() +
        #AddGtAffinities(malis.mknhood3d()) +
        AddBoundaryDistance(label_array_key=ArrayKeys.GT_LABELS,
                            distance_array_key=ArrayKeys.GT_DIST,
                            normalize='tanh',
                            normalize_args=dt_scaling_factor
                            ) +
        BalanceLabels(ArrayKeys.GT_LABELS, ArrayKeys.GT_SCALE, ArrayKeys.GT_MASK) +
        #BalanceByThreshold(
        #    labels=ArrayKeys.GT_DIST,
        #    scales= ArrayKeys.GT_SCALE) +
          #{
            #     ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE
            # },
            # {
            #     ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_MASK
            # }) +
        PreCache(
            cache_size=40,
            num_workers=10)+

        Train(
            'unet',
            optimizer=net_io_names['optimizer'],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names['raw']: ArrayKeys.RAW,
                net_io_names['gt_dist']: ArrayKeys.GT_DIST,
                net_io_names['loss_weights']: ArrayKeys.GT_SCALE
            },
            summary=net_io_names['summary'],
            log_dir='log',
            outputs={
                net_io_names['dist']: ArrayKeys.PREDICTED_DIST_LABELS
            },
            gradients={
                net_io_names['dist']: ArrayKeys.LOSS_GRADIENT
            }) +
        Snapshot({
                ArrayKeys.RAW:                   'volumes/raw',
                ArrayKeys.GT_LABELS:             'volumes/labels/gt_clefts',
                ArrayKeys.GT_DIST:               'volumes/labels/gt_clefts_dist',
                ArrayKeys.PREDICTED_DIST_LABELS: 'volumes/labels/pred_clefts_dist',
                ArrayKeys.LOSS_GRADIENT:         'volumes/loss_gradient',
            },
            every=500,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request) +

        PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
예제 #3
0
def train_until(max_iteration, data_sources, input_shape, output_shape,
                dt_scaling_factor, loss_name, labels):
    ArrayKey('RAW')
    ArrayKey('ALPHA_MASK')
    ArrayKey('GT_LABELS')
    ArrayKey('MASK')

    data_providers = []
    data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/{0:}.n5"
    voxel_size = Coordinate((2, 2, 2))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size

    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')
    for src in data_sources:
        n5_source = N5Source(
            os.path.join(data_dir.format(src)),
            datasets={
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_LABELS: 'volumes/labels/all',
                ArrayKeys.MASK: 'volumes/mask'
            },
            array_specs={ArrayKeys.MASK: ArraySpec(interpolatable=False)})
        data_providers.append(n5_source)

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    inputs = dict()
    inputs[net_io_names['raw']] = ArrayKeys.RAW
    outputs = dict()
    snapshot = dict()
    snapshot[ArrayKeys.RAW] = 'volumes/raw'
    snapshot[ArrayKeys.GT_LABELS] = 'volumes/labels/gt_labels'
    for label in labels:
        inputs[net_io_names['gt_' + label.labelname]] = label.gt_dist_key
        if label.scale_loss or label.scale_key is not None:
            inputs[net_io_names['w_' + label.labelname]] = label.scale_key
        outputs[net_io_names[label.labelname]] = label.pred_dist_key
        snapshot[
            label.gt_dist_key] = 'volumes/labels/gt_dist_' + label.labelname
        snapshot[label.
                 pred_dist_key] = 'volumes/labels/pred_dist_' + label.labelname

    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    snapshot_request = BatchRequest()

    request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size)
    request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size)
    request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size)

    for label in labels:
        request.add(label.gt_dist_key, output_size, voxel_size=voxel_size)
        snapshot_request.add(label.pred_dist_key,
                             output_size,
                             voxel_size=voxel_size)
        if label.scale_loss:
            request.add(label.scale_key, output_size, voxel_size=voxel_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) +
        RandomLocation(min_masked=0.5, mask=ArrayKeys.MASK
                       )  # chose a random location inside the provided arrays
        #Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data
        for provider in data_providers)

    train_pipeline = (
        data_sources + RandomProvider() + ElasticAugment(
            (100, 100, 100), (10., 10., 10.), (0, math.pi / 2.0),
            prob_slip=0,
            prob_shift=0,
            max_misalign=0,
            subsample=8) + SimpleAugment() +
        #ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) +
        IntensityAugment(ArrayKeys.RAW, 0.95, 1.05, -0.05, 0.05) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW))

    for label in labels:
        train_pipeline += AddDistance(label_array_key=ArrayKeys.GT_LABELS,
                                      distance_array_key=label.gt_dist_key,
                                      normalize='tanh',
                                      normalize_args=dt_scaling_factor,
                                      label_id=label.labelid)

    train_pipeline = (train_pipeline)
    for label in labels:
        if label.scale_loss:
            train_pipeline += BalanceByThreshold(label.gt_dist_key,
                                                 label.scale_key,
                                                 mask=ArrayKeys.MASK)
    train_pipeline = (train_pipeline +
                      PreCache(cache_size=40, num_workers=10) +
                      Train('build',
                            optimizer=net_io_names['optimizer'],
                            loss=net_io_names[loss_name],
                            inputs=inputs,
                            summary=net_io_names['summary'],
                            log_dir='log',
                            outputs=outputs,
                            gradients={}) +
                      Snapshot(snapshot,
                               every=500,
                               output_filename='batch_{iteration}.hdf',
                               output_dir='snapshots/',
                               additional_request=snapshot_request) +
                      PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
예제 #4
0
def train_until(
    max_iteration,
    cremi_dir,
    data_sources,
    input_shape,
    output_shape,
    dt_scaling_factor,
    loss_name,
    cache_size=10,
    num_workers=10,
):
    ArrayKey("RAW")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_SYN_LABELS")
    ArrayKey("GT_LABELS")
    ArrayKey("GT_MASK")
    ArrayKey("TRAINING_MASK")
    ArrayKey("GT_SYN_SCALE")
    ArrayKey("LOSS_GRADIENT")
    ArrayKey("GT_SYN_DIST")
    ArrayKey("PREDICTED_SYN_DIST")
    ArrayKey("GT_BDY_DIST")
    ArrayKey("PREDICTED_BDY_DIST")

    data_providers = []
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")

    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(cremi_dir, "sample_" + sample + "_cleftsorig.hdf"),
            datasets={
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_SYN_LABELS: "volumes/labels/clefts",
                ArrayKeys.GT_MASK: "volumes/masks/groundtruth",
                ArrayKeys.TRAINING_MASK: "volumes/masks/validation",
                ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids",
            },
            array_specs={ArrayKeys.GT_MASK: ArraySpec(interpolatable=False)},
        )

        data_providers.append(h5_source)

    # todo: dvid source

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    # input_size = Coordinate((132,)*3) * voxel_size
    # output_size = Coordinate((44,)*3) * voxel_size

    # specifiy which volumes should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_SYN_LABELS, output_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_BDY_DIST, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(ArrayKeys.GT_SYN_SCALE, output_size)
    request.add(ArrayKeys.GT_SYN_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None)
        + Pad(ArrayKeys.GT_MASK, None)
        + Pad(ArrayKeys.TRAINING_MASK, None)
        + RandomLocation()
        + Reject(  # chose a random location inside the provided arrays
            ArrayKeys.GT_MASK
        )
        + Reject(  # reject batches wich do contain less than 50% labelled data
            ArrayKeys.TRAINING_MASK, min_masked=0.99
        )
        + Reject(ArrayKeys.GT_LABELS, min_masked=0.0, reject_probability=0.95)
        for provider in data_providers
    )

    snapshot_request = BatchRequest(
        {
            ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_LABELS],
            ArrayKeys.PREDICTED_SYN_DIST: request[ArrayKeys.GT_SYN_LABELS],
            ArrayKeys.PREDICTED_BDY_DIST: request[ArrayKeys.GT_LABELS],
            ArrayKeys.LOSS_GRADIENT: request[ArrayKeys.GT_SYN_DIST],
        }
    )

    artifact_source = (
        Hdf5Source(
            os.path.join(cremi_dir, "sample_ABC_padded_20160501.defects.hdf"),
            datasets={
                ArrayKeys.RAW: "defect_sections/raw",
                ArrayKeys.ALPHA_MASK: "defect_sections/mask",
            },
            array_specs={
                ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)),
                ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)),
            },
        )
        + RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK)
        + Normalize(ArrayKeys.RAW)
        + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
        + ElasticAugment((4, 40, 40), (0, 2, 2), (0, math.pi / 2.0), subsample=8)
        + SimpleAugment(transpose_only=[1, 2])
    )

    train_pipeline = (
        data_sources
        + RandomProvider()
        + ElasticAugment(
            (4, 40, 40),
            (0.0, 2.0, 2.0),
            (0, math.pi / 2.0),
            prob_slip=0.05,
            prob_shift=0.05,
            max_misalign=10,
            subsample=8,
        )
        + SimpleAugment(transpose_only=[1, 2])
        + IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
        + DefectAugment(
            ArrayKeys.RAW,
            prob_missing=0.03,
            prob_low_contrast=0.01,
            prob_artifact=0.03,
            artifact_source=artifact_source,
            artifacts=ArrayKeys.RAW,
            artifacts_mask=ArrayKeys.ALPHA_MASK,
            contrast_scale=0.5,
        )
        + IntensityScaleShift(ArrayKeys.RAW, 2, -1)
        + ZeroOutConstSections(ArrayKeys.RAW)
        + GrowBoundary(ArrayKeys.GT_LABELS, ArrayKeys.GT_MASK, steps=1, only_xy=True)
        + AddBoundaryDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_BDY_DIST,
            normalize="tanh",
            normalize_args=100,
        )
        + AddDistance(
            label_array_key=ArrayKeys.GT_SYN_LABELS,
            distance_array_key=ArrayKeys.GT_SYN_DIST,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
        )
        + BalanceLabels(
            ArrayKeys.GT_SYN_LABELS, ArrayKeys.GT_SYN_SCALE, ArrayKeys.GT_MASK
        )
        + PreCache(cache_size=cache_size, num_workers=num_workers)
        + Train(
            "unet",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: ArrayKeys.RAW,
                net_io_names["gt_syn_dist"]: ArrayKeys.GT_SYN_DIST,
                net_io_names["gt_bdy_dist"]: ArrayKeys.GT_BDY_DIST,
                net_io_names["loss_weights"]: ArrayKeys.GT_SYN_SCALE,
                net_io_names["mask"]: ArrayKeys.GT_MASK,
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={
                net_io_names["syn_dist"]: ArrayKeys.PREDICTED_SYN_DIST,
                net_io_names["bdy_dist"]: ArrayKeys.PREDICTED_BDY_DIST,
            },
            gradients={net_io_names["syn_dist"]: ArrayKeys.LOSS_GRADIENT},
        )
        + Snapshot(
            {
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_SYN_LABELS: "volumes/labels/gt_clefts",
                ArrayKeys.GT_SYN_DIST: "volumes/labels/gt_clefts_dist",
                ArrayKeys.PREDICTED_SYN_DIST: "volumes/labels/pred_clefts_dist",
                ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient",
                ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids",
                ArrayKeys.PREDICTED_BDY_DIST: "volumes/labels/pred_bdy_dist",
                ArrayKeys.GT_BDY_DIST: "volumes/labels/gt_bdy_dist",
            },
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        )
        + PrintProfilingStats(every=50)
    )

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
예제 #5
0
def train_until(max_iteration, data_sources, input_shape, output_shape,
                dt_scaling_factor, loss_name):
    ArrayKey("RAW")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")

    ArrayKey("GT_SCALE")
    ArrayKey("LOSS_GRADIENT")
    ArrayKey("GT_DIST")
    ArrayKey("PREDICTED_DIST")

    data_providers = []
    fib25_dir = "/groups/saalfeld/saalfeldlab/larissa/data/gunpowder/fib25/"
    if "fib25h5" in data_sources:

        for volume_name in (
                "tstvol-520-1",
                "tstvol-520-2",
                "trvol-250-1",
                "trvol-250-2",
        ):
            h5_source = Hdf5Source(
                os.path.join(fib25_dir, volume_name + ".hdf"),
                datasets={
                    ArrayKeys.RAW: "volumes/raw",
                    ArrayKeys.GT_LABELS: "volumes/labels/clefts",
                    ArrayKeys.GT_MASK: "volumes/masks/groundtruth",
                },
                volume_specs={Array.GT_MASK: ArraySpec(interpolatable=False)},
            )
            data_providers.append(h5_source)

    fib19_dir = "/groups/saalfeld/saalfeldlab/larissa/fib19"
    # if 'fib19h5' in data_sources:
    #    for volume_name in ("trvol-250", "trvol-600"):
    #        h5_source = prepare_h5source(fib19_dir, volume_name)
    #        data_providers.append(h5_source)

    # todo: dvid source

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((8, 8, 8))
    input_size = Coordinate((196, ) * 3) * voxel_size
    output_size = Coordinate((92, ) * 3) * voxel_size
    # input_size = Coordinate((132,)*3) * voxel_size
    # output_size = Coordinate((44,)*3) * voxel_size

    # specifiy which volumes should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    # request.add(VolumeTypes.GT_SCALE, output_size)
    request.add(ArrayKeys.GT_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize() +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) +
        RandomLocation() + Reject(
        )  # chose a random location inside the provided volumes  # reject batches wich do contain less than 50% labelled data
        for provider in data_providers)

    snapshot_request = BatchRequest({
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_LABELS],
        ArrayKeys.PREDICTED_DIST:
        request[ArrayKeys.GT_LABELS],
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_DIST],
    })

    # artifact_source = (
    #    Hdf5Source(
    #        os.path.join(data_dir, 'sample_ABC_padded_20160501.defects.hdf'),
    #        datasets = {
    #            VolumeTypes.RAW: 'defect_sections/raw',
    #            VolumeTypes.ALPHA_MASK: 'defect_sections/mask',
    #        },
    #        volume_specs = {
    #            VolumeTypes.RAW: VolumeSpec(voxel_size=(40, 4, 4)),
    #            VolumeTypes.ALPHA_MASK: VolumeSpec(voxel_size=(40, 4, 4)),
    #        }
    #    ) +
    #    RandomLocation(min_masked=0.05, mask_volume_type=VolumeTypes.ALPHA_MASK) +
    #    Normalize() +
    #    IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
    #    ElasticAugment([4,40,40], [0,2,2], [0,math.pi/2.0], subsample=8) +
    #    SimpleAugment(transpose_only_xy=True)
    # )

    train_pipeline = (
        data_sources + RandomProvider() + ElasticAugment(
            [40, 40, 40],
            [2, 2, 2],
            [0, math.pi / 2.0],
            prob_slip=0.01,
            prob_shift=0.05,
            max_misalign=1,
            subsample=8,
        ) + SimpleAugment() +
        IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) + GrowBoundary(steps=1) +
        # SplitAndRenumberSegmentationLabels() +
        # AddGtAffinities(malis.mknhood3d()) +
        AddBoundaryDistance(
            label_volume_type=ArrayKeys.GT_LABELS,
            distance_volume_type=ArrayKeys.GT_DIST,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
        ) + BalanceLabels(ArrayKeys.GT_LABELs, ArrayKeys.GT_SCALE,
                          ArrayKeys.GT_MASK) +
        # BalanceByThreshold(
        #    labels=VolumeTypes.GT_DIST,
        #    scales= VolumeTypes.GT_SCALE) +
        # {
        #     VolumeTypes.GT_AFFINITIES: VolumeTypes.GT_SCALE
        # },
        # {
        #     VolumeTypes.GT_AFFINITIES: VolumeTypes.GT_MASK
        # }) +
        PreCache(cache_size=40, num_workers=10) +
        # DefectAugment(
        #    prob_missing=0.03,
        #    prob_low_contrast=0.01,
        #    prob_artifact=0.03,
        #    artifact_source=artifact_source,
        #    contrast_scale=0.5) +
        Train(
            "unet",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: ArrayKeys.RAW,
                net_io_names["gt_dist"]: ArrayKeys.GT_DIST,
                # net_io_names['loss_weights']: VolumeTypes.GT_SCALE
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={net_io_names["dist"]: ArrayKeys.PREDICTED_DIST},
            gradients={net_io_names["dist"]: ArrayKeys.LOSS_GRADIENT},
        ) + Snapshot(
            {
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids",
                ArrayKeys.GT_DIST: "volumes/labels/distances",
                ArrayKeys.PREDICTED_DIST: "volumes/labels/pred_distances",
                ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient",
            },
            every=1000,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        ) + PrintProfilingStats(every=10))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)
    print("Training finished")
예제 #6
0
def train_until(
    max_iteration, data_sources, input_shape, output_shape, dt_scaling_factor, loss_name
):
    raw = ArrayKey("RAW")
    # ArrayKey('ALPHA_MASK')
    clefts = ArrayKey("GT_LABELS")
    mask = ArrayKey("GT_MASK")
    scale = ArrayKey("GT_SCALE")
    # grad = ArrayKey('LOSS_GRADIENT')
    gt_dist = ArrayKey("GT_DIST")
    pred_dist = ArrayKey("PREDICTED_DIST")

    data_providers = []

    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")
    if trained_until >= max_iteration:
        return
    data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/fib19/mine/"
    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(data_dir, "cube{0:}.hdf".format(sample)),
            datasets={
                raw: "volumes/raw",
                clefts: "volumes/labels/clefts",
                mask: "/volumes/masks/groundtruth",
            },
            array_specs={mask: ArraySpec(interpolatable=False)},
        )
        data_providers.append(h5_source)

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((8, 8, 8))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    # input_size = Coordinate((132,)*3) * voxel_size
    # output_size = Coordinate((44,)*3) * voxel_size

    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(raw, input_size)
    request.add(clefts, output_size)
    request.add(mask, output_size)
    # request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(scale, output_size)
    request.add(gt_dist, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(raw, None)
        + RandomLocation()
        +  # chose a random location inside the provided arrays
        # Reject(ArrayKeys.GT_MASK) + # reject batches wich do contain less than 50% labelled data
        # Reject(ArrayKeys.TRAINING_MASK, min_masked=0.99) +
        Reject(mask=mask) + Reject(clefts, min_masked=0.0, reject_probability=0.95)
        for provider in data_providers
    )

    snapshot_request = BatchRequest({pred_dist: request[clefts]})

    train_pipeline = (
        data_sources
        + RandomProvider()
        + ElasticAugment(
            (40, 40, 40),
            (2.0, 2.0, 2.0),
            (0, math.pi / 2.0),
            prob_slip=0.01,
            prob_shift=0.01,
            max_misalign=1,
            subsample=8,
        )
        + SimpleAugment()
        + IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1)
        + IntensityScaleShift(raw, 2, -1)
        + ZeroOutConstSections(raw)
        +
        # GrowBoundary(steps=1) +
        # SplitAndRenumberSegmentationLabels() +
        # AddGtAffinities(malis.mknhood3d()) +
        AddDistance(
            label_array_key=clefts,
            distance_array_key=gt_dist,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
        )
        +
        # BalanceLabels(clefts, scale, mask) +
        BalanceByThreshold(labels=ArrayKeys.GT_DIST, scales=ArrayKeys.GT_SCALE)
        +
        # {
        #     ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE
        # },
        # {
        #     ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_MASK
        # }) +
        PreCache(cache_size=40, num_workers=10)
        + Train(
            "unet",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: raw,
                net_io_names["gt_dist"]: gt_dist,
                net_io_names["loss_weights"]: scale,
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={net_io_names["dist"]: pred_dist},
            gradients={},
        )
        + Snapshot(
            {
                raw: "volumes/raw",
                clefts: "volumes/labels/gt_clefts",
                gt_dist: "volumes/labels/gt_clefts_dist",
                pred_dist: "volumes/labels/pred_clefts_dist",
            },
            dataset_dtypes={clefts: np.uint64},
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        )
        + PrintProfilingStats(every=50)
    )

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration - trained_until):
            b.request_batch(request)

    print("Training finished")
def train_until(
        data_providers,
        affinity_neighborhood,
        meta_graph_filename,
        stop,
        input_shape,
        output_shape,
        loss,
        optimizer,
        tensor_affinities,
        tensor_affinities_mask,
        tensor_glia,
        tensor_glia_mask,
        summary,
        save_checkpoint_every,
        pre_cache_size,
        pre_cache_num_workers,
        snapshot_every,
        balance_labels,
        renumber_connected_components,
        network_inputs,
        ignore_labels_for_slip,
        grow_boundaries,
        mask_out_labels,
        snapshot_dir):

    ignore_keys_for_slip = (LABELS_KEY, GT_MASK_KEY, GT_GLIA_KEY, GLIA_MASK_KEY, UNLABELED_KEY) if ignore_labels_for_slip else ()

    defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects'
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    input_voxel_size = Coordinate((120, 12, 12)) * 3
    output_voxel_size = Coordinate((40, 36, 36)) * 3

    input_size = Coordinate(input_shape) * input_voxel_size
    output_size = Coordinate(output_shape) * output_voxel_size

    num_affinities = sum(len(nh) for nh in affinity_neighborhood)
    gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size))
    print("gt affinities size", gt_affinities_size)

    # TODO why is GT_AFFINITIES three-dimensional? compare to
    # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35
    # TODO Use glia scale somehow, probably not possible with tensorflow 1.3 because it does not know uint64...
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(RAW_KEY,             input_size,  voxel_size=input_voxel_size)
    request.add(LABELS_KEY,          output_size, voxel_size=output_voxel_size)
    request.add(GT_AFFINITIES_KEY,   output_size, voxel_size=output_voxel_size)
    request.add(AFFINITIES_MASK_KEY, output_size, voxel_size=output_voxel_size)
    request.add(GT_MASK_KEY,         output_size, voxel_size=output_voxel_size)
    request.add(GLIA_MASK_KEY,       output_size, voxel_size=output_voxel_size)
    request.add(GLIA_KEY,            output_size, voxel_size=output_voxel_size)
    request.add(GT_GLIA_KEY,         output_size, voxel_size=output_voxel_size)
    request.add(UNLABELED_KEY,       output_size, voxel_size=output_voxel_size)
    if balance_labels:
        request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size)
    # always balance glia labels!
    request.add(GLIA_SCALE_KEY, output_size, voxel_size=output_voxel_size)
    network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY
    network_inputs[tensor_glia_mask]       = GLIA_SCALE_KEY#GLIA_SCALE_KEY if balance_labels else GLIA_MASK_KEY

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(RAW_KEY, None) +
        Pad(GT_MASK_KEY, None) +
        Pad(GLIA_MASK_KEY, None) +
        Pad(LABELS_KEY, size=NETWORK_OUTPUT_SHAPE / 2, value=np.uint64(-3)) +
        Pad(GT_GLIA_KEY, size=NETWORK_OUTPUT_SHAPE / 2) +
        # Pad(LABELS_KEY, None) +
        # Pad(GT_GLIA_KEY, None) +
        RandomLocation() + # chose a random location inside the provided arrays
        Reject(mask=GT_MASK_KEY, min_masked=0.5) +
        Reject(mask=GLIA_MASK_KEY, min_masked=0.5) +
        MapNumpyArray(lambda array: np.require(array, dtype=np.int64), GT_GLIA_KEY) # this is necessary because gunpowder 1.3 only understands int64, not uint64

        for provider in data_providers)

    # TODO figure out what this is for
    snapshot_request = BatchRequest({
        LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY],
        AFFINITIES_KEY    : request[GT_AFFINITIES_KEY],
    })

    # no need to do anything here. random sections will be replaced with sections from this source (only raw)
    artifact_source = (
        Hdf5Source(
            os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                RAW_KEY        : 'defect_sections/raw',
                DEFECT_MASK_KEY : 'defect_sections/mask',
            },
            array_specs={
                RAW_KEY        : ArraySpec(voxel_size=input_voxel_size),
                DEFECT_MASK_KEY : ArraySpec(voxel_size=input_voxel_size),
            }
        ) +
        RandomLocation(min_masked=0.05, mask=DEFECT_MASK_KEY) +
        Normalize(RAW_KEY) +
        IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            subsample=8
        ) +
        SimpleAugment(transpose_only=[1,2])
    )

    train_pipeline  = data_sources
    train_pipeline += RandomProvider()

    train_pipeline += ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            augmentation_probability=0.5,
            subsample=8
        )

    # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: before misalign: ' % GT_MASK_KEY)
    train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip)
    # train_pipeline += Log.log_numpy_array_stats_after_process(GT_MASK_KEY, 'min', 'max', 'dtype', logging_prefix='%s: after  misalign: ' % GT_MASK_KEY)

    train_pipeline += SimpleAugment(transpose_only=[1,2])
    train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
    train_pipeline += DefectAugment(RAW_KEY,
                                    prob_missing=0.03,
                                    prob_low_contrast=0.01,
                                    prob_artifact=0.03,
                                    artifact_source=artifact_source,
                                    artifacts=RAW_KEY,
                                    artifacts_mask=DEFECT_MASK_KEY,
                                    contrast_scale=0.5)
    train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1)
    train_pipeline += ZeroOutConstSections(RAW_KEY)

    if grow_boundaries > 0:
        train_pipeline += GrowBoundary(LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True)

    _logger.info("Renumbering connected components? %s", renumber_connected_components)
    if renumber_connected_components:
        train_pipeline += RenumberConnectedComponents(labels=LABELS_KEY)

    train_pipeline += NewKeyFromNumpyArray(lambda array: 1 - array, GT_GLIA_KEY, UNLABELED_KEY)

    if len(mask_out_labels) > 0:
        train_pipeline += MaskOutLabels(label_key=LABELS_KEY, mask_key=GT_MASK_KEY, ids_to_be_masked=mask_out_labels)

    # labels_mask: anything that connects into labels_mask will be zeroed out
    # unlabelled: anyhing that points into unlabeled will have zero affinity;
    #             affinities within unlabelled will be masked out
    train_pipeline += AddAffinities(
            affinity_neighborhood=affinity_neighborhood,
            labels=LABELS_KEY,
            labels_mask=GT_MASK_KEY,
            affinities=GT_AFFINITIES_KEY,
            affinities_mask=AFFINITIES_MASK_KEY,
            unlabelled=UNLABELED_KEY
    )

    snapshot_datasets = {
        RAW_KEY: 'volumes/raw',
        LABELS_KEY: 'volumes/labels/neuron_ids',
        GT_AFFINITIES_KEY: 'volumes/affinities/gt',
        GT_GLIA_KEY: 'volumes/labels/glia_gt',
        UNLABELED_KEY: 'volumes/labels/unlabeled',
        AFFINITIES_KEY: 'volumes/affinities/prediction',
        LOSS_GRADIENT_KEY: 'volumes/loss_gradient',
        AFFINITIES_MASK_KEY: 'masks/affinities',
        GLIA_KEY: 'volumes/labels/glia_pred',
        GT_MASK_KEY: 'masks/gt',
        GLIA_MASK_KEY: 'masks/glia'}

    if balance_labels:
        train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY)
        snapshot_datasets[AFFINITIES_SCALE_KEY] = 'masks/affinity-scale'
    train_pipeline += BalanceLabels(labels=GT_GLIA_KEY, scales=GLIA_SCALE_KEY, mask=GLIA_MASK_KEY)
    snapshot_datasets[GLIA_SCALE_KEY] = 'masks/glia-scale'


    if (pre_cache_size > 0 and pre_cache_num_workers > 0):
        train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers)
    train_pipeline += Train(
            summary=summary,
            graph=meta_graph_filename,
            save_every=save_checkpoint_every,
            optimizer=optimizer,
            loss=loss,
            inputs=network_inputs,
            log_dir='log',
            outputs={tensor_affinities: AFFINITIES_KEY, tensor_glia: GLIA_KEY},
            gradients={tensor_affinities: LOSS_GRADIENT_KEY},
            array_specs={
                AFFINITIES_KEY       : ArraySpec(voxel_size=output_voxel_size),
                LOSS_GRADIENT_KEY    : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_MASK_KEY  : ArraySpec(voxel_size=output_voxel_size),
                GT_MASK_KEY          : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size),
                GLIA_MASK_KEY        : ArraySpec(voxel_size=output_voxel_size),
                GLIA_SCALE_KEY       : ArraySpec(voxel_size=output_voxel_size),
                GLIA_KEY             : ArraySpec(voxel_size=output_voxel_size)
            }
        )

    train_pipeline += Snapshot(
            snapshot_datasets,
            every=snapshot_every,
            output_filename='batch_{iteration}.hdf',
            output_dir=snapshot_dir,
            additional_request=snapshot_request,
            attributes_callback=Snapshot.default_attributes_callback())

    train_pipeline += PrintProfilingStats(every=50)

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(trained_until, stop):
            b.request_batch(request)

    print("Training finished")
예제 #8
0
def train_until(
        data_providers,
        affinity_neighborhood,
        meta_graph_filename,
        stop,
        input_shape,
        output_shape,
        loss,
        optimizer,
        tensor_affinities,
        tensor_affinities_nn,
        tensor_affinities_mask,
        summary,
        save_checkpoint_every,
        pre_cache_size,
        pre_cache_num_workers,
        snapshot_every,
        balance_labels,
        renumber_connected_components,
        network_inputs,
        ignore_labels_for_slip,
        grow_boundaries):

    ignore_keys_for_slip = (GT_LABELS_KEY, GT_MASK_KEY) if ignore_labels_for_slip else ()

    defect_dir = '/groups/saalfeld/home/hanslovskyp/experiments/quasi-isotropic/data/defects'
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    input_voxel_size = Coordinate((120, 12, 12)) * 3
    output_voxel_size = Coordinate((40, 36, 36)) * 3

    input_size = Coordinate(input_shape) * input_voxel_size
    output_size = Coordinate(output_shape) * output_voxel_size
    output_size_nn = Coordinate(s - 2 for s in output_shape) * output_voxel_size

    num_affinities = sum(len(nh) for nh in affinity_neighborhood)
    gt_affinities_size = Coordinate((num_affinities,) + tuple(s for s in output_size))
    print("gt affinities size", gt_affinities_size)

    # TODO why is GT_AFFINITIES three-dimensional? compare to
    # TODO https://github.com/funkey/gunpowder/blob/master/examples/cremi/train.py#L35
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(RAW_KEY,             input_size,     voxel_size=input_voxel_size)
    request.add(GT_LABELS_KEY,       output_size,    voxel_size=output_voxel_size)
    request.add(GT_AFFINITIES_KEY,   output_size,    voxel_size=output_voxel_size)
    request.add(AFFINITIES_MASK_KEY, output_size,    voxel_size=output_voxel_size)
    request.add(GT_MASK_KEY,         output_size,    voxel_size=output_voxel_size)
    request.add(AFFINITIES_NN_KEY,   output_size_nn, voxel_size=output_voxel_size)
    if balance_labels:
        request.add(AFFINITIES_SCALE_KEY, output_size, voxel_size=output_voxel_size)
    network_inputs[tensor_affinities_mask] = AFFINITIES_SCALE_KEY if balance_labels else AFFINITIES_MASK_KEY

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(RAW_KEY) + # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(RAW_KEY, None) +
        Pad(GT_MASK_KEY, None) +
        RandomLocation() + # chose a random location inside the provided arrays
        Reject(GT_MASK_KEY) + # reject batches wich do contain less than 50% labelled data
        Reject(GT_LABELS_KEY, min_masked=0.0, reject_probability=0.95)

        for provider in data_providers)

    # TODO figure out what this is for
    snapshot_request = BatchRequest({
        LOSS_GRADIENT_KEY : request[GT_AFFINITIES_KEY],
        AFFINITIES_KEY    : request[GT_AFFINITIES_KEY],
        AFFINITIES_NN_KEY : request[AFFINITIES_NN_KEY]
    })

    # no need to do anything here. random sections will be replaced with sections from this source (only raw)
    artifact_source = (
        Hdf5Source(
            os.path.join(defect_dir, 'sample_ABC_padded_20160501.defects.hdf'),
            datasets={
                RAW_KEY        : 'defect_sections/raw',
                ALPHA_MASK_KEY : 'defect_sections/mask',
            },
            array_specs={
                RAW_KEY        : ArraySpec(voxel_size=input_voxel_size),
                ALPHA_MASK_KEY : ArraySpec(voxel_size=input_voxel_size),
            }
        ) +
        RandomLocation(min_masked=0.05, mask=ALPHA_MASK_KEY) +
        Normalize(RAW_KEY) +
        IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            subsample=8
        ) +
        SimpleAugment(transpose_only=[1,2])
    )

    train_pipeline  = data_sources
    train_pipeline += RandomProvider()
    train_pipeline += ElasticAugment(
            voxel_size=(360, 36, 36),
            control_point_spacing=(4, 40, 40),
            control_point_displacement_sigma=(0, 2 * 36, 2 * 36),
            rotation_interval=(0, math.pi / 2.0),
            augmentation_probability=0.5,
            subsample=8
        )
    train_pipeline += Misalign(z_resolution=360, prob_slip=0.05, prob_shift=0.05, max_misalign=(360,) * 2, ignore_keys_for_slip=ignore_keys_for_slip)
    train_pipeline += SimpleAugment(transpose_only=[1,2])
    train_pipeline += IntensityAugment(RAW_KEY, 0.9, 1.1, -0.1, 0.1, z_section_wise=True)
    train_pipeline += DefectAugment(RAW_KEY,
                      prob_missing=0.03,
                      prob_low_contrast=0.01,
                      prob_artifact=0.03,
                      artifact_source=artifact_source,
                      artifacts=RAW_KEY,
                      artifacts_mask=ALPHA_MASK_KEY,
                      contrast_scale=0.5)
    train_pipeline += IntensityScaleShift(RAW_KEY, 2, -1)
    train_pipeline += ZeroOutConstSections(RAW_KEY)
    if grow_boundaries > 0:
        train_pipeline += GrowBoundary(GT_LABELS_KEY, GT_MASK_KEY, steps=grow_boundaries, only_xy=True)

    if renumber_connected_components:
        train_pipeline += RenumberConnectedComponents(labels=GT_LABELS_KEY)

    train_pipeline += AddAffinities(
            affinity_neighborhood=affinity_neighborhood,
            labels=GT_LABELS_KEY,
            labels_mask=GT_MASK_KEY,
            affinities=GT_AFFINITIES_KEY,
            affinities_mask=AFFINITIES_MASK_KEY
        )

    if balance_labels:
        train_pipeline += BalanceLabels(labels=GT_AFFINITIES_KEY, scales=AFFINITIES_SCALE_KEY, mask=AFFINITIES_MASK_KEY)

    train_pipeline += PreCache(cache_size=pre_cache_size, num_workers=pre_cache_num_workers)
    train_pipeline += Train(
            summary=summary,
            graph=meta_graph_filename,
            save_every=save_checkpoint_every,
            optimizer=optimizer,
            loss=loss,
            inputs=network_inputs,
            log_dir='log',
            outputs={tensor_affinities: AFFINITIES_KEY, tensor_affinities_nn: AFFINITIES_NN_KEY},
            gradients={tensor_affinities: LOSS_GRADIENT_KEY},
            array_specs={
                AFFINITIES_KEY       : ArraySpec(voxel_size=output_voxel_size),
                LOSS_GRADIENT_KEY    : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_MASK_KEY  : ArraySpec(voxel_size=output_voxel_size),
                GT_MASK_KEY          : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_SCALE_KEY : ArraySpec(voxel_size=output_voxel_size),
                AFFINITIES_NN_KEY    : ArraySpec(voxel_size=output_voxel_size)
            }
        )
    train_pipeline += Snapshot(
            dataset_names={
                RAW_KEY             : 'volumes/raw',
                GT_LABELS_KEY       : 'volumes/labels/neuron_ids',
                GT_AFFINITIES_KEY   : 'volumes/affinities/gt',
                AFFINITIES_KEY      : 'volumes/affinities/prediction',
                LOSS_GRADIENT_KEY   : 'volumes/loss_gradient',
                AFFINITIES_MASK_KEY : 'masks/affinities',
                AFFINITIES_NN_KEY   : 'volumes/affinities/prediction-nn'
            },
            every=snapshot_every,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request,
            attributes_callback=Snapshot.default_attributes_callback())
    train_pipeline += PrintProfilingStats(every=50)

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(trained_until, stop):
            b.request_batch(request)

    print("Training finished")
예제 #9
0
def train_until(max_iteration, data_sources, input_shape, output_shape):
    ArrayKey('RAW')
    ArrayKey('PRED_RAW')
    data_providers = []
    data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/superresolution/{0:}.n5"
    voxel_size = Coordinate((4,4,4))
    input_size = Coordinate(input_shape)*voxel_size
    output_size = Coordinate(output_shape)*voxel_size

    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)
        
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size)
    
    snapshot_request = BatchRequest()
    snapshot_request.add(ArrayKeys.PRED_RAW, output_size, voxel_size=voxel_size)
    
    # load latest ckpt for weights if available 
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')
    
    # construct DAG
    for src in data_sources:
        n5_source = N5Source(
            data_dir.format(src),
            datasets={
                ArrayKeys.RAW: 'volumes/raw'
            }
        )
        data_providers.append(n5_source)
        
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +
        Pad(ArrayKeys.RAW, Coordinate((400,400,400))) +
        RandomLocation()
        for provider in data_providers
    )
    
    train_pipeline = (
        data_sources +
        ElasticAugment((100,100,100), (10., 10., 10.), (0, math.pi/2.0),
                       prob_slip=0, prob_shift=0, max_misalign=0,
                       subsample=8) +
        SimpleAugment() + 
        ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) +
        IntensityAugment(ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1) + 
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) +
        PreCache(cache_size=40, num_workers=10) +
        Train('unet',
              optimizer=net_io_names['optimizer'],
              loss=net_io_names['loss'],
              inputs={
                  net_io_names['raw']:ArrayKeys.RAW
              },
              summary=net_io_names['summary'],
              log_dir='log',
              outputs={
                  net_io_names['pred_raw']:ArrayKeys.PRED_RAW
              },
              gradients={}
              )+
        Snapshot({ArrayKeys.RAW: 'volumes/raw', ArrayKeys.PRED_RAW: 'volumes/pred_raw'},
                 every=500,
                 output_filename='batch_{iteration}.hdf',
                 output_dir='snapshots/',
                 additional_request=snapshot_request) +
        PrintProfilingStats(every=50)
        )
        # no intensity augment cause currently can't apply the same to both in and out
        

    
    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
예제 #10
0
def train_until(max_iteration, data_sources, input_shape, output_shape,
                dt_scaling_factor, loss_name):
    ArrayKey("RAW")
    ArrayKey("RAW_UP")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("MASK")
    ArrayKey("MASK_UP")
    ArrayKey("GT_DIST_CENTROSOME")
    ArrayKey("GT_DIST_GOLGI")
    ArrayKey("GT_DIST_GOLGI_MEM")
    ArrayKey("GT_DIST_ER")
    ArrayKey("GT_DIST_ER_MEM")
    ArrayKey("GT_DIST_MVB")
    ArrayKey("GT_DIST_MVB_MEM")
    ArrayKey("GT_DIST_MITO")
    ArrayKey("GT_DIST_MITO_MEM")
    ArrayKey("GT_DIST_LYSOSOME")
    ArrayKey("GT_DIST_LYSOSOME_MEM")

    ArrayKey("PRED_DIST_CENTROSOME")
    ArrayKey("PRED_DIST_GOLGI")
    ArrayKey("PRED_DIST_GOLGI_MEM")
    ArrayKey("PRED_DIST_ER")
    ArrayKey("PRED_DIST_ER_MEM")
    ArrayKey("PRED_DIST_MVB")
    ArrayKey("PRED_DIST_MVB_MEM")
    ArrayKey("PRED_DIST_MITO")
    ArrayKey("PRED_DIST_MITO_MEM")
    ArrayKey("PRED_DIST_LYSOSOME")
    ArrayKey("PRED_DIST_LYSOSOME_MEM")

    ArrayKey("SCALE_CENTROSOME")
    ArrayKey("SCALE_GOLGI")
    ArrayKey("SCALE_GOLGI_MEM")
    ArrayKey("SCALE_ER")
    ArrayKey("SCALE_ER_MEM")
    ArrayKey("SCALE_MVB")
    ArrayKey("SCALE_MVB_MEM")
    ArrayKey("SCALE_MITO")
    ArrayKey("SCALE_MITO_MEM")
    ArrayKey("SCALE_LYSOSOME")
    ArrayKey("SCALE_LYSOSOME_MEM")

    data_providers = []
    data_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cell/{0:}.n5"
    voxel_size_up = Coordinate((4, 4, 4))
    voxel_size_orig = Coordinate((8, 8, 8))
    input_size = Coordinate(input_shape) * voxel_size_orig
    output_size = Coordinate(output_shape) * voxel_size_orig

    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")
    for src in data_sources:
        n5_source = N5Source(
            os.path.join(data_dir.format(src)),
            datasets={
                ArrayKeys.RAW_UP: "volumes/raw",
                ArrayKeys.GT_LABELS: "volumes/labels/all",
                ArrayKeys.MASK_UP: "volumes/mask",
            },
            array_specs={ArrayKeys.MASK_UP: ArraySpec(interpolatable=False)},
        )
        data_providers.append(n5_source)

    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.RAW_UP, input_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK_UP, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_CENTROSOME,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_GOLGI,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_GOLGI_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_ER, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_ER_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_MVB, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_MVB_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_MITO,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_MITO_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_LYSOSOME,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.GT_DIST_LYSOSOME_MEM,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_CENTROSOME,
                output_size,
                voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_GOLGI, output_size, voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_GOLGI_MEM, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_ER, output_size, voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_ER_MEM, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_MVB, output_size, voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_MVB_MEM, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_MITO, output_size, voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_MITO_MEM, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.SCALE_LYSOSOME,
                output_size,
                voxel_size=voxel_size_orig)
    # request.add(ArrayKeys.SCALE_LYSOSOME_MEM, output_size, voxel_size=voxel_size_orig)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW_UP) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW_UP, None) +
        RandomLocation(min_masked=0.5, mask=ArrayKeys.MASK_UP
                       )  # chose a random location inside the provided arrays
        # Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data
        for provider in data_providers)

    snapshot_request = BatchRequest()
    snapshot_request.add(ArrayKeys.PRED_DIST_CENTROSOME, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_GOLGI, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_GOLGI_MEM, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_ER, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_ER_MEM, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_MVB, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_MVB_MEM, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_MITO, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_MITO_MEM, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_LYSOSOME, output_size)
    snapshot_request.add(ArrayKeys.PRED_DIST_LYSOSOME_MEM, output_size)
    train_pipeline = (
        data_sources + RandomProvider() + ElasticAugment(
            (100, 100, 100),
            (10.0, 10.0, 10.0),
            (0, math.pi / 2.0),
            prob_slip=0,
            prob_shift=0,
            max_misalign=0,
            subsample=8,
        ) + SimpleAugment() + ElasticAugment(
            (40, 1000, 1000), (10.0, 0.0, 0.0), (0, 0), subsample=8) +
        IntensityAugment(ArrayKeys.RAW_UP, 0.9, 1.1, -0.1, 0.1) +
        IntensityScaleShift(ArrayKeys.RAW_UP, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW_UP) +
        # GrowBoundary(steps=1) +
        # SplitAndRenumberSegmentationLabels() +
        # AddGtAffinities(malis.mknhood3d()) +
        AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_CENTROSOME,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=1,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_GOLGI,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(2, 11),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_GOLGI_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=11,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_ER,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(3, 10),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_ER_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=10,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_MVB,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(4, 9),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_MVB_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=9,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_MITO,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(5, 8),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_MITO_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=8,
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_LYSOSOME,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=(6, 7),
            factor=2,
        ) + AddDistance(
            label_array_key=ArrayKeys.GT_LABELS,
            distance_array_key=ArrayKeys.GT_DIST_LYSOSOME_MEM,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            label_id=7,
            factor=2,
        ) + DownSample(ArrayKeys.MASK_UP, 2, ArrayKeys.MASK) +
        BalanceByThreshold(
            ArrayKeys.GT_DIST_CENTROSOME,
            ArrayKeys.SCALE_CENTROSOME,
            mask=ArrayKeys.MASK,
        ) + BalanceByThreshold(ArrayKeys.GT_DIST_GOLGI,
                               ArrayKeys.SCALE_GOLGI,
                               mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_GOLGI_MEM, ArrayKeys.SCALE_GOLGI_MEM, mask=ArrayKeys.MASK) +
        BalanceByThreshold(
            ArrayKeys.GT_DIST_ER, ArrayKeys.SCALE_ER, mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_ER_MEM, ArrayKeys.SCALE_ER_MEM, mask=ArrayKeys.MASK) +
        BalanceByThreshold(
            ArrayKeys.GT_DIST_MVB, ArrayKeys.SCALE_MVB, mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_MVB_MEM, ArrayKeys.SCALE_MVB_MEM, mask=ArrayKeys.MASK) +
        BalanceByThreshold(ArrayKeys.GT_DIST_MITO,
                           ArrayKeys.SCALE_MITO,
                           mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_MITO_MEM, ArrayKeys.SCALE_MITO_MEM, mask=ArrayKeys.MASK) +
        BalanceByThreshold(ArrayKeys.GT_DIST_LYSOSOME,
                           ArrayKeys.SCALE_LYSOSOME,
                           mask=ArrayKeys.MASK) +
        # BalanceByThreshold(ArrayKeys.GT_DIST_LYSOSOME_MEM, ArrayKeys.SCALE_LYSOSOME_MEM, mask=ArrayKeys.MASK) +
        # BalanceByThreshold(
        #    labels=ArrayKeys.GT_DIST,
        #    scales= ArrayKeys.GT_SCALE) +
        # {
        #     ArrayKeys.GT_AFFINITIES: ArrayKeys.GT_SCALE
        # },
        # {
        #     ArrayKeys.GT_AFFINITIES: ArrayKeys.MASK
        # }) +
        DownSample(ArrayKeys.RAW_UP, 2, ArrayKeys.RAW) +
        PreCache(cache_size=40, num_workers=10) + Train(
            "build",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: ArrayKeys.RAW,
                net_io_names["gt_centrosome"]: ArrayKeys.GT_DIST_CENTROSOME,
                net_io_names["gt_golgi"]: ArrayKeys.GT_DIST_GOLGI,
                net_io_names["gt_golgi_mem"]: ArrayKeys.GT_DIST_GOLGI_MEM,
                net_io_names["gt_er"]: ArrayKeys.GT_DIST_ER,
                net_io_names["gt_er_mem"]: ArrayKeys.GT_DIST_ER_MEM,
                net_io_names["gt_mvb"]: ArrayKeys.GT_DIST_MVB,
                net_io_names["gt_mvb_mem"]: ArrayKeys.GT_DIST_MVB_MEM,
                net_io_names["gt_mito"]: ArrayKeys.GT_DIST_MITO,
                net_io_names["gt_mito_mem"]: ArrayKeys.GT_DIST_MITO_MEM,
                net_io_names["gt_lysosome"]: ArrayKeys.GT_DIST_LYSOSOME,
                net_io_names["gt_lysosome_mem"]:
                ArrayKeys.GT_DIST_LYSOSOME_MEM,
                net_io_names["w_centrosome"]: ArrayKeys.SCALE_CENTROSOME,
                net_io_names["w_golgi"]: ArrayKeys.SCALE_GOLGI,
                net_io_names["w_golgi_mem"]: ArrayKeys.SCALE_GOLGI,
                net_io_names["w_er"]: ArrayKeys.SCALE_ER,
                net_io_names["w_er_mem"]: ArrayKeys.SCALE_ER,
                net_io_names["w_mvb"]: ArrayKeys.SCALE_MVB,
                net_io_names["w_mvb_mem"]: ArrayKeys.SCALE_MVB,
                net_io_names["w_mito"]: ArrayKeys.SCALE_MITO,
                net_io_names["w_mito_mem"]: ArrayKeys.SCALE_MITO,
                net_io_names["w_lysosome"]: ArrayKeys.SCALE_LYSOSOME,
                net_io_names["w_lysosome_mem"]: ArrayKeys.SCALE_LYSOSOME,
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={
                net_io_names["centrosome"]: ArrayKeys.PRED_DIST_CENTROSOME,
                net_io_names["golgi"]: ArrayKeys.PRED_DIST_GOLGI,
                net_io_names["golgi_mem"]: ArrayKeys.PRED_DIST_GOLGI_MEM,
                net_io_names["er"]: ArrayKeys.PRED_DIST_ER,
                net_io_names["er_mem"]: ArrayKeys.PRED_DIST_ER_MEM,
                net_io_names["mvb"]: ArrayKeys.PRED_DIST_MVB,
                net_io_names["mvb_mem"]: ArrayKeys.PRED_DIST_MVB_MEM,
                net_io_names["mito"]: ArrayKeys.PRED_DIST_MITO,
                net_io_names["mito_mem"]: ArrayKeys.PRED_DIST_MITO_MEM,
                net_io_names["lysosome"]: ArrayKeys.PRED_DIST_LYSOSOME,
                net_io_names["lysosome_mem"]: ArrayKeys.PRED_DIST_LYSOSOME_MEM,
            },
            gradients={},
        ) + Snapshot(
            {
                ArrayKeys.RAW:
                "volumes/raw",
                ArrayKeys.GT_LABELS:
                "volumes/labels/gt_labels",
                ArrayKeys.GT_DIST_CENTROSOME:
                "volumes/labels/gt_dist_centrosome",
                ArrayKeys.PRED_DIST_CENTROSOME:
                "volumes/labels/pred_dist_centrosome",
                ArrayKeys.GT_DIST_GOLGI:
                "volumes/labels/gt_dist_golgi",
                ArrayKeys.PRED_DIST_GOLGI:
                "volumes/labels/pred_dist_golgi",
                ArrayKeys.GT_DIST_GOLGI_MEM:
                "volumes/labels/gt_dist_golgi_mem",
                ArrayKeys.PRED_DIST_GOLGI_MEM:
                "volumes/labels/pred_dist_golgi_mem",
                ArrayKeys.GT_DIST_ER:
                "volumes/labels/gt_dist_er",
                ArrayKeys.PRED_DIST_ER:
                "volumes/labels/pred_dist_er",
                ArrayKeys.GT_DIST_ER_MEM:
                "volumes/labels/gt_dist_er_mem",
                ArrayKeys.PRED_DIST_ER_MEM:
                "volumes/labels/pred_dist_er_mem",
                ArrayKeys.GT_DIST_MVB:
                "volumes/labels/gt_dist_mvb",
                ArrayKeys.PRED_DIST_MVB:
                "volumes/labels/pred_dist_mvb",
                ArrayKeys.GT_DIST_MVB_MEM:
                "volumes/labels/gt_dist_mvb_mem",
                ArrayKeys.PRED_DIST_MVB_MEM:
                "volumes/labels/pred_dist_mvb_mem",
                ArrayKeys.GT_DIST_MITO:
                "volumes/labels/gt_dist_mito",
                ArrayKeys.PRED_DIST_MITO:
                "volumes/labels/pred_dist_mito",
                ArrayKeys.GT_DIST_MITO_MEM:
                "volumes/labels/gt_dist_mito_mem",
                ArrayKeys.PRED_DIST_MITO_MEM:
                "volumes/labels/pred_dist_mito_mem",
                ArrayKeys.GT_DIST_LYSOSOME:
                "volumes/labels/gt_dist_lysosome",
                ArrayKeys.PRED_DIST_LYSOSOME:
                "volumes/labels/pred_dist_lysosome",
                ArrayKeys.GT_DIST_LYSOSOME_MEM:
                "volumes/labels/gt_dist_lysosome_mem",
                ArrayKeys.PRED_DIST_LYSOSOME_MEM:
                "volumes/labels/pred_dist_lysosome_mem",
            },
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        ) + PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
예제 #11
0
def train_until(
        max_iteration,
        cremi_dir,
        samples,
        n5_filename_format,
        csv_filename_format,
        filter_comments_pre,
        filter_comments_post,
        labels,
        net_name,
        input_shape,
        output_shape,
        loss_name,
        aug_mode,
        include_cleft=False,
        dt_scaling_factor=50,
        cache_size=5,
        num_workers=10,
        min_masked_voxels=17561.0,
        voxel_size=Coordinate((40, 4, 4)),
):
    '''
    Trains a network to predict signed distance boundaries of synapses.

    Args:
        max_iteration(int): The number of iterations to train the network.
        cremi_dir(str): The path to the directory containing n5 files for training.
        samples (:obj:`list` of :obj:`str`): The names of samples to train on. This is used as input to format the
            `n5_filename_format` and `csv_filename_format`.
        n5_filename_format(str): The format string for n5 files.
        csv_filename_format (str): The format string for n5 files.
        filter_comments_pre (:obj:`list` of :obj: `str`): A list of pre- or postsynaptic comments that should be
            excluded from the mapping of cleft ids to presynaptic neuron ids.
        filter_comments_post (:obj:`list` of :obj: `str`): A list of pre- or postsynaptic comments that should be
            excluded from the mapping of cleft ids to postsynaptic neuron ids.
        labels(:obj:`list` of :class:`Label`): The list of labels to be trained for.
        net_name(str): The name of the network, referring to the .meta file.
        input_shape(:obj:`tuple` of int): The shape of input arrays of the network.
        output_shape(:obj:`tuple` of int): The shape of output arrays of the network.
        loss_name (str): The name of the loss function as saved in the net_io_names.
        aug_mode (str): The augmentation mode ("deluxe", "classic" or "lite").
        include_cleft (boolean, optional): whether to include the whole cleft as part of the label when calculating
            the masked distance transform for pre-and postsynaptic sites
        dt_scaling_factor (int, optional): The factor for scaling the signed distance transform before applying tanh
            using formula tanh(distance_transform/dt_scaling_factor), default:50.
        cache_size (int, optional): The size of the cache for pulling batches, default: 5.
        num_workers(int, optional): The number of workers for pulling batches, default: 10.
        min_masked_voxels(Union(int,float), optional): The number of voxels that need to be contained in the groundtruth
            mask for a batch to be viable, default: 17561.
        voxel_size(:class:`Coordinate`): The voxel size of the input and output of the network.

    Returns:
        None.
    '''
    def label_filter(cond_f):
        return [ll for ll in labels if cond_f(ll)]

    def get_label(name):
        filter = label_filter(lambda l: l.labelname == name)
        if len(filter) > 0:
            return filter[0]
        else:
            return None

    def network_setup():
        # load net_io_names.json
        with open("net_io_names.json", "r") as f:
            net_io_names = json.load(f)

        # find checkpoint from previous training, start a new one if not found
        if tf.train.latest_checkpoint("."):
            start_iteration = int(
                tf.train.latest_checkpoint(".").split("_")[-1])
            if start_iteration >= max_iteration:
                logging.info(
                    "Network has already been trained for {0:} iterations".
                    format(start_iteration))
            else:
                logging.info(
                    "Resuming training from {0:}".format(start_iteration))
        else:
            start_iteration = 0
            logging.info("Starting fresh training")

        # define network inputs
        inputs = dict()
        inputs[net_io_names["raw"]] = ak_raw
        inputs[net_io_names["mask"]] = ak_training
        for label in labels:
            inputs[net_io_names["mask_" + label.labelname]] = label.mask_key
            inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key
            if label.scale_loss or label.scale_key is not None:
                inputs[net_io_names["w_" + label.labelname]] = label.scale_key

        # define network outputs
        outputs = dict()
        for label in labels:
            outputs[net_io_names[label.labelname]] = label.pred_dist_key
        return net_io_names, start_iteration, inputs, outputs

    keep_thr = float(min_masked_voxels) / np.prod(output_shape)
    max_distance = 2.76 * dt_scaling_factor

    ak_raw = ArrayKey("RAW")
    ak_alpha = ArrayKey("ALPHA_MASK")
    ak_neurons = ArrayKey("GT_NEURONS")
    ak_training = ArrayKey("TRAINING_MASK")
    ak_integral = ArrayKey("INTEGRAL_MASK")
    ak_clefts = ArrayKey("GT_CLEFTS")

    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    pad_width = input_size - output_size + voxel_size * Coordinate(
        (20, 20, 20))
    crop_width = Coordinate((max_distance, ) * len(voxel_size))
    crop_width = crop_width // voxel_size
    if crop_width == 0:
        crop_width *= voxel_size
    else:
        crop_width = (crop_width + (1, ) * len(crop_width)) * voxel_size

    net_io_names, start_iteration, inputs, outputs = network_setup()

    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(ak_raw, input_size, voxel_size=voxel_size)
    request.add(ak_neurons, output_size, voxel_size=voxel_size)
    request.add(ak_clefts, output_size, voxel_size=voxel_size)
    request.add(ak_training, output_size, voxel_size=voxel_size)
    request.add(ak_integral, output_size, voxel_size=voxel_size)
    for l in labels:
        request.add(l.mask_key, output_size, voxel_size=voxel_size)
        request.add(l.scale_key, output_size, voxel_size=voxel_size)
        request.add(l.gt_dist_key, output_size, voxel_size=voxel_size)
    arrays_that_need_to_be_cropped = []
    arrays_that_need_to_be_cropped.append(ak_neurons)
    arrays_that_need_to_be_cropped.append(ak_clefts)
    for l in labels:
        arrays_that_need_to_be_cropped.append(l.mask_key)
        arrays_that_need_to_be_cropped.append(l.gt_dist_key)

    # specify specs for output
    array_specs_pred = dict()
    for l in labels:
        array_specs_pred[l.pred_dist_key] = ArraySpec(voxel_size=voxel_size,
                                                      interpolatable=True)

    snapshot_data = {
        ak_raw: "volumes/raw",
        ak_training: "volumes/masks/training",
        ak_clefts: "volumes/labels/gt_clefts",
        ak_neurons: "volumes/labels/gt_neurons",
        ak_integral: "volumes/masks/gt_integral"
    }

    # specify snapshot data layout
    for l in labels:
        snapshot_data[l.mask_key] = "volumes/masks/" + l.labelname
        snapshot_data[
            l.pred_dist_key] = "volumes/labels/pred_dist_" + l.labelname
        snapshot_data[l.gt_dist_key] = "volumes/labels/gt_dist_" + l.labelname

    # specify snapshot request
    snapshot_request_dict = {}
    for l in labels:
        snapshot_request_dict[l.pred_dist_key] = request[l.gt_dist_key]
    snapshot_request = BatchRequest(snapshot_request_dict)

    csv_files = [
        os.path.join(cremi_dir, csv_filename_format.format(sample))
        for sample in samples
    ]

    cleft_to_pre, cleft_to_post, cleft_to_pre_filtered, cleft_to_post_filtered = \
        make_cleft_to_prepostsyn_neuron_id_dict(csv_files, filter_comments_pre, filter_comments_post)

    data_providers = []

    for sample in samples:
        logging.info("Adding sample {0:}".format(sample))
        datasets = {
            ak_raw: "volumes/raw",
            ak_training: "volumes/masks/validation",
            ak_integral: "volumes/masks/groundtruth_integral",
            ak_clefts: "volumes/labels/clefts",
            ak_neurons: "volumes/labels/neuron_ids",
        }
        specs = {
            ak_clefts: ArraySpec(interpolatable=False),
            ak_training: ArraySpec(interpolatable=False),
            ak_integral: ArraySpec(interpolatable=False),
        }
        for l in labels:
            datasets[l.mask_key] = "volumes/masks/groundtruth"
            specs[l.mask_key] = ArraySpec(interpolatable=False)

        n5_source = ZarrSource(
            os.path.join(cremi_dir, n5_filename_format.format(sample)),
            datasets=datasets,
            array_specs=specs,
        )
        data_providers.append(n5_source)
    data_sources = []
    for provider in data_providers:
        provider += Normalize(ak_raw)
        provider += Pad(ak_training, pad_width)
        provider += Pad(ak_neurons, pad_width)
        for l in labels:
            provider += Pad(l.mask_key, pad_width)
        provider += IntensityScaleShift(ak_training, -1, 1)
        provider += RandomLocationWithIntegralMask(integral_mask=ak_integral,
                                                   min_masked=keep_thr)
        provider += Reject(ak_training, min_masked=0.999)
        provider += Reject(ak_clefts, min_masked=0.0, reject_probability=0.95)
        data_sources.append(provider)

    artifact_source = (
        Hdf5Source(
            os.path.join(cremi_dir, "sample_ABC_padded_20160501.defects.hdf"),
            datasets={
                ArrayKeys.RAW: "defect_sections/raw",
                ArrayKeys.ALPHA_MASK: "defect_sections/mask",
            },
            array_specs={
                ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)),
                ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)),
            },
        ) + RandomLocation(min_masked=0.05, mask=ak_alpha) +
        Normalize(ak_raw) + IntensityAugment(
            ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment((4, 40, 40), (0, 2, 2),
                       (0, math.pi / 2.0), subsample=8) +
        SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]))

    train_pipeline = tuple(data_sources) + RandomProvider()
    if aug_mode == "deluxe":
        slip_ignore = [ak_clefts, ak_training, ak_neurons, ak_integral]
        for l in labels:
            slip_ignore.append(l.mask_key)

        train_pipeline += fuse.ElasticAugment(
            (40, 4, 4),
            (4, 40, 40),
            (0.0, 2.0, 2.0),
            (0, math.pi / 2.0),
            spatial_dims=3,
            subsample=8,
        )
        train_pipeline += SimpleAugment(transpose_only=[1, 2],
                                        mirror_only=[1, 2])
        train_pipeline += fuse.Misalign(
            40,
            prob_slip=0.05,
            prob_shift=0.05,
            max_misalign=(10, 10),
            ignore_keys_for_slip=tuple(slip_ignore),
        )
        train_pipeline += IntensityAugment(ak_raw,
                                           0.9,
                                           1.1,
                                           -0.1,
                                           0.1,
                                           z_section_wise=True)
        train_pipeline += DefectAugment(
            ak_raw,
            prob_missing=0.03,
            prob_low_contrast=0.01,
            prob_artifact=0.03,
            artifact_source=artifact_source,
            artifacts=ak_raw,
            artifacts_mask=ak_alpha,
            contrast_scale=0.5,
        )

    elif aug_mode == "classic":
        train_pipeline += fuse.ElasticAugment(
            (40, 4, 4),
            (4, 40, 40),
            (0.0, 0.0, 0.0),
            (0, math.pi / 2.0),
            spatial_dims=3,
            subsample=8,
        )
        train_pipeline += fuse.SimpleAugment(transpose_only=[1, 2],
                                             mirror_only=[1, 2])
        train_pipeline += IntensityAugment(ak_raw,
                                           0.9,
                                           1.1,
                                           -0.1,
                                           0.1,
                                           z_section_wise=True)
        train_pipeline += DefectAugment(
            ak_raw,
            prob_missing=0.03,
            prob_low_contrast=0.01,
            prob_artifact=0.03,
            artifact_source=artifact_source,
            artifacts=ak_raw,
            artifacts_mask=ak_alpha,
            contrast_scale=0.5,
        )
    elif aug_mode == "lite":
        train_pipeline += fuse.ElasticAugment(
            (40, 4, 4),
            (4, 40, 40),
            (0.0, 0.0, 0.0),
            (0, math.pi / 2.0),
            spatial_dims=3,
            subsample=8,
        )
        train_pipeline += fuse.SimpleAugment(transpose_only=[1, 2],
                                             mirror_only=[1, 2])
        train_pipeline += IntensityAugment(ak_raw,
                                           0.9,
                                           1.1,
                                           -0.1,
                                           0.1,
                                           z_section_wise=False)

    else:
        pass
    train_pipeline += IntensityScaleShift(ak_raw, 2, -1)
    train_pipeline += ZeroOutConstSections(ak_raw)
    clefts = get_label("clefts")
    pre = get_label("pre")
    post = get_label("post")

    if clefts is not None or pre is not None or post is not None:
        train_pipeline += AddPrePostCleftDistance(
            ak_clefts,
            ak_neurons,
            clefts.gt_dist_key if clefts is not None else None,
            pre.gt_dist_key if pre is not None else None,
            post.gt_dist_key if post is not None else None,
            clefts.mask_key if post is not None else None,
            pre.mask_key if pre is not None else None,
            post.mask_key if post is not None else None,
            cleft_to_pre,
            cleft_to_post,
            cleft_to_presyn_neuron_id_filtered=cleft_to_pre_filtered,
            cleft_to_postsyn_neuron_id_filtered=cleft_to_post_filtered,
            bg_value=(0, 18446744073709551613),
            include_cleft=include_cleft,
            max_distance=2.76 * dt_scaling_factor,
        )

    for ak in arrays_that_need_to_be_cropped:
        train_pipeline += CropArray(ak, crop_width, crop_width)
    for l in labels:
        train_pipeline += TanhSaturate(l.gt_dist_key, dt_scaling_factor)
    for l in labels:
        train_pipeline += BalanceByThreshold(
            labels=l.gt_dist_key,
            scales=l.scale_key,
            mask=(l.mask_key, ak_training),
            threshold=l.thr,
        )

    train_pipeline += PreCache(cache_size=cache_size, num_workers=num_workers)
    train_pipeline += Train(
        net_name,
        optimizer=net_io_names["optimizer"],
        loss=net_io_names[loss_name],
        inputs=inputs,
        summary=net_io_names["summary"],
        log_dir="log",
        save_every=500,
        log_every=5,
        outputs=outputs,
        gradients={},
        array_specs=array_specs_pred,
    )
    train_pipeline += Snapshot(
        snapshot_data,
        every=500,
        output_filename="batch_{iteration}.hdf",
        output_dir="snapshots/",
        additional_request=snapshot_request,
    )
    train_pipeline += PrintProfilingStats(every=50)

    logging.info("Starting training...")
    with build(train_pipeline) as pp:
        for i in range(start_iteration, max_iteration + 1):
            start_it = time.time()
            pp.request_batch(request)
            time_it = time.time() - start_it
            logging.info("it{0:}: {1:}".format(i + 1, time_it))
    logging.info("Training finished")
예제 #12
0
def train_until(max_iteration, data_sources, input_shape, output_shape,
                dt_scaling_factor, loss_name, cremi_version, aligned):
    ArrayKey('RAW')
    ArrayKey('ALPHA_MASK')
    ArrayKey('GT_LABELS')
    ArrayKey('GT_CLEFTS')
    ArrayKey('GT_MASK')
    ArrayKey('TRAINING_MASK')
    ArrayKey('CLEFT_SCALE')
    ArrayKey('PRE_SCALE')
    ArrayKey('POST_SCALE')
    ArrayKey('LOSS_GRADIENT')
    ArrayKey('GT_CLEFT_DIST')
    ArrayKey('PRED_CLEFT_DIST')
    ArrayKey('GT_PRE_DIST')
    ArrayKey('PRED_PRE_DIST')
    ArrayKey('GT_POST_DIST')
    ArrayKey('PRED_POST_DIST')
    ArrayKey('GT_POST_DIST')
    data_providers = []
    if cremi_version == '2016':
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/"
        filename = 'sample_{0:}_padded_20160501.'
    elif cremi_version == '2017':
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/"
        filename = 'sample_{0:}_padded_20170424.'
    if aligned:
        filename += 'aligned.'
    filename += '0bg.hdf'
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')
    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(cremi_dir, filename.format(sample)),
            datasets={
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_CLEFTS: 'volumes/labels/clefts',
                ArrayKeys.GT_MASK: 'volumes/masks/groundtruth',
                ArrayKeys.TRAINING_MASK: 'volumes/masks/validation',
                ArrayKeys.GT_LABELS: 'volumes/labels/neuron_ids'
            },
            array_specs={
                ArrayKeys.GT_MASK: ArraySpec(interpolatable=False),
                ArrayKeys.GT_CLEFTS: ArraySpec(interpolatable=False),
                ArrayKeys.TRAINING_MASK: ArraySpec(interpolatable=False)
            })
        data_providers.append(h5_source)

    if cremi_version == '2017':
        csv_files = [
            os.path.join(cremi_dir, 'cleft-partners_' + sample + '_2017.csv')
            for sample in data_sources
        ]
    elif cremi_version == '2016':
        csv_files = [
            os.path.join(
                cremi_dir,
                'cleft-partners-' + sample + '-20160501.aligned.corrected.csv')
            for sample in data_sources
        ]
    cleft_to_pre, cleft_to_post = make_cleft_to_prepostsyn_neuron_id_dict(
        csv_files)
    print(cleft_to_pre, cleft_to_post)
    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    context = input_size - output_size
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_CLEFTS, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(ArrayKeys.CLEFT_SCALE, output_size)
    request.add(ArrayKeys.GT_CLEFT_DIST, output_size)
    request.add(ArrayKeys.GT_PRE_DIST, output_size)
    request.add(ArrayKeys.GT_POST_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        IntensityScaleShift(ArrayKeys.TRAINING_MASK, -1, 1) +
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) +
        Pad(ArrayKeys.TRAINING_MASK, context) +
        RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK)
        +  # chose a random location inside the provided arrays
        Reject(ArrayKeys.GT_MASK)
        +  # reject batches which do contain less than 50% labelled data
        Reject(ArrayKeys.GT_CLEFTS, min_masked=0.0, reject_probability=0.95)
        for provider in data_providers)

    snapshot_request = BatchRequest({
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_CLEFTS],
        ArrayKeys.PRED_CLEFT_DIST:
        request[ArrayKeys.GT_CLEFT_DIST],
        ArrayKeys.PRED_PRE_DIST:
        request[ArrayKeys.GT_PRE_DIST],
        ArrayKeys.PRED_POST_DIST:
        request[ArrayKeys.GT_POST_DIST],
    })

    train_pipeline = (
        data_sources + RandomProvider() + ElasticAugment(
            (4, 40, 40), (0., 0., 0.), (0, math.pi / 2.0),
            prob_slip=0.0,
            prob_shift=0.0,
            max_misalign=0,
            subsample=8) +
        SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]) +
        IntensityAugment(
            ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=False) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) +
        AddDistance(label_array_key=ArrayKeys.GT_CLEFTS,
                    distance_array_key=ArrayKeys.GT_CLEFT_DIST,
                    normalize='tanh',
                    normalize_args=dt_scaling_factor) +
        AddPrePostCleftDistance(ArrayKeys.GT_CLEFTS,
                                ArrayKeys.GT_LABELS,
                                ArrayKeys.GT_PRE_DIST,
                                ArrayKeys.GT_POST_DIST,
                                cleft_to_pre,
                                cleft_to_post,
                                normalize='tanh',
                                normalize_args=dt_scaling_factor,
                                include_cleft=False) +
        BalanceByThreshold(labels=ArrayKeys.GT_CLEFT_DIST,
                           scales=ArrayKeys.CLEFT_SCALE,
                           mask=ArrayKeys.GT_MASK) +
        BalanceByThreshold(labels=ArrayKeys.GT_PRE_DIST,
                           scales=ArrayKeys.PRE_SCALE,
                           mask=ArrayKeys.GT_MASK,
                           threshold=-0.5) +
        BalanceByThreshold(labels=ArrayKeys.GT_POST_DIST,
                           scales=ArrayKeys.POST_SCALE,
                           mask=ArrayKeys.GT_MASK,
                           threshold=-0.5) +
        PreCache(cache_size=40, num_workers=10) + Train(
            'unet',
            optimizer=net_io_names['optimizer'],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names['raw']: ArrayKeys.RAW,
                net_io_names['gt_cleft_dist']: ArrayKeys.GT_CLEFT_DIST,
                net_io_names['gt_pre_dist']: ArrayKeys.GT_PRE_DIST,
                net_io_names['gt_post_dist']: ArrayKeys.GT_POST_DIST,
                net_io_names['loss_weights_cleft']: ArrayKeys.CLEFT_SCALE,
                net_io_names['loss_weights_pre']: ArrayKeys.CLEFT_SCALE,
                net_io_names['loss_weights_post']: ArrayKeys.CLEFT_SCALE,
                net_io_names['mask']: ArrayKeys.GT_MASK
            },
            summary=net_io_names['summary'],
            log_dir='log',
            outputs={
                net_io_names['cleft_dist']: ArrayKeys.PRED_CLEFT_DIST,
                net_io_names['pre_dist']: ArrayKeys.PRED_PRE_DIST,
                net_io_names['post_dist']: ArrayKeys.PRED_POST_DIST
            },
            gradients={net_io_names['cleft_dist']: ArrayKeys.LOSS_GRADIENT}) +
        Snapshot(
            {
                ArrayKeys.RAW: 'volumes/raw',
                ArrayKeys.GT_CLEFTS: 'volumes/labels/gt_clefts',
                ArrayKeys.GT_CLEFT_DIST: 'volumes/labels/gt_clefts_dist',
                ArrayKeys.PRED_CLEFT_DIST: 'volumes/labels/pred_clefts_dist',
                ArrayKeys.LOSS_GRADIENT: 'volumes/loss_gradient',
                ArrayKeys.PRED_PRE_DIST: 'volumes/labels/pred_pre_dist',
                ArrayKeys.PRED_POST_DIST: 'volumes/labels/pred_post_dist',
                ArrayKeys.GT_PRE_DIST: 'volumes/labels/gt_pre_dist',
                ArrayKeys.GT_POST_DIST: 'volumes/labels/gt_post_dist'
            },
            every=500,
            output_filename='batch_{iteration}.hdf',
            output_dir='snapshots/',
            additional_request=snapshot_request) +
        PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")