def train(iterations):

    # tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    if trained_until < phase_switch and iterations > phase_switch:
        train(phase_switch)

    phase = 'euclid' if iterations <= phase_switch else 'malis'
    print("Training in phase %s until %i" % (phase, iterations))

    # define array-keys
    labels_key = ArrayKey('LABELS')

    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    affs_key = ArrayKey('AFFINITIES')
    affs_opp_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_opp_key = ArrayKey('JOINED_AFFINITIES_OPP')
    merged_labels_key = ArrayKey('MERGED_LABELS')

    gt_affs_key = ArrayKey('GT_AFFINITIES')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')
    gt_affs_scale_key = ArrayKey('GT_AFFINITIES_SCALE')

    pred_affs_key = ArrayKey('PRED_AFFS')
    pred_affs_gradient_key = ArrayKey('PRED_AFFS_GRADIENT')

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    input_affs_shape = Coordinate([i + 1 for i in config['input_shape']
                                   ]) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in config['output_shape']
                                    ]) * voxel_size

    print("input_shape: ", input_shape)
    print("input_affs_shape: ", input_affs_shape)
    print("output_shape: ", output_shape)
    print("output_affs_shape: ", output_affs_shape)

    request = BatchRequest()
    request.add(labels_key, output_shape)

    request.add(raw_key, input_shape)
    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)

    request.add(affs_key, input_affs_shape)
    request.add(affs_opp_key, input_affs_shape)
    request.add(joined_affs_key, input_affs_shape)
    request.add(joined_affs_opp_key, input_affs_shape)
    request.add(merged_labels_key, output_shape)

    request.add(gt_affs_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(gt_affs_mask_key, output_shape)
    request.add(gt_affs_scale_key, output_shape)

    request.add(pred_affs_key, output_shape)

    # offset = Coordinate((input_shape[i]-output_shape[i])/2 for i in range(len(input_shape)))
    # crop_roi = Roi(offset, output_shape)
    # print("crop_roi: ", crop_roi)

    pipeline = ()
    # print ("iteration: ", iteration)
    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=raw_affs_key,
                                    joined_affinities=raw_joined_affs_key)

    pipeline += AddRealism(joined_affinities=raw_joined_affs_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=affs_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood_opp,
                              labels=labels_key,
                              affinities=affs_opp_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_key,
                                    joined_affinities=joined_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=affs_opp_key,
                                    joined_affinities=joined_affs_opp_key)

    pipeline += MergeLabels(labels=labels_key,
                            joined_affinities=joined_affs_key,
                            joined_affinities_opp=joined_affs_opp_key,
                            merged_labels=merged_labels_key,
                            every=2)

    # pipeline += GrowBoundary(merged_labels_key, steps=1, only_xy=True)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=merged_labels_key,
                              affinities=gt_affs_in_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=merged_labels_key,
                              affinities=gt_affs_key,
                              affinities_mask=gt_affs_mask_key)

    # if phase == 'euclid':
    pipeline += BalanceLabels(labels=gt_affs_key, scales=gt_affs_scale_key)

    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)

    pipeline += IntensityScaleShift(raw_key, 2, -1)

    if phase == 'malis':
        # pipeline += Crop(
        # 	key=merged_labels_key,
        # 	roi=crop_roi)
        pipeline += RenumberConnectedComponents(labels=merged_labels_key)

    pipeline += PreCache(cache_size=32, num_workers=8)

    train_inputs = {
        config['raw']: raw_key,
        config['gt_affs_in']: gt_affs_in_key,
        config['gt_affs_out']: gt_affs_key,
        config['pred_affs_loss_weights']: gt_affs_scale_key
    }

    if phase == 'euclid':
        train_loss = config['loss']
        train_optimizer = config['optimizer']
        train_summary = config['summary']
        # train_inputs[config['pred_affs_loss_weights']] = input_affinities_scale_key
    else:
        train_loss = None
        train_optimizer = add_malis_loss
        train_inputs['gt_seg:0'] = merged_labels_key  # XXX question
        train_inputs['gt_affs_mask:0'] = gt_affs_mask_key
        train_summary = 'Merge/MergeSummary:0'

    pipeline += Train(graph=setup_dir + 'train_net',
                      optimizer=train_optimizer,
                      loss=train_loss,
                      inputs=train_inputs,
                      outputs={config['pred_affs']: pred_affs_key},
                      gradients={config['pred_affs']: pred_affs_gradient_key},
                      summary=train_summary,
                      log_dir='log/prob_unet/' + setup_name,
                      save_every=2000)

    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)

    pipeline += Snapshot(dataset_names={
        merged_labels_key: 'volumes/labels',
        raw_key: 'volumes/raw',
        pred_affs_key: 'volumes/pred_affs',
        gt_affs_key: 'volumes/gt_affs'
    },
                         output_filename='prob_unet/' + setup_name +
                         '/batch_{iteration}.hdf',
                         every=4000,
                         dataset_dtypes={
                             merged_labels_key: np.uint64,
                             raw_key: np.float32
                         })

    pipeline += PrintProfilingStats(every=20)

    print("Starting training...")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
    print("Training finished")
Example #2
0
def predict(checkpoint, iterations):

    print("checkpoint: ", checkpoint)

    labels_key = ArrayKey('LABELS')
    gt_affs_key = ArrayKey('GT_AFFINITIES')
    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')
    sample_z_key = ArrayKey("SAMPLE_Z")
    # broadcast_key = ArrayKey("BROADCAST")
    # pred_logits_key = ArrayKey("PRED_LOGITS")
    # sample_out_key = ArrayKey("SAMPLE_OUT")
    # debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    sample_shape = Coordinate((1, 1, 6)) * voxel_size
    # debug_shape = Coordinate((1, 1, 5)) * voxel_size

    print("input_size: ", input_shape)
    print("output_size: ", output_shape)

    request = BatchRequest()
    request.add(labels_key, output_shape)
    request.add(gt_affs_key, input_shape)
    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(pred_affinities_key, output_shape)
    # request.add(broadcast_key, output_shape)
    request.add(sample_z_key, sample_shape)
    # request.add(pred_logits_key, output_shape)
    # request.add(sample_out_key, sample_shape)
    # request.add(debug_key, debug_shape)

    dataset_names = {
        labels_key: 'volumes/labels',
    }

    array_specs = {labels_key: ArraySpec(interpolatable=False)}

    pipeline = tuple(
        Hdf5Source(os.path.join(data_dir, sample + '.hdf'),
                   datasets=dataset_names,
                   array_specs=array_specs) + Pad(labels_key, None)
        # Pad(merged_labels_key[i], None) for i in range(num_merges) # don't know why this doesn't work
        for sample in samples)

    pipeline += (
        # Pad(raw_key, size=None) +
        # Crop(raw_key, read_roi) +
        #Normalize(raw_key) +
        SequentialProvider() +
        AddAffinities(affinity_neighborhood=neighborhood,
                      labels=labels_key,
                      affinities=raw_affs_key) +
        AddJoinedAffinities(input_affinities=raw_affs_key,
                            joined_affinities=raw_joined_affs_key) +
        AddRealism(joined_affinities=raw_joined_affs_key,
                   raw=raw_key,
                   sp=0.25,
                   sigma=1,
                   contrast=0.7) +
        GrowBoundary(labels_key, steps=1, only_xy=True) +
        AddAffinities(affinity_neighborhood=neighborhood,
                      labels=labels_key,
                      affinities=gt_affs_key) +
        PreCache(cache_size=32, num_workers=8) +
        IntensityScaleShift(raw_key, 2, -1) + Predict(
            checkpoint=os.path.join(setup_dir,
                                    'train_net_checkpoint_%d' % checkpoint),
            inputs={config['raw']: raw_key},
            outputs={
                config['pred_affs']: pred_affinities_key,
                config['sample_z']: sample_z_key,
                # config['broadcast']: broadcast_key,
                # config['pred_logits']: pred_logits_key,
                # config['sample_out']: sample_out_key,
                # config['debug']: debug_key
            },
            graph=os.path.join(setup_dir, 'predict_net.meta')) +
        IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5) + Snapshot(
            dataset_names={
                labels_key: 'volumes/labels',
                gt_affs_key: 'volumes/gt_affs',
                # raw_key: 'volumes/raw',
                pred_affinities_key: 'volumes/pred_affs',
                # broadcast_key: 'volumes/broadcast',
                sample_z_key: 'volumes/sample_z',
                # pred_logits_key: 'volumes/pred_logits',
                # sample_out_key: 'volumes/sample_out'
            },
            output_filename='prob_unet/' + setup_name + '/prediction_{id}.hdf',
            every=1,
            dataset_dtypes={
                labels_key: np.uint16,
                gt_affs_key: np.float32,
                pred_affinities_key: np.float32,
                # broadcast_key: np.float32,
                sample_z_key: np.float32,
                # pred_logits_key: np.float32,
                # sample_out_key: np.float32
            })
        # PrintProfilingStats(every=20)
    )

    print("Starting prediction...")
    with build(pipeline) as p:
        for i in range(iterations):
            req = p.request_batch(request)
            # sample_z = req[sample_z_key].data
            # broadcast_sample = req[broadcast_key].data
            # sample_out = req[sample_out_key].data
            # debug = req[debug_key].data
            # print("debug", debug)

            # print("sample_z: ", sample_z)
            # print("sample_out:", sample_out)
            # print("Z - 0")
            # print(np.unique(broadcast_sample[0, 0, :, :, :]))
            # print("Z - 1")
            # print(np.unique(broadcast_sample[0, 1, :, :, :]))
            # print("Z - 2")
            # print(np.unique(broadcast_sample[0, 2, :, :, :]))
            # print("Z - 3")
            # print(np.unique(broadcast_sample[0, 3, :, :, :]))
            # print("Z - 4")
            # print(np.unique(broadcast_sample[0, 4, :, :, :]))
            # print("Z - 5")
            # print(np.unique(broadcast_sample[0, 5, :, :, :]))
    print("Prediction finished")
def train(iterations):

    # tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    if trained_until < phase_switch and iterations > phase_switch:
        train(phase_switch)

    phase = 'euclid' if iterations <= phase_switch else 'malis'
    print("Training in phase %s until %i" % (phase, iterations))

    # define array-keys
    labels_key = ArrayKey('LABELS')
    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    merged_labels_keys = []
    # merged_affs_keys = []
    picked_labels_key = ArrayKey('PICKED_RANDOM_LABEL')

    affs_neg_key = ArrayKey('AFFINITIES')
    affs_pos_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_neg_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_pos_key = ArrayKey('JOINED_AFFINITIES_OPP')

    num_merges = 3
    for i in range(num_merges):
        merged_labels_keys.append(ArrayKey('MERGED_LABELS_%i' % (i + 1)))

    gt_affs_out_key = ArrayKey('GT_AFFINITIES')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')
    gt_affs_scale_key = ArrayKey('GT_AFFINITIES_SCALE')

    pred_affs_key = ArrayKey('PRED_AFFS')
    pred_affs_gradient_key = ArrayKey('PRED_AFFS_GRADIENT')

    sample_z_key = ArrayKey("SAMPLE_Z")
    broadcast_key = ArrayKey("BROADCAST")
    sample_out_key = ArrayKey("SAMPLE_OUT")
    debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    input_affs_shape = Coordinate([i + 1 for i in config['input_shape']
                                   ]) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in config['output_shape']
                                    ]) * voxel_size
    sample_shape = Coordinate((1, 1, 6)) * voxel_size
    debug_shape = Coordinate((1, 1, 5)) * voxel_size

    print("input_shape: ", input_shape)
    print("input_affs_shape: ", input_affs_shape)
    print("output_shape: ", output_shape)
    print("output_affs_shape: ", output_affs_shape)

    request = BatchRequest()
    request.add(labels_key, input_shape)

    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)
    request.add(raw_key, input_shape)

    for i in range(num_merges):
        request.add(merged_labels_keys[i], input_shape)
    request.add(picked_labels_key, output_shape)

    request.add(gt_affs_out_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(gt_affs_mask_key, output_shape)
    request.add(gt_affs_scale_key, output_shape)

    request.add(pred_affs_key, output_shape)
    request.add(pred_affs_gradient_key, output_shape)

    request.add(broadcast_key, output_shape)
    request.add(sample_z_key, sample_shape)
    request.add(sample_out_key, sample_shape)
    request.add(debug_key, debug_shape)

    # offset = Coordinate((input_shape[i]-output_shape[i])/2 for i in range(len(input_shape)))
    # crop_roi = Roi(offset, output_shape)
    # print("crop_roi: ", crop_roi)

    dataset_names = {
        labels_key: 'volumes/labels',
    }

    array_specs = {labels_key: ArraySpec(interpolatable=False)}

    for i in range(num_merges):
        dataset_names[
            merged_labels_keys[i]] = 'volumes/merged_labels_%i' % (i + 1)
        array_specs[merged_labels_keys[i]] = ArraySpec(interpolatable=False)

    pipeline = tuple(
        Hdf5Source(os.path.join(data_dir, sample + '.hdf'),
                   datasets=dataset_names,
                   array_specs=array_specs) + Pad(labels_key, None) +
        Pad(merged_labels_keys[0], None) + Pad(merged_labels_keys[1], None) +
        Pad(merged_labels_keys[2], None)
        # Pad(merged_labels_key[i], None) for i in range(num_merges) # don't know why this doesn't work
        for sample in samples)

    pipeline += RandomProvider()

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=raw_affs_key,
                                    joined_affinities=raw_joined_affs_key)

    pipeline += AddRealism(joined_affinities=raw_joined_affs_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)

    if phase == "euclid":

        pipeline += PickRandomLabel(input_labels=[labels_key] +
                                    merged_labels_keys,
                                    output_label=picked_labels_key,
                                    probabilities=[1, 0, 0, 0])

    else:

        pipeline += PickRandomLabel(input_labels=[labels_key] +
                                    merged_labels_keys,
                                    output_label=picked_labels_key,
                                    probabilities=[0.5, 0.5, 0, 0])

        pipeline += RenumberConnectedComponents(labels=picked_labels_key)

    pipeline += GrowBoundary(picked_labels_key, steps=1, only_xy=True)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=picked_labels_key,
                              affinities=gt_affs_in_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=picked_labels_key,
                              affinities=gt_affs_out_key,
                              affinities_mask=gt_affs_mask_key)

    # if phase == 'euclid':
    pipeline += BalanceLabels(labels=gt_affs_out_key, scales=gt_affs_scale_key)

    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)

    pipeline += IntensityScaleShift(raw_key, 2, -1)

    pipeline += PreCache(cache_size=8, num_workers=4)

    train_inputs = {
        config['raw']: raw_key,
        config['gt_affs_in']: gt_affs_in_key,
        config['gt_affs_out']: gt_affs_out_key,
        config['pred_affs_loss_weights']: gt_affs_scale_key
    }

    if phase == 'euclid':
        train_loss = config['loss']
        train_optimizer = config['optimizer']
        train_summary = config['summary']
        # train_inputs[config['pred_affs_loss_weights']] = input_affinities_scale_key
    else:
        train_loss = None
        train_optimizer = add_malis_loss
        train_inputs['gt_seg:0'] = picked_labels_key
        train_inputs['gt_affs_mask:0'] = gt_affs_mask_key
        train_summary = 'Merge/MergeSummary:0'

    pipeline += Train(graph=setup_dir + 'train_net',
                      optimizer=train_optimizer,
                      loss=train_loss,
                      inputs=train_inputs,
                      outputs={
                          config['pred_affs']: pred_affs_key,
                          config['broadcast']: broadcast_key,
                          config['sample_z']: sample_z_key,
                          config['sample_out']: sample_out_key,
                          config['debug']: debug_key
                      },
                      gradients={config['pred_affs']: pred_affs_gradient_key},
                      summary=train_summary,
                      log_dir='log/prob_unet/' + setup_name,
                      save_every=2000)

    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)

    pipeline += Snapshot(dataset_names={
        labels_key: 'volumes/labels',
        picked_labels_key: 'volumes/merged_labels',
        raw_affs_key: 'volumes/raw_affs',
        raw_key: 'volumes/raw',
        pred_affs_key: 'volumes/pred_affs',
        gt_affs_out_key: 'volumes/gt_affs_out',
        gt_affs_in_key: 'volumes/gt_affs_in'
    },
                         output_filename='prob_unet/' + setup_name +
                         '/batch_{iteration}.hdf',
                         every=2000,
                         dataset_dtypes={
                             labels_key: np.uint64,
                             picked_labels_key: np.uint64,
                             raw_key: np.float32
                         })

    pipeline += PrintProfilingStats(every=20)

    print("Starting training...")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
            # sample_z = req[sample_z_key].data
            # broadcast_sample = req[broadcast_key].data
            # sample_out = req[sample_out_key].data
            # print("sample_out:", sample_out)
            # debug = req[debug_key].data
            # print("debug", debug)

            # print("sample_z: ", sample_z)
            # print("Z - 0")
            # print(np.unique(broadcast_sample[0, 0, :, :, :]))
            # print("Z - 1")
            # print(np.unique(broadcast_sample[0, 1, :, :, :]))
            # print("Z - 2")
            # print(np.unique(broadcast_sample[0, 2, :, :, :]))
            # print("Z - 3")
            # print(np.unique(broadcast_sample[0, 3, :, :, :]))
            # print("Z - 4")
            # print(np.unique(broadcast_sample[0, 4, :, :, :]))
            # print("Z - 5")
            # print(np.unique(broadcast_sample[0, 5, :, :, :]))

    print("Training finished")
Example #4
0
def predict(checkpoint, iterations):

    print("checkpoint: ", checkpoint)

    labels_key = ArrayKey('GT_LABELS')
    joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
    raw_affinities_key = ArrayKey('RAW_AFFINITIES_KEY')
    raw_key = ArrayKey('RAW')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size

    print("input_size: ", input_shape)
    print("output_size: ", output_shape)

    request = BatchRequest()
    # request.add(labels_key, input_shape)
    request.add(joined_affinities_key, input_shape)
    request.add(raw_affinities_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(pred_affinities_key, output_shape)

    pipeline = (
        ToyNeuronSegmentationGenerator(array_key=labels_key,
                                       n_objects=50,
                                       points_per_skeleton=8,
                                       smoothness=3,
                                       noise_strength=1,
                                       interpolation="random",
                                       seed=0) +
        AddAffinities(affinity_neighborhood=neighborhood,
                      labels=labels_key,
                      affinities=raw_affinities_key) +
        AddJoinedAffinities(input_affinities=raw_affinities_key,
                            joined_affinities=joined_affinities_key) +
        AddRealism(joined_affinities=joined_affinities_key,
                   raw=raw_key,
                   sp=0.65,
                   sigma=1,
                   contrast=0.7) +
        # Pad(raw_key, size=None) +
        # Crop(raw_key, read_roi) +
        # Normalize(raw_key) +
        IntensityScaleShift(raw_key, 2, -1) +
        Predict(checkpoint=os.path.join(setup_dir, 'train_net_checkpoint_%d' %
                                        checkpoint),
                inputs={config['raw']: raw_key},
                outputs={config['pred_affs']: pred_affinities_key},
                graph=os.path.join(setup_dir, 'predict_net.meta')) +
        IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5) +
        Snapshot(dataset_names={
            labels_key: 'volumes/labels',
            raw_affinities_key: 'volumes/raw_affs',
            raw_key: 'volumes/raw',
            pred_affinities_key: 'volumes/pred_affs',
        },
                 output_filename='prob_unet/prediction_{id}.hdf',
                 every=1,
                 dataset_dtypes={
                     labels_key: np.uint16,
                     raw_key: np.float32,
                     pred_affinities_key: np.float32,
                     sample_z_key: np.float32
                 })
        # PrintProfilingStats(every=1)
    )

    print("Starting prediction...")
    with build(pipeline) as p:
        for i in range(iterations):
            p.request_batch(request)
    print("Prediction finished")
def train(iterations):

    # tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    if trained_until < phase_switch and iterations > phase_switch:
        train(phase_switch)

    phase = 'euclid' if iterations <= phase_switch else 'malis'
    print("Training in phase %s until %i" % (phase, iterations))

    # define array-keys
    labels_key = ArrayKey('LABELS')
    raw_affs_key = ArrayKey('RAW_AFFINITIES')
    raw_joined_affs_key = ArrayKey('RAW_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')

    affs_neg_key = ArrayKey('AFFINITIES')
    affs_pos_key = ArrayKey('AFFINITIES_OPP')
    joined_affs_neg_key = ArrayKey('JOINED_AFFINITIES')
    joined_affs_pos_key = ArrayKey('JOINED_AFFINITIES_OPP')

    gt_affs_out_key = ArrayKey('GT_AFFINITIES')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_mask_key = ArrayKey('GT_AFFINITIES_MASK')
    gt_affs_scale_key = ArrayKey('GT_AFFINITIES_SCALE')

    pred_affs_key = ArrayKey('PRED_AFFS')
    pred_affs_gradient_key = ArrayKey('PRED_AFFS_GRADIENT')

    sample_z_key = ArrayKey("SAMPLE_Z")
    broadcast_key = ArrayKey("BROADCAST")
    pred_logits_key = ArrayKey("PRED_LOGITS")
    sample_out_key = ArrayKey("SAMPLE_OUT")
    debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    input_affs_shape = Coordinate([i + 1 for i in config['input_shape']
                                   ]) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    output_affs_shape = Coordinate([i + 1 for i in config['output_shape']
                                    ]) * voxel_size
    sample_shape = Coordinate((1, 1, config['latent_dims'])) * voxel_size
    debug_shape = Coordinate((1, 1, 5)) * voxel_size

    print("input_shape: ", input_shape)
    print("input_affs_shape: ", input_affs_shape)
    print("output_shape: ", output_shape)
    print("output_affs_shape: ", output_affs_shape)

    request = BatchRequest()
    request.add(labels_key, input_shape)

    request.add(raw_affs_key, input_shape)
    request.add(raw_joined_affs_key, input_shape)
    request.add(raw_key, input_shape)

    request.add(gt_affs_out_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(gt_affs_mask_key, output_shape)
    request.add(gt_affs_scale_key, output_shape)

    request.add(pred_affs_key, output_shape)
    request.add(pred_affs_gradient_key, output_shape)

    request.add(broadcast_key, output_shape)
    request.add(sample_z_key, sample_shape)
    request.add(pred_logits_key, output_shape)
    request.add(sample_out_key, sample_shape)
    request.add(debug_key, debug_shape)

    pipeline = ()
    # print ("iteration: ", iteration)
    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=raw_affs_key)

    pipeline += AddJoinedAffinities(input_affinities=raw_affs_key,
                                    joined_affinities=raw_joined_affs_key)

    pipeline += AddRealism(joined_affinities=raw_joined_affs_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)

    pipeline += RenumberConnectedComponents(labels=labels_key)

    pipeline += GrowBoundary(labels_key, steps=1, only_xy=True)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_in_key)

    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_out_key,
                              affinities_mask=gt_affs_mask_key)

    pipeline += BalanceLabels(labels=gt_affs_out_key, scales=gt_affs_scale_key)

    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)

    pipeline += IntensityScaleShift(raw_key, 2, -1)

    pipeline += PreCache(cache_size=8, num_workers=4)

    train_inputs = {
        config['raw']: raw_key,
        config['gt_affs_in']: gt_affs_in_key,
        config['gt_affs_out']: gt_affs_out_key,
        config['pred_affs_loss_weights']: gt_affs_scale_key
    }

    if phase == 'euclid':
        train_loss = config['loss']
        train_optimizer = config['optimizer']
        train_summary = config['summary']
    else:
        train_loss = None
        train_optimizer = add_malis_loss
        train_inputs['gt_seg:0'] = labels_key
        train_inputs['gt_affs_mask:0'] = gt_affs_mask_key
        train_summary = 'Merge/MergeSummary:0'

    pipeline += Train(graph=setup_dir + 'train_net',
                      optimizer=train_optimizer,
                      loss=train_loss,
                      inputs=train_inputs,
                      outputs={
                          config['pred_affs']: pred_affs_key,
                          config['broadcast']: broadcast_key,
                          config['sample_z']: sample_z_key,
                          config['pred_logits']: pred_logits_key,
                          config['sample_out']: sample_out_key,
                          config['debug']: debug_key
                      },
                      gradients={config['pred_affs']: pred_affs_gradient_key},
                      summary=train_summary,
                      log_dir='log/prob_unet/' + setup_name,
                      save_every=2000)

    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)

    pipeline += Snapshot(dataset_names={
        labels_key: 'volumes/labels',
        raw_affs_key: 'volumes/raw_affs',
        raw_key: 'volumes/raw',
        gt_affs_in_key: 'volumes/gt_affs_in',
        gt_affs_out_key: 'volumes/gt_affs_out',
        pred_affs_key: 'volumes/pred_affs',
        pred_logits_key: 'volumes/pred_logits'
    },
                         output_filename='prob_unet/' + setup_name +
                         '/batch_{iteration}.hdf',
                         every=2000,
                         dataset_dtypes={
                             labels_key: np.uint16,
                             raw_affs_key: np.float32,
                             raw_key: np.float32,
                             gt_affs_in_key: np.float32,
                             gt_affs_out_key: np.float32,
                             pred_affs_key: np.float32,
                             pred_logits_key: np.float32
                         })

    pipeline += PrintProfilingStats(every=20)

    print("Starting training...")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
            # sample_z = req[sample_z_key].data
            # broadcast_sample = req[broadcast_key].data
            # sample_out = req[sample_out_key].data
            # debug = req[debug_key].data
            # print("debug", debug)

            # print("sample_z: ", sample_z)
            # print("sample_out:", sample_out)
            # print("Z - 0")
            # print(np.unique(broadcast_sample[0, 0, :, :, :]))
            # print("Z - 1")
            # print(np.unique(broadcast_sample[0, 1, :, :, :]))
            # print("Z - 2")
            # print(np.unique(broadcast_sample[0, 2, :, :, :]))
            # print("Z - 3")
            # print(np.unique(broadcast_sample[0, 3, :, :, :]))
            # print("Z - 4")
            # print(np.unique(broadcast_sample[0, 4, :, :, :]))
            # print("Z - 5")
            # print(np.unique(broadcast_sample[0, 5, :, :, :]))
    print("Training finished")
def train(iterations):
    tf.reset_default_graph()
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= iterations:
        return

    # define array-keys
    labels_key = ArrayKey('GT_LABELS')
    gt_affs_in_key = ArrayKey('GT_AFFINITIES_IN')
    gt_affs_out_key = ArrayKey('GT_AFFINITIES_OUT')
    joined_affinities_key = ArrayKey('GT_JOINED_AFFINITIES')
    raw_key = ArrayKey('RAW')
    input_affinities_scale_key = ArrayKey('GT_AFFINITIES_SCALE')
    pred_affinities_key = ArrayKey('PREDICTED_AFFS')
    pred_affinities_gradient_key = ArrayKey('AFFS_GRADIENT')
    gt_affs_mask = ArrayKey('GT_AFFINITIES_MASK')
    debug_key = ArrayKey("DEBUG")

    voxel_size = Coordinate((1, 1, 1))
    input_shape = Coordinate(config['input_shape']) * voxel_size
    output_shape = Coordinate(config['output_shape']) * voxel_size
    debug_shape = Coordinate((1, 1, 5)) * voxel_size

    print("input_shape: ", input_shape)
    print("output_shape: ", output_shape)

    # define requests
    request = BatchRequest()
    request.add(labels_key, output_shape)
    request.add(gt_affs_in_key, input_shape)
    request.add(joined_affinities_key, input_shape)
    request.add(raw_key, input_shape)
    request.add(gt_affs_out_key, output_shape)
    request.add(input_affinities_scale_key, output_shape)
    request.add(pred_affinities_key, output_shape)
    request.add(gt_affs_mask, output_shape)
    request.add(debug_key, debug_shape)

    offset = Coordinate((input_shape[i] - output_shape[i]) / 2
                        for i in range(len(input_shape)))
    crop_roi = Roi(offset, output_shape)
    # print("crop_roi: ", crop_roi)

    pipeline = ()
    pipeline += ToyNeuronSegmentationGenerator(array_key=labels_key,
                                               n_objects=50,
                                               points_per_skeleton=8,
                                               smoothness=3,
                                               noise_strength=1,
                                               interpolation="linear")
    # ElasticAugment(
    # 	control_point_spacing=[4,40,40],
    # 	jitter_sigma=[0,2,2],
    # 	rotation_interval=[0,math.pi/2.0],
    # 	prob_slip=0.05,
    # 	prob_shift=0.05,
    # 	max_misalign=10,
    # 	subsample=8) +
    # SimpleAugment(transpose_only=[1, 2]) +
    # IntensityAugment(labels, 0.9, 1.1, -0.1, 0.1, z_section_wise=True) +
    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_in_key)
    pipeline += GrowBoundary(labels_key, steps=1, only_xy=True)
    pipeline += AddAffinities(affinity_neighborhood=neighborhood,
                              labels=labels_key,
                              affinities=gt_affs_out_key,
                              affinities_mask=gt_affs_mask)
    pipeline += AddJoinedAffinities(input_affinities=gt_affs_in_key,
                                    joined_affinities=joined_affinities_key)
    pipeline += AddRealism(joined_affinities=joined_affinities_key,
                           raw=raw_key,
                           sp=0.25,
                           sigma=1,
                           contrast=0.7)
    pipeline += BalanceLabels(labels=gt_affs_out_key,
                              scales=input_affinities_scale_key)
    pipeline += DefectAugment(intensities=raw_key,
                              prob_missing=0.03,
                              prob_low_contrast=0.01,
                              contrast_scale=0.5,
                              axis=0)
    pipeline += IntensityScaleShift(raw_key, 2, -1)
    pipeline += PreCache(cache_size=32, num_workers=8)
    pipeline += Crop(key=labels_key, roi=crop_roi)
    pipeline += RenumberConnectedComponents(labels=labels_key)
    train = Train(
        graph='train/unet/train_net',
        # optimizer=config['optimizer'],
        optimizer=add_malis_loss,
        # loss=config['loss'],
        loss=None,
        inputs={
            config['raw']: raw_key,
            "gt_seg:0": labels_key,
            "gt_affs_mask:0": gt_affs_mask,
            config['gt_affs']: gt_affs_out_key,
            config['pred_affs_loss_weights']: input_affinities_scale_key,
        },
        outputs={
            config['pred_affs']: pred_affinities_key,
            config['debug']: debug_key,
        },
        gradients={config['pred_affs']: pred_affinities_gradient_key},
        summary="malis_loss:0",
        log_dir='log/unet',
        save_every=1)
    pipeline += train
    pipeline += IntensityScaleShift(array=raw_key, scale=0.5, shift=0.5)
    # Snapshot(
    # 	dataset_names={
    # 		labels_key: 'volumes/labels',
    # 		raw_key: 'volumes/raw',
    # 		gt_affs_out_key: 'volumes/gt_affs',
    # 		pred_affinities_key: 'volumes/pred_affs'
    # 	},
    # 	output_filename='unet/train/batch_{iteration}.hdf',
    # 	every=100,
    # 	dataset_dtypes={
    # 		raw_key: np.float32,
    # 		labels_key: np.uint64
    # 	}) +
    pipeline += PrintProfilingStats(every=8)

    print("Starting training... COOL BEANS")
    with build(pipeline) as p:
        for i in range(iterations - trained_until):
            req = p.request_batch(request)
            pred_affs = req[pred_affinities_key].data
            debug = req[debug_key].data
            print("debug", debug)
            print('pred_affs', pred_affs)
            print("name of pred_adds: ", req[pred_affinities_key])
            # print("train session: ", train.session)
            # print ("all vars: ", [n.name for n in tf.get_default_graph().as_graph_def().node])
            # graph_def = tf.graph_util.convert_variables_to_constants(train.session, tf.get_default_graph().as_graph_def(), ["pred_affs:0".split(':')[0]])
            # print ("labels: ", req[labels_key].data.shape)
            # print ("affinities_out: ", req[gt_affs_out_key].data.shape)
            # print ("affinities_joined: ", req[joined_affinities_key].data.shape)
            # print ("raw: ", req[raw_key].data.shape)
            # print ("affinities_in_scale: ", req[input_affinities_scale_key].data.shape)
    print("Training finished")
Example #7
0
def generate_data(num_batches):

    labels_key = ArrayKey('LABELS')
    gt_affs_key = ArrayKey('GT_AFFINITIES')
    joined_affs_key = ArrayKey('JOINED_AFFINITIES')
    raw_key1 = ArrayKey('RAW1')
    raw_key2 = ArrayKey('RAW2')
    raw_key3 = ArrayKey('RAW3')

    voxel_size = Coordinate((1, 1, 1))
    input_size = Coordinate((133, 133, 133)) * voxel_size
    affs_size = Coordinate((131, 131, 131)) * voxel_size
    output_size = Coordinate((44, 44, 44)) * voxel_size

    print("input_size: ", input_size)
    print("output_size: ", output_size)

    request = BatchRequest()
    request.add(labels_key, input_size)
    request.add(gt_affs_key, affs_size)
    request.add(joined_affs_key, affs_size)
    request.add(raw_key1, affs_size)

    pipeline = (
        Hdf5Source(os.path.join(data_dir, 'seg_standard.hdf'),
                   datasets={labels_key: "volumes/labels"},
                   array_specs={labels_key: ArraySpec(interpolatable=False)}) +
        Pad(labels_key, None) + AddAffinities(
            affinity_neighborhood=[[-1, 0, 0], [0, -1, 0], [0, 0, -1]],
            labels=labels_key,
            affinities=gt_affs_key) +
        AddJoinedAffinities(input_affinities=gt_affs_key,
                            joined_affinities=joined_affs_key) +
        AddRealism(joined_affinities=joined_affs_key,
                   raw=raw_key1,
                   sp=0.25,
                   sigma=1,
                   contrast=0.7) +
        Snapshot(
            dataset_names={
                raw_key1: 'volumes/raw',
                # gt_affs_key: 'volumes/gt_affs',
                # joined_affs_key: 'volumes/joined_affs',
                # raw_key1: 'volumes/raw1',
                # raw_key2: 'volumes/raw2',
                # raw_key3: 'volumes/raw3',
            },
            output_filename="results/data_gen/raw_synth/contrast_07.hdf",
            every=1,
            dataset_dtypes={
                # labels_key: np.uint64,
                raw_key1: np.float32,
                # raw_key2: np.float32,
                # raw_key3: np.float32,
                # gt_affs_key: np.float32,
                # joined_affs_key: np.float32
            }))

    hashes = []
    with build(pipeline) as p:
        for i in range(num_batches):
            print("\nDATA POINT: ", i)
            req = p.request_batch(request)
            labels = req[labels_key].data
            hashes = np.sum(labels)
            print(hashes)