def train_until(
    max_iteration,
    data_sources,
    ribo_sources,
    input_shape,
    output_shape,
    dt_scaling_factor,
    loss_name,
    labels,
    net_name,
    min_masked_voxels=17561.0,
    mask_ds_name="volumes/masks/training_cropped",
):
    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    ArrayKey("RAW")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("MASK")
    ArrayKey("RIBO_GT")

    voxel_size_up = Coordinate((2, 2, 2))
    voxel_size_orig = Coordinate((4, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size_orig
    output_size = Coordinate(output_shape) * voxel_size_orig
    # context = input_size-output_size

    keep_thr = float(min_masked_voxels) / np.prod(output_shape)

    data_providers = []
    inputs = dict()
    outputs = dict()
    snapshot = dict()
    request = BatchRequest()
    snapshot_request = BatchRequest()

    datasets_ribo = {
        ArrayKeys.RAW: "volumes/raw/data/s0",
        ArrayKeys.GT_LABELS: "volumes/labels/all",
        ArrayKeys.MASK: mask_ds_name,
        ArrayKeys.RIBO_GT: "volumes/labels/ribosomes",
    }
    # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all
    # instead (only one with the right resolution)
    datasets_no_ribo = {
        ArrayKeys.RAW: "volumes/raw/data/s0",
        ArrayKeys.GT_LABELS: "volumes/labels/all",
        ArrayKeys.MASK: mask_ds_name,
        ArrayKeys.RIBO_GT: "volumes/labels/all",
    }

    array_specs = {
        ArrayKeys.MASK: ArraySpec(interpolatable=False),
        ArrayKeys.RAW: ArraySpec(voxel_size=Coordinate(voxel_size_orig)),
    }
    array_specs_pred = {}

    inputs[net_io_names["raw"]] = ArrayKeys.RAW

    snapshot[ArrayKeys.RAW] = "volumes/raw"
    snapshot[ArrayKeys.GT_LABELS] = "volumes/labels/gt_labels"

    request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.RIBO_GT, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_orig)

    for label in labels:
        datasets_no_ribo[label.mask_key] = "volumes/masks/" + label.labelname
        datasets_ribo[label.mask_key] = "volumes/masks/" + label.labelname

        array_specs[label.mask_key] = ArraySpec(interpolatable=False)
        array_specs_pred[label.pred_dist_key] = ArraySpec(
            voxel_size=voxel_size_orig, interpolatable=True)

        inputs[net_io_names["mask_" + label.labelname]] = label.mask_key
        inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key
        if label.scale_loss or label.scale_key is not None:
            inputs[net_io_names["w_" + label.labelname]] = label.scale_key

        outputs[net_io_names[label.labelname]] = label.pred_dist_key

        snapshot[
            label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname
        snapshot[label.
                 pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname

        request.add(label.gt_dist_key, output_size, voxel_size=voxel_size_orig)
        request.add(label.pred_dist_key,
                    output_size,
                    voxel_size=voxel_size_orig)
        request.add(label.mask_key, output_size, voxel_size=voxel_size_orig)
        if label.scale_loss:
            request.add(label.scale_key,
                        output_size,
                        voxel_size=voxel_size_orig)

        snapshot_request.add(label.pred_dist_key,
                             output_size,
                             voxel_size=voxel_size_orig)

    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")

    for src in data_sources:
        if src not in ribo_sources:
            n5_source = N5Source(src.full_path,
                                 datasets=datasets_no_ribo,
                                 array_specs=array_specs)
        else:
            n5_source = N5Source(src.full_path,
                                 datasets=datasets_ribo,
                                 array_specs=array_specs)

        data_providers.append(n5_source)

    # create a tuple of data sources, one for each HDF file
    data_stream = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]
        # zero-pad provided RAW and MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        # Pad(ArrayKeys.RAW, context) +
        RandomLocation() +
        RejectEfficiently(  # chose a random location inside the provided arrays
            ArrayKeys.MASK,
            min_masked=keep_thr)
        # Reject(ArrayKeys.MASK) # reject batches wich do contain less than 50% labelled data
        for provider in data_providers)

    train_pipeline = (
        data_stream +
        RandomProvider(tuple([ds.labeled_voxels for ds in data_sources])) +
        gpn.SimpleAugment() + gpn.ElasticAugment(
            voxel_size_orig,
            (100, 100, 100),
            (10.0, 10.0, 10.0),
            (0, math.pi / 2.0),
            spatial_dims=3,
            subsample=8,
        ) +
        # ElasticAugment((40, 1000, 1000), (10., 0., 0.), (0, 0), subsample=8) +
        gpn.IntensityAugment(ArrayKeys.RAW, 0.25, 1.75, -0.5, 0.35) +
        GammaAugment(ArrayKeys.RAW, 0.5, 2.0) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1))
    # ZeroOutConstSections(ArrayKeys.RAW))

    for label in labels:
        if label.labelname != "ribosomes":
            train_pipeline += AddDistance(
                label_array_key=ArrayKeys.GT_LABELS,
                distance_array_key=label.gt_dist_key,
                normalize="tanh",
                normalize_args=dt_scaling_factor,
                label_id=label.labelid,
                factor=2,
            )
        else:
            train_pipeline += AddDistance(
                label_array_key=ArrayKeys.RIBO_GT,
                distance_array_key=label.gt_dist_key,
                normalize="tanh+",
                normalize_args=(dt_scaling_factor, 8),
                label_id=label.labelid,
                factor=2,
            )

    for label in labels:
        if label.scale_loss:
            train_pipeline += BalanceByThreshold(label.gt_dist_key,
                                                 label.scale_key,
                                                 mask=label.mask_key)

    train_pipeline = (train_pipeline +
                      PreCache(cache_size=30, num_workers=30) + Train(
                          net_name,
                          optimizer=net_io_names["optimizer"],
                          loss=net_io_names[loss_name],
                          inputs=inputs,
                          summary=net_io_names["summary"],
                          log_dir="log",
                          outputs=outputs,
                          gradients={},
                          log_every=5,
                          save_every=500,
                          array_specs=array_specs_pred,
                      ) + Snapshot(
                          snapshot,
                          every=500,
                          output_filename="batch_{iteration}.hdf",
                          output_dir="snapshots/",
                          additional_request=snapshot_request,
                      ) + PrintProfilingStats(every=500))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            start_it = time.time()
            b.request_batch(request)
            time_it = time.time() - start_it
            logging.info("it {0:}: {1:}".format(i + 1, time_it))

    print("Training finished")
示例#2
0
def train_until(
    max_iteration,
    data_sources,
    input_shape,
    output_shape,
    dt_scaling_factor,
    loss_name,
    cremi_version,
    aligned,
):
    ArrayKey("RAW")
    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("GT_CLEFTS")
    ArrayKey("GT_MASK")
    ArrayKey("TRAINING_MASK")
    ArrayKey("CLEFT_SCALE")
    ArrayKey("PRE_SCALE")
    ArrayKey("POST_SCALE")
    ArrayKey("LOSS_GRADIENT")
    ArrayKey("GT_CLEFT_DIST")
    ArrayKey("PRED_CLEFT_DIST")
    ArrayKey("GT_PRE_DIST")
    ArrayKey("PRED_PRE_DIST")
    ArrayKey("GT_POST_DIST")
    ArrayKey("PRED_POST_DIST")
    ArrayKey("GT_POST_DIST")
    data_providers = []
    if cremi_version == "2016":
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2016/"
        filename = "sample_{0:}_padded_20160501."
    elif cremi_version == "2017":
        cremi_dir = "/groups/saalfeld/saalfeldlab/larissa/data/cremi-2017/"
        filename = "sample_{0:}_padded_20170424."
    if aligned:
        filename += "aligned."
    filename += "0bg.hdf"
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")
    for sample in data_sources:
        print(sample)
        h5_source = Hdf5Source(
            os.path.join(cremi_dir, filename.format(sample)),
            datasets={
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_CLEFTS: "volumes/labels/clefts",
                ArrayKeys.GT_MASK: "volumes/masks/groundtruth",
                ArrayKeys.TRAINING_MASK: "volumes/masks/validation",
                ArrayKeys.GT_LABELS: "volumes/labels/neuron_ids",
            },
            array_specs={
                ArrayKeys.GT_MASK: ArraySpec(interpolatable=False),
                ArrayKeys.GT_CLEFTS: ArraySpec(interpolatable=False),
                ArrayKeys.TRAINING_MASK: ArraySpec(interpolatable=False),
            },
        )
        data_providers.append(h5_source)

    if cremi_version == "2017":
        csv_files = [
            os.path.join(cremi_dir, "cleft-partners_" + sample + "_2017.csv")
            for sample in data_sources
        ]
    elif cremi_version == "2016":
        csv_files = [
            os.path.join(
                cremi_dir,
                "cleft-partners-" + sample + "-20160501.aligned.corrected.csv",
            ) for sample in data_sources
        ]
    cleft_to_pre, cleft_to_post = make_cleft_to_prepostsyn_neuron_id_dict(
        csv_files)
    print(cleft_to_pre, cleft_to_post)
    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    voxel_size = Coordinate((40, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size
    output_size = Coordinate(output_shape) * voxel_size
    context = input_size - output_size
    # specifiy which Arrays should be requested for each batch
    request = BatchRequest()
    request.add(ArrayKeys.RAW, input_size)
    request.add(ArrayKeys.GT_LABELS, output_size)
    request.add(ArrayKeys.GT_CLEFTS, output_size)
    request.add(ArrayKeys.GT_MASK, output_size)
    request.add(ArrayKeys.TRAINING_MASK, output_size)
    request.add(ArrayKeys.CLEFT_SCALE, output_size)
    request.add(ArrayKeys.GT_CLEFT_DIST, output_size)
    request.add(ArrayKeys.GT_PRE_DIST, output_size)
    request.add(ArrayKeys.GT_POST_DIST, output_size)

    # create a tuple of data sources, one for each HDF file
    data_sources = tuple(
        provider + Normalize(ArrayKeys.RAW) +
        IntensityScaleShift(  # ensures RAW is in float in [0, 1]
            ArrayKeys.TRAINING_MASK, -1, 1) +
        # zero-pad provided RAW and GT_MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        Pad(ArrayKeys.RAW, None) + Pad(ArrayKeys.GT_MASK, None) +
        Pad(ArrayKeys.TRAINING_MASK, context) +
        RandomLocation(min_masked=0.99, mask=ArrayKeys.TRAINING_MASK) +
        Reject(  # chose a random location inside the provided arrays
            ArrayKeys.GT_MASK) +
        Reject(  # reject batches which do contain less than 50% labelled data
            ArrayKeys.GT_CLEFTS,
            min_masked=0.0,
            reject_probability=0.95) for provider in data_providers)

    snapshot_request = BatchRequest({
        ArrayKeys.LOSS_GRADIENT:
        request[ArrayKeys.GT_CLEFTS],
        ArrayKeys.PRED_CLEFT_DIST:
        request[ArrayKeys.GT_CLEFT_DIST],
        ArrayKeys.PRED_PRE_DIST:
        request[ArrayKeys.GT_PRE_DIST],
        ArrayKeys.PRED_POST_DIST:
        request[ArrayKeys.GT_POST_DIST],
    })

    artifact_source = (
        Hdf5Source(
            os.path.join(cremi_dir, "sample_ABC_padded_20160501.defects.hdf"),
            datasets={
                ArrayKeys.RAW: "defect_sections/raw",
                ArrayKeys.ALPHA_MASK: "defect_sections/mask",
            },
            array_specs={
                ArrayKeys.RAW: ArraySpec(voxel_size=(40, 4, 4)),
                ArrayKeys.ALPHA_MASK: ArraySpec(voxel_size=(40, 4, 4)),
            },
        ) + RandomLocation(min_masked=0.05, mask=ArrayKeys.ALPHA_MASK) +
        Normalize(ArrayKeys.RAW) + IntensityAugment(
            ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        ElasticAugment((4, 40, 40), (0, 2, 2),
                       (0, math.pi / 2.0), subsample=8) +
        SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2]))

    train_pipeline = (
        data_sources + RandomProvider() + SimpleAugment(
            transpose_only=[1, 2], mirror_only=[1, 2]) + gpn.ElasticAugment(
                (40, 4, 4),
                (4, 40, 40),
                (0.0, 2.0, 2.0),
                (0, math.pi / 2.0),
                spatial_dims=3,
                subsample=8,
            ) + gpn.Misalign(
                40,
                prob_slip=0.05,
                prob_shift=0.05,
                max_misalign=10,
                ignore_keys_for_slip=(
                    ArrayKeys.GT_CLEFTS,
                    ArrayKeys.GT_MASK,
                    ArrayKeys.TRAINING_MASK,
                    ArrayKeys.GT_LABELS,
                ),
            ) + IntensityAugment(
                ArrayKeys.RAW, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
        DefectAugment(
            ArrayKeys.RAW,
            prob_missing=0.03,
            prob_low_contrast=0.01,
            prob_artifact=0.03,
            artifact_source=artifact_source,
            artifacts=ArrayKeys.RAW,
            artifacts_mask=ArrayKeys.ALPHA_MASK,
            contrast_scale=0.5,
        ) + IntensityScaleShift(ArrayKeys.RAW, 2, -1) +
        ZeroOutConstSections(ArrayKeys.RAW) + AddDistance(
            label_array_key=ArrayKeys.GT_CLEFTS,
            distance_array_key=ArrayKeys.GT_CLEFT_DIST,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
        ) + AddPrePostCleftDistance(
            ArrayKeys.GT_CLEFTS,
            ArrayKeys.GT_LABELS,
            ArrayKeys.GT_PRE_DIST,
            ArrayKeys.GT_POST_DIST,
            cleft_to_pre,
            cleft_to_post,
            normalize="tanh",
            normalize_args=dt_scaling_factor,
            include_cleft=False,
        ) + BalanceByThreshold(
            labels=ArrayKeys.GT_CLEFT_DIST,
            scales=ArrayKeys.CLEFT_SCALE,
            mask=ArrayKeys.GT_MASK,
        ) + BalanceByThreshold(
            labels=ArrayKeys.GT_PRE_DIST,
            scales=ArrayKeys.PRE_SCALE,
            mask=ArrayKeys.GT_MASK,
            threshold=-0.5,
        ) + BalanceByThreshold(
            labels=ArrayKeys.GT_POST_DIST,
            scales=ArrayKeys.POST_SCALE,
            mask=ArrayKeys.GT_MASK,
            threshold=-0.5,
        ) + PreCache(cache_size=40, num_workers=10) + Train(
            "unet",
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs={
                net_io_names["raw"]: ArrayKeys.RAW,
                net_io_names["gt_cleft_dist"]: ArrayKeys.GT_CLEFT_DIST,
                net_io_names["gt_pre_dist"]: ArrayKeys.GT_PRE_DIST,
                net_io_names["gt_post_dist"]: ArrayKeys.GT_POST_DIST,
                net_io_names["loss_weights_cleft"]: ArrayKeys.CLEFT_SCALE,
                net_io_names["loss_weights_pre"]: ArrayKeys.CLEFT_SCALE,
                net_io_names["loss_weights_post"]: ArrayKeys.CLEFT_SCALE,
                net_io_names["mask"]: ArrayKeys.GT_MASK,
            },
            summary=net_io_names["summary"],
            log_dir="log",
            outputs={
                net_io_names["cleft_dist"]: ArrayKeys.PRED_CLEFT_DIST,
                net_io_names["pre_dist"]: ArrayKeys.PRED_PRE_DIST,
                net_io_names["post_dist"]: ArrayKeys.PRED_POST_DIST,
            },
            gradients={net_io_names["cleft_dist"]: ArrayKeys.LOSS_GRADIENT},
        ) + Snapshot(
            {
                ArrayKeys.RAW: "volumes/raw",
                ArrayKeys.GT_CLEFTS: "volumes/labels/gt_clefts",
                ArrayKeys.GT_CLEFT_DIST: "volumes/labels/gt_clefts_dist",
                ArrayKeys.PRED_CLEFT_DIST: "volumes/labels/pred_clefts_dist",
                ArrayKeys.LOSS_GRADIENT: "volumes/loss_gradient",
                ArrayKeys.PRED_PRE_DIST: "volumes/labels/pred_pre_dist",
                ArrayKeys.PRED_POST_DIST: "volumes/labels/pred_post_dist",
                ArrayKeys.GT_PRE_DIST: "volumes/labels/gt_pre_dist",
                ArrayKeys.GT_POST_DIST: "volumes/labels/gt_post_dist",
            },
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        ) + PrintProfilingStats(every=50))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)

    print("Training finished")
示例#3
0
def train_until(max_iteration, data_sources, ribo_sources, input_shape, output_shape, \
                                                          dt_scaling_factor, loss_name,
                labels, net_name, min_masked_voxels=17561., mask_ds_name='volumes/masks/training'):
    with open('net_io_names.json', 'r') as f:
        net_io_names = json.load(f)

    ArrayKey('RAW')
    ArrayKey('ALPHA_MASK')
    ArrayKey('GT_LABELS')
    ArrayKey('MASK')
    ArrayKey('RIBO_GT')

    voxel_size_up = Coordinate((2, 2, 2))
    voxel_size_input = Coordinate((8, 8, 8))
    voxel_size_output = Coordinate((4, 4, 4))
    input_size = Coordinate(input_shape) * voxel_size_input
    output_size = Coordinate(output_shape) * voxel_size_output
    # context = input_size-output_size

    keep_thr = float(min_masked_voxels) / np.prod(output_shape)

    data_providers = []
    inputs = dict()
    outputs = dict()
    snapshot = dict()
    request = BatchRequest()
    snapshot_request = BatchRequest()

    datasets_ribo = {
        ArrayKeys.RAW: None,
        ArrayKeys.GT_LABELS: 'volumes/labels/all',
        ArrayKeys.MASK: mask_ds_name,
        ArrayKeys.RIBO_GT: 'volumes/labels/ribosomes',
    }
    # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all
    # instead (only one with the right resolution)
    datasets_no_ribo = {
        ArrayKeys.RAW: None,
        ArrayKeys.GT_LABELS: 'volumes/labels/all',
        ArrayKeys.MASK: mask_ds_name,
        ArrayKeys.RIBO_GT: 'volumes/labels/all',
    }

    array_specs = {
        ArrayKeys.MASK: ArraySpec(interpolatable=False),
        ArrayKeys.RAW: ArraySpec(voxel_size=Coordinate(voxel_size_input))
    }
    array_specs_pred = {}

    inputs[net_io_names['raw']] = ArrayKeys.RAW

    snapshot[ArrayKeys.RAW] = 'volumes/raw'
    snapshot[ArrayKeys.GT_LABELS] = 'volumes/labels/gt_labels'

    request.add(ArrayKeys.GT_LABELS, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK, output_size, voxel_size=voxel_size_output)
    request.add(ArrayKeys.RIBO_GT, output_size, voxel_size=voxel_size_up)
    request.add(ArrayKeys.RAW, input_size, voxel_size=voxel_size_input)

    for label in labels:
        datasets_no_ribo[label.mask_key] = 'volumes/masks/' + label.labelname
        datasets_ribo[label.mask_key] = 'volumes/masks/' + label.labelname

        array_specs[label.mask_key] = ArraySpec(interpolatable=False)
        array_specs_pred[label.pred_dist_key] = ArraySpec(
            voxel_size=voxel_size_output, interpolatable=True)

        inputs[net_io_names['mask_' + label.labelname]] = label.mask_key
        inputs[net_io_names['gt_' + label.labelname]] = label.gt_dist_key
        if label.scale_loss or label.scale_key is not None:
            inputs[net_io_names['w_' + label.labelname]] = label.scale_key

        outputs[net_io_names[label.labelname]] = label.pred_dist_key

        snapshot[
            label.gt_dist_key] = 'volumes/labels/gt_dist_' + label.labelname
        snapshot[label.
                 pred_dist_key] = 'volumes/labels/pred_dist_' + label.labelname

        request.add(label.gt_dist_key,
                    output_size,
                    voxel_size=voxel_size_output)
        request.add(label.pred_dist_key,
                    output_size,
                    voxel_size=voxel_size_output)
        request.add(label.mask_key, output_size, voxel_size=voxel_size_output)
        if label.scale_loss:
            request.add(label.scale_key,
                        output_size,
                        voxel_size=voxel_size_output)

        snapshot_request.add(label.pred_dist_key,
                             output_size,
                             voxel_size=voxel_size_output)

    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
        print('Resuming training from', trained_until)
    else:
        trained_until = 0
        print('Starting fresh training')

    for src in data_sources:
        for subsample_variant in range(8):
            dnr = datasets_no_ribo.copy()
            dr = datasets_ribo.copy()
            dnr[ArrayKeys.RAW] = 'volumes/subsampled/raw{0:}/'.format(
                subsample_variant)
            dr[ArrayKeys.RAW] = 'volumes/subsampled/raw{0:}/'.format(
                subsample_variant)

            if src not in ribo_sources:
                n5_source = N5Source(src.full_path,
                                     datasets=dnr,
                                     array_specs=array_specs)
            else:
                n5_source = N5Source(src.full_path,
                                     datasets=dr,
                                     array_specs=array_specs)

            data_providers.append(n5_source)

    # create a tuple of data sources, one for each HDF file
    data_stream = tuple(
        provider +
        Normalize(ArrayKeys.RAW) +  # ensures RAW is in float in [0, 1]

        # zero-pad provided RAW and MASK to be able to draw batches close to
        # the boundary of the available data
        # size more or less irrelevant as followed by Reject Node
        # Pad(ArrayKeys.RAW, context) +
        RandomLocation()
        +  # chose a random location inside the provided arrays
        Reject(ArrayKeys.MASK, min_masked=keep_thr)
        for provider in data_providers)

    train_pipeline = (
        data_stream + RandomProvider(
            tuple(np.repeat([ds.labeled_voxels for ds in data_sources], 8))) +
        gpn.SimpleAugment() + gpn.ElasticAugment(voxel_size_output,
                                                 (100, 100, 100),
                                                 (10., 10., 10.),
                                                 (0, math.pi / 2.0),
                                                 spatial_dims=3,
                                                 subsample=8) +
        gpn.IntensityAugment(ArrayKeys.RAW, 0.25, 1.75, -0.5, 0.35) +
        GammaAugment(ArrayKeys.RAW, 0.5, 2.0) +
        IntensityScaleShift(ArrayKeys.RAW, 2, -1))

    for label in labels:
        if label.labelname != 'ribosomes':
            train_pipeline += AddDistance(label_array_key=ArrayKeys.GT_LABELS,
                                          distance_array_key=label.gt_dist_key,
                                          normalize='tanh',
                                          normalize_args=dt_scaling_factor,
                                          label_id=label.labelid,
                                          factor=2)
        else:
            train_pipeline += AddDistance(label_array_key=ArrayKeys.RIBO_GT,
                                          distance_array_key=label.gt_dist_key,
                                          normalize='tanh+',
                                          normalize_args=(dt_scaling_factor,
                                                          8),
                                          label_id=label.labelid,
                                          factor=2)

    for label in labels:
        if label.scale_loss:
            train_pipeline += BalanceByThreshold(label.gt_dist_key,
                                                 label.scale_key,
                                                 mask=label.mask_key)

    train_pipeline = (train_pipeline +
                      PreCache(cache_size=10, num_workers=20) +
                      Train(net_name,
                            optimizer=net_io_names['optimizer'],
                            loss=net_io_names[loss_name],
                            inputs=inputs,
                            summary=net_io_names['summary'],
                            log_dir='log',
                            outputs=outputs,
                            gradients={},
                            log_every=5,
                            save_every=500,
                            array_specs=array_specs_pred) +
                      Snapshot(snapshot,
                               every=500,
                               output_filename='batch_{iteration}.hdf',
                               output_dir='snapshots/',
                               additional_request=snapshot_request) +
                      PrintProfilingStats(every=500))

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            start_it = time.time()
            b.request_batch(request)
            time_it = time.time() - start_it
            logging.info('it {0:}: {1:}'.format(i + 1, time_it))

    print("Training finished")
示例#4
0
def train_until(
    max_iteration,
    data_sources,
    ribo_sources,
    dt_scaling_factor,
    loss_name,
    labels,
    scnet,
    raw_name="raw",
    min_masked_voxels=17561.0,
):
    with open("net_io_names.json", "r") as f:
        net_io_names = json.load(f)

    ArrayKey("ALPHA_MASK")
    ArrayKey("GT_LABELS")
    ArrayKey("MASK")
    ArrayKey("RIBO_GT")

    datasets_ribo = {
        ArrayKeys.GT_LABELS: "volumes/labels/all",
        ArrayKeys.MASK: "volumes/masks/training_cropped",
        ArrayKeys.RIBO_GT: "volumes/labels/ribosomes",
    }
    # for datasets without ribosome annotations volumes/labels/ribosomes doesn't exist, so use volumes/labels/all
    # instead (only one with the right resolution)
    datasets_no_ribo = {
        ArrayKeys.GT_LABELS: "volumes/labels/all",
        ArrayKeys.MASK: "volumes/masks/training_cropped",
        ArrayKeys.RIBO_GT: "volumes/labels/all",
    }
    array_specs = {ArrayKeys.MASK: ArraySpec(interpolatable=False)}
    array_specs_pred = {}

    # individual mask per label
    for label in labels:
        datasets_no_ribo[label.mask_key] = "volumes/masks/" + label.labelname
        datasets_ribo[label.mask_key] = "volumes/masks/" + label.labelname
        array_specs[label.mask_key] = ArraySpec(interpolatable=False)
    # inputs = {net_io_names['mask']: ArrayKeys.MASK}
    snapshot = {ArrayKeys.GT_LABELS: "volumes/labels/all"}

    request = BatchRequest()
    snapshot_request = BatchRequest()

    raw_array_keys = []
    contexts = []
    # input and output sizes in world coordinates
    input_sizes_wc = [
        Coordinate(inp_sh) * Coordinate(vs)
        for inp_sh, vs in zip(scnet.input_shapes, scnet.voxel_sizes)
    ]
    output_size_wc = Coordinate(scnet.output_shapes[0]) * Coordinate(
        scnet.voxel_sizes[0]
    )
    keep_thr = float(min_masked_voxels) / np.prod(scnet.output_shapes[0])

    voxel_size_up = Coordinate((2, 2, 2))
    voxel_size_orig = Coordinate((4, 4, 4))
    assert voxel_size_orig == Coordinate(
        scnet.voxel_sizes[0]
    )  # make sure that scnet has the same base voxel size
    inputs = {}
    # add multiscale raw data as inputs
    for k, (inp_sh_wc, vs) in enumerate(zip(input_sizes_wc, scnet.voxel_sizes)):
        ak = ArrayKey("RAW_S{0:}".format(k))
        raw_array_keys.append(ak)
        datasets_ribo[ak] = "volumes/{0:}/data/s{1:}".format(raw_name, k)
        datasets_no_ribo[ak] = "volumes/{0:}/data/s{1:}".format(raw_name, k)
        inputs[net_io_names["raw_{0:}".format(vs[0])]] = ak
        snapshot[ak] = "volumes/raw_s{0:}".format(k)
        array_specs[ak] = ArraySpec(voxel_size=Coordinate(vs))
        request.add(ak, inp_sh_wc, voxel_size=Coordinate(vs))
        contexts.append(inp_sh_wc - output_size_wc)
    outputs = dict()
    for label in labels:
        inputs[net_io_names["gt_" + label.labelname]] = label.gt_dist_key
        if label.scale_loss or label.scale_key is not None:
            inputs[net_io_names["w_" + label.labelname]] = label.scale_key
        inputs[net_io_names["mask_" + label.labelname]] = label.mask_key
        outputs[net_io_names[label.labelname]] = label.pred_dist_key
        snapshot[label.gt_dist_key] = "volumes/labels/gt_dist_" + label.labelname
        snapshot[label.pred_dist_key] = "volumes/labels/pred_dist_" + label.labelname
        array_specs_pred[label.pred_dist_key] = ArraySpec(
            voxel_size=voxel_size_orig, interpolatable=True
        )

    request.add(ArrayKeys.GT_LABELS, output_size_wc, voxel_size=voxel_size_up)
    request.add(ArrayKeys.MASK, output_size_wc, voxel_size=voxel_size_orig)
    request.add(ArrayKeys.RIBO_GT, output_size_wc, voxel_size=voxel_size_up)
    for label in labels:
        request.add(label.gt_dist_key, output_size_wc, voxel_size=voxel_size_orig)
        snapshot_request.add(
            label.pred_dist_key, output_size_wc, voxel_size=voxel_size_orig
        )
        request.add(label.pred_dist_key, output_size_wc, voxel_size=voxel_size_orig)
        request.add(label.mask_key, output_size_wc, voxel_size=voxel_size_orig)
        if label.scale_loss:
            request.add(label.scale_key, output_size_wc, voxel_size=voxel_size_orig)

    data_providers = []
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
        print("Resuming training from", trained_until)
    else:
        trained_until = 0
        print("Starting fresh training")
    for src in data_sources:

        if src not in ribo_sources:
            n5_source = N5Source(
                src.full_path, datasets=datasets_no_ribo, array_specs=array_specs
            )
        else:
            n5_source = N5Source(
                src.full_path, datasets=datasets_ribo, array_specs=array_specs
            )

        data_providers.append(n5_source)

    data_stream = []
    for provider in data_providers:
        data_stream.append(provider)
        for ak, context in zip(raw_array_keys, contexts):
            data_stream[-1] += Normalize(ak)
            # data_stream[-1] += Pad(ak, context) # this shouldn't be necessary as I cropped the input data to have
            # sufficient padding
        data_stream[-1] += RandomLocation()
        data_stream[-1] += Reject(ArrayKeys.MASK, min_masked=keep_thr)
    data_stream = tuple(data_stream)

    train_pipeline = (
        data_stream
        + RandomProvider(tuple([ds.labeled_voxels for ds in data_sources]))
        + gpn.SimpleAugment()
        + gpn.ElasticAugment(
            tuple(scnet.voxel_sizes[0]),
            (100, 100, 100),
            (10.0, 10.0, 10.0),
            (0, math.pi / 2.0),
            spatial_dims=3,
            subsample=8,
        )
        + gpn.IntensityAugment(raw_array_keys, 0.25, 1.75, -0.5, 0.35)
        + GammaAugment(raw_array_keys, 0.5, 2.0)
    )
    for ak in raw_array_keys:
        train_pipeline += IntensityScaleShift(ak, 2, -1)
        # train_pipeline += ZeroOutConstSections(ak)

    for label in labels:
        if label.labelname != "ribosomes":
            train_pipeline += AddDistance(
                label_array_key=ArrayKeys.GT_LABELS,
                distance_array_key=label.gt_dist_key,
                normalize="tanh",
                normalize_args=dt_scaling_factor,
                label_id=label.labelid,
                factor=2,
            )
        else:
            train_pipeline += AddDistance(
                label_array_key=ArrayKeys.RIBO_GT,
                distance_array_key=label.gt_dist_key,
                normalize="tanh+",
                normalize_args=(dt_scaling_factor, 8),
                label_id=label.labelid,
                factor=2,
            )

    for label in labels:
        if label.scale_loss:
            train_pipeline += BalanceByThreshold(
                label.gt_dist_key, label.scale_key, mask=label.mask_key
            )

    train_pipeline = (
        train_pipeline
        + PreCache(cache_size=10, num_workers=40)
        + Train(
            scnet.name,
            optimizer=net_io_names["optimizer"],
            loss=net_io_names[loss_name],
            inputs=inputs,
            summary=net_io_names["summary"],
            log_dir="log",
            outputs=outputs,
            gradients={},
            log_every=5,
            save_every=500,
            array_specs=array_specs_pred,
        )
        + Snapshot(
            snapshot,
            every=500,
            output_filename="batch_{iteration}.hdf",
            output_dir="snapshots/",
            additional_request=snapshot_request,
        )
        + PrintProfilingStats(every=10)
    )

    print("Starting training...")
    with build(train_pipeline) as b:
        for i in range(max_iteration):
            start_it = time.time()
            b.request_batch(request)
            time_it = time.time() - start_it
            logging.info("it {0:}: {1:}".format(i + 1, time_it))
    print("Training finished")