Example #1
0
def create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter,
                  gt_neurons):
    data_sources = tuple((
        Hdf5PointsSource(os.path.join(data_dir_syn, sample + '.hdf'),
                         datasets={
                             presyn: 'annotations',
                             postsyn: 'annotations'
                         },
                         rois={
                             presyn: cremi_roi,
                             postsyn: cremi_roi
                         }),
        Hdf5PointsSource(
            os.path.join(data_dir_syn, sample + '.hdf'),
            datasets={dummypostsyn: 'annotations'},
            rois={
                # presyn: cremi_roi,
                dummypostsyn: cremi_roi
            },
            kind='postsyn'),
        gp.Hdf5Source(os.path.join(data_dir, sample + '.hdf'),
                      datasets={
                          raw: 'volumes/raw',
                          gt_neurons: 'volumes/labels/neuron_ids',
                      },
                      array_specs={
                          raw: gp.ArraySpec(interpolatable=True),
                          gt_neurons: gp.ArraySpec(interpolatable=False),
                      })))
    source_pip = data_sources + gp.MergeProvider() + gp.Normalize(
        raw) + gp.RandomLocation(ensure_nonempty=dummypostsyn,
                                 p_nonempty=parameter['reject_probability'])
    return source_pip
Example #2
0
    def test_pipeline3(self):
        array_key = gp.ArrayKey("TEST_ARRAY")
        points_key = gp.PointsKey("TEST_POINTS")
        voxel_size = gp.Coordinate((1, 1))
        spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file,
                                    {array_key: 'testdata'},
                                    array_specs={array_key: spec})
        csv_source = gp.CsvPointsSource(
            self.fake_points_file, points_key,
            gp.PointsSpec(
                roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0))))

        request = gp.BatchRequest()
        shape = gp.Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = gp.ShiftAugment(prob_slip=0.2,
                                     prob_shift=0.2,
                                     sigma=5,
                                     shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() +
                    gp.RandomLocation(ensure_nonempty=points_key) + shift_node)
        with gp.build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = request[points_key].data
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points.values()
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(result_val, list(result_points.values()), target_vals,
                        self.fake_points))
Example #3
0
    def test_prepare1(self):

        key = gp.ArrayKey("TEST_ARRAY")
        spec = gp.ArraySpec(voxel_size=gp.Coordinate((1, 1)),
                            interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'},
                                    array_specs={key: spec})

        request = gp.BatchRequest()
        shape = gp.Coordinate((3, 3))
        request.add(key, shape, voxel_size=gp.Coordinate((1, 1)))

        shift_node = gp.ShiftAugment(sigma=1, shift_axis=0)
        with gp.build((hdf5_source + shift_node)):
            shift_node.prepare(request)
            self.assertTrue(shift_node.ndim == 2)
            self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0]))
Example #4
0
    def test_pipeline2(self):

        key = gp.ArrayKey("TEST_ARRAY")
        spec = gp.ArraySpec(voxel_size=gp.Coordinate((3, 1)),
                            interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'},
                                    array_specs={key: spec})

        request = gp.BatchRequest()
        shape = gp.Coordinate((3, 3))
        request.add(key, shape, voxel_size=gp.Coordinate((3, 1)))

        shift_node = gp.ShiftAugment(prob_slip=0.2,
                                     prob_shift=0.2,
                                     sigma=1,
                                     shift_axis=0)
        with gp.build((hdf5_source + shift_node)) as b:
            b.request_batch(request)
Example #5
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_mask = gp.ArrayKey('GT_MASK')
    gt_dt = gp.ArrayKey('GT_DT')
    pred_dt = gp.ArrayKey('PRED_DT')
    loss_gradient = gp.ArrayKey('LOSS_GRADIENT')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_mask, output_shape)
    request.add(gt_dt, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    snapshot_request.add(gt_dt, output_shape)
    snapshot_request.add(pred_dt, output_shape)
    snapshot_request.add(loss_gradient, output_shape)

    # specify data source
    data_sources = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw: sample + '/raw',
                        gt_mask: sample + '/fg'
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask, np.uint8) +
                gp.Pad(raw, context) +
                gp.Pad(gt_mask, context) +
                gp.RandomLocation()
                for sample in f)

    pipeline = (
            data_sources +
            gp.RandomProvider() +
            gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) +
            DistanceTransform(gt_mask, gt_dt, 3) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2,-1) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=5) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt_dt']: gt_dt,
                },
                outputs={
                    net_names['pred_dt']: pred_dt,
                },
                gradients={
                    net_names['pred_dt']: loss_gradient,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    gt_mask: 'volumes/gt_mask',
                    gt_dt: 'volumes/gt_dt',
                    pred_dt: 'volumes/pred_dt',
                    loss_gradient: 'volumes/gradient',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2000) +
            gp.PrintProfilingStats(every=500)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Example #6
0
def train():

    gunpowder.set_verbose(False)

    affinity_neighborhood = malis.mknhood3d()
    solver_parameters = gunpowder.caffe.SolverParameters()
    solver_parameters.train_net = 'net.prototxt'
    solver_parameters.base_lr = 1e-4
    solver_parameters.momentum = 0.95
    solver_parameters.momentum2 = 0.999
    solver_parameters.delta = 1e-8
    solver_parameters.weight_decay = 0.000005
    solver_parameters.lr_policy = 'inv'
    solver_parameters.gamma = 0.0001
    solver_parameters.power = 0.75
    solver_parameters.snapshot = 10000
    solver_parameters.snapshot_prefix = 'net'
    solver_parameters.type = 'Adam'
    solver_parameters.resume_from = None
    solver_parameters.train_state.add_stage('euclid')

    request = BatchRequest()
    request.add_volume_request(VolumeTypes.RAW, constants.input_shape)
    request.add_volume_request(VolumeTypes.GT_LABELS, constants.output_shape)
    request.add_volume_request(VolumeTypes.GT_MASK, constants.output_shape)
    request.add_volume_request(VolumeTypes.GT_AFFINITIES, constants.output_shape)
    request.add_volume_request(VolumeTypes.LOSS_SCALE, constants.output_shape)

    data_providers = list()
    fibsem_dir = "/groups/turaga/turagalab/data/FlyEM/fibsem_medulla_7col"
    for volume_name in ("tstvol-520-1-h5",):
        h5_filepath = "./{}.h5".format(volume_name)
        path_to_labels = os.path.join(fibsem_dir, volume_name, "groundtruth_seg.h5")
        with h5py.File(path_to_labels, "r") as f_labels:
            mask_shape = f_labels["main"].shape
        with h5py.File(h5_filepath, "w") as h5:
            h5['volumes/raw'] = h5py.ExternalLink(os.path.join(fibsem_dir, volume_name, "im_uint8.h5"), "main")
            h5['volumes/labels/neuron_ids'] = h5py.ExternalLink(path_to_labels, "main")
            h5.create_dataset(
                name="volumes/labels/mask",
                dtype="uint8",
                shape=mask_shape,
                fillvalue=1,
            )
        data_providers.append(
            gunpowder.Hdf5Source(
                h5_filepath,
                datasets={
                    VolumeTypes.RAW: 'volumes/raw',
                    VolumeTypes.GT_LABELS: 'volumes/labels/neuron_ids',
                    VolumeTypes.GT_MASK: 'volumes/labels/mask',
                },
                resolution=(8, 8, 8),
            )
        )
    dvid_source = DvidSource(
        hostname='slowpoke3',
        port=32788,
        uuid='341',
        raw_array_name='grayscale',
        gt_array_name='groundtruth',
        gt_mask_roi_name="seven_column_eroded7_z_lt_5024",
        resolution=(8, 8, 8),
    )
    data_providers.extend([dvid_source])
    data_providers = tuple(
        provider +
        RandomLocation() +
        Reject(min_masked=0.5) +
        Normalize()
        for provider in data_providers
    )

    # create a batch provider by concatenation of filters
    batch_provider = (
        data_providers +
        RandomProvider() +
        ElasticAugment([20, 20, 20], [0, 0, 0], [0, math.pi / 2.0]) +
        SimpleAugment(transpose_only_xy=False) +
        GrowBoundary(steps=2, only_xy=False) +
        AddGtAffinities(affinity_neighborhood) +
        BalanceAffinityLabels() +
        SplitAndRenumberSegmentationLabels() +
        IntensityAugment(0.9, 1.1, -0.1, 0.1, z_section_wise=False) +
        PreCache(
            request,
            cache_size=11,
            num_workers=10) +
        Train(solver_parameters, use_gpu=0) +
        Typecast(volume_dtypes={
            VolumeTypes.GT_LABELS: np.dtype("uint32"),
            VolumeTypes.GT_MASK: np.dtype("uint8"),
            VolumeTypes.LOSS_SCALE: np.dtype("float32"),
        }, safe=True) +
        Snapshot(every=50, output_filename='batch_{id}.hdf') +
        PrintProfilingStats(every=50)
    )

    n = 500000
    print("Training for", n, "iterations")

    with gunpowder.build(batch_provider) as pipeline:
        for i in range(n):
            pipeline.request_batch(request)

    print("Finished")
Example #7
0
def predict(iteration):

    ##################
    # DECLARE ARRAYS #
    ##################

    # raw intensities
    raw = gp.ArrayKey('RAW')

    # the predicted affinities
    pred_affs = gp.ArrayKey('PRED_AFFS')

    ####################
    # DECLARE REQUESTS #
    ####################

    with open('test_net_config.json', 'r') as f:
        net_config = json.load(f)

    # get the input and output size in world units (nm, in this case)
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size
    context = input_size - output_size

    # formulate the request for what a batch should contain
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(pred_affs, output_size)

    #############################
    # ASSEMBLE TESTING PIPELINE #
    #############################

    source = gp.Hdf5Source('sample_A_padded_20160501.hdf',
                           datasets={raw: 'volumes/raw'})

    # get the ROI provided for raw (we need it later to calculate the ROI in
    # which we can make predictions)
    with gp.build(source):
        raw_roi = source.spec[raw].roi

    pipeline = (

        # read from HDF5 file
        source +

        # convert raw to float in [0, 1]
        gp.Normalize(raw) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(
            graph='test_net.meta',
            checkpoint='train_net_checkpoint_%d' % iteration,
            inputs={net_config['raw']: raw},
            outputs={net_config['pred_affs']: pred_affs},
            array_specs={
                pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))
            }) +

        # store all passing batches in the same HDF5 file
        gp.Hdf5Write({
            raw: '/volumes/raw',
            pred_affs: '/volumes/pred_affs',
        },
                     output_filename='predictions_sample_A.hdf',
                     compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request))

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
Example #8
0
def predict(iteration,
            raw_file,
            raw_dataset,
            out_file,
            db_host,
            db_name,
            worker_config,
            network_config,
            out_properties={},
            **kwargs):
    setup_dir = os.path.dirname(os.path.realpath(__file__))

    with open(
            os.path.join(setup_dir,
                         '{}_net_config.json'.format(network_config)),
            'r') as f:
        net_config = json.load(f)

    # voxels
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    # nm
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size

    parameterfile = os.path.join(setup_dir, 'parameter.json')
    if os.path.exists(parameterfile):
        with open(parameterfile, 'r') as f:
            parameters = json.load(f)
    else:
        parameters = {}

    raw = gp.ArrayKey('RAW')
    pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS')
    pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR')

    chunk_request = gp.BatchRequest()
    chunk_request.add(raw, input_size)
    chunk_request.add(pred_postpre_vectors, output_size)
    chunk_request.add(pred_post_indicator, output_size)

    d_property = out_properties[
        'pred_partner_vectors'] if 'pred_partner_vectors' in out_properties else None
    m_property = out_properties[
        'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None

    # Hdf5Source
    if raw_file.endswith('.hdf'):
        pipeline = gp.Hdf5Source(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'):
        pipeline = gp.ZarrSource(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    else:
        raise RuntimeError('unknwon input data format {}'.format(raw_file))

    pipeline += gp.Pad(raw, size=None)

    pipeline += gp.Normalize(raw)

    pipeline += gp.IntensityScaleShift(raw, 2, -1)

    pipeline += gp.tensorflow.Predict(
        os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration),
        inputs={net_config['raw']: raw},
        outputs={
            net_config['pred_syn_indicator_out']: pred_post_indicator,
            net_config['pred_partner_vectors']: pred_postpre_vectors
        },
        graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config)))
    d_scale = parameters['d_scale'] if 'd_scale' in parameters else None
    if d_scale != 1 and d_scale is not None:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors, 1. / d_scale,
                                           0)  # Map back to nm world.
    if m_property is not None and 'scale' in m_property:
        if m_property['scale'] != 1:
            pipeline += gp.IntensityScaleShift(pred_post_indicator,
                                               m_property['scale'], 0)
    if d_property is not None and 'scale' in d_property:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors,
                                           d_property['scale'], 0)
    if d_property is not None and 'dtype' in d_property:
        assert d_property['dtype'] == 'int8' or d_property[
            'dtype'] == 'float32', 'predict not adapted to dtype {}'.format(
                d_property['dtype'])
        if d_property['dtype'] == 'int8':
            pipeline += IntensityScaleShiftClip(pred_postpre_vectors,
                                                1,
                                                0,
                                                clip=(-128, 127))

    pipeline += gp.ZarrWrite(dataset_names={
        pred_post_indicator:
        'volumes/pred_syn_indicator',
        pred_postpre_vectors:
        'volumes/pred_partner_vectors',
    },
                             output_filename=out_file)

    pipeline += gp.PrintProfilingStats(every=10)

    pipeline += gp.DaisyRequestBlocks(
        chunk_request,
        roi_map={
            raw: 'read_roi',
            pred_postpre_vectors: 'write_roi',
            pred_post_indicator: 'write_roi'
        },
        num_workers=worker_config['num_cache_workers'],
        block_done_callback=lambda b, s, d: block_done_callback(
            db_host, db_name, worker_config, b, s, d))

    print("Starting prediction...")
    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
    print("Prediction finished")
Example #9
0
def train(iterations):

    ##################
    # DECLARE ARRAYS #
    ##################

    # raw intensities
    raw = gp.ArrayKey('RAW')

    # objects labelled with unique IDs
    gt_labels = gp.ArrayKey('LABELS')

    # array of per-voxel affinities to direct neighbors
    gt_affs = gp.ArrayKey('AFFINITIES')

    # weights to use to balance the loss
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    # the predicted affinities
    pred_affs = gp.ArrayKey('PRED_AFFS')

    # the gredient of the loss wrt to the predicted affinities
    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')

    ####################
    # DECLARE REQUESTS #
    ####################

    with open('train_net_config.json', 'r') as f:
        net_config = json.load(f)

    # get the input and output size in world units (nm, in this case)
    voxel_size = gp.Coordinate((8, 8, 8))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt_affs, output_size)
    request.add(loss_weights, output_size)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request[pred_affs] = request[gt_affs]
    snapshot_request[pred_affs_gradients] = request[gt_affs]

    ##############################
    # ASSEMBLE TRAINING PIPELINE #
    ##############################

    pipeline = (

        # a tuple of sources, one for each sample (A, B, and C) provided by the
        # CREMI challenge
        tuple(

            # read batches from the HDF5 file
            gp.Hdf5Source(os.path.join(data_dir, 'fib.hdf'),
                          datasets={
                              raw: 'volumes/raw',
                              gt_labels: 'volumes/labels/neuron_ids'
                          }) +

            # convert raw to float in [0, 1]
            gp.Normalize(raw) +

            # chose a random location for each requested batch
            gp.RandomLocation()) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        gp.ElasticAugment([8, 8, 8], [0, 2, 2], [0, math.pi / 2.0],
                          prob_slip=0.05,
                          prob_shift=0.05,
                          max_misalign=25) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(transpose_only=[1, 2]) +

        # scale and shift the intensity of the raw array
        gp.IntensityAugment(raw,
                            scale_min=0.9,
                            scale_max=1.1,
                            shift_min=-0.1,
                            shift_max=0.1,
                            z_section_wise=True) +

        # grow a boundary between labels
        gp.GrowBoundary(gt_labels, steps=3, only_xy=True) +

        # convert labels into affinities between voxels
        gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels,
                         gt_affs) +

        # create a weight array that balances positive and negative samples in
        # the affinity array
        gp.BalanceLabels(gt_affs, loss_weights) +

        # pre-cache batches from the point upstream
        gp.PreCache(cache_size=10, num_workers=5) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            'train_net',
            net_config['optimizer'],
            net_config['loss'],
            inputs={
                net_config['raw']: raw,
                net_config['gt_affs']: gt_affs,
                net_config['loss_weights']: loss_weights
            },
            outputs={net_config['pred_affs']: pred_affs},
            gradients={net_config['pred_affs']: pred_affs_gradients},
            save_every=10000) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                gt_labels: '/volumes/labels/neuron_ids',
                gt_affs: '/volumes/labels/affs',
                pred_affs: '/volumes/pred_affs',
                pred_affs_gradients: '/volumes/pred_affs_gradients'
            },
            output_dir='snapshots',
            output_filename='batch_{iteration}.hdf',
            every=1000,
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=1000))

    #########
    # TRAIN #
    #########

    print("Training for", iterations, "iterations")

    with gp.build(pipeline):
        for i in range(iterations):
            pipeline.request_batch(request)

    print("Finished")
Example #10
0
merged = tf.summary.merge_all()
tf.train.export_meta_graph(filename='unet.meta')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

###############################
# CREMI DATA IS STORED Z,X,Y  #
#           z   x    y        #
# voxelsize=(40 ,4   ,4)      #
# size  =   (125,1250,1250)   #
###############################

# import source
source = gp.Hdf5Source('data_with_mask.hdf', {
    raw: 'volumes/raw',
    gt: 'volumes/labels',
    mask: 'volumes/masks'
})

# define output snapshot
snapshot_request = BatchRequest()
snapshot_request.add(grad, (output_shape[0] * voxel_size[0], output_shape[1] *
                            voxel_size[1], output_shape[2] * voxel_size[2]))
snapshot_request.add(prediction,
                     (output_shape[0] * voxel_size[0], output_shape[1] *
                      voxel_size[1], output_shape[2] * voxel_size[2]))

# define pipeline
training_pipeline = (
    source + gp.RandomLocation() +
    # gp.SimpleAugment() +
Example #11
0
# prepare requests for scanning (i.e. chunks) and overall
scan_request = gp.BatchRequest()
scan_request[stardists] = gp.Roi(
    gp.Coordinate((0, 0, 0)),
    gp.Coordinate((40, 100, 100)) * gp.Coordinate((40, 8, 8)))
voxel_size = gp.Coordinate((40, 4, 4))
request = gp.BatchRequest(
)  # empty request will loop over whole area with scanning
request[stardists] = gp.Roi(
    gp.Coordinate((40, 200, 200)) * gp.Coordinate((40, 8, 8)),
    gp.Coordinate((40, 100, 100)) * gp.Coordinate((40, 8, 8)) * gp.Coordinate(
        (2, 2, 2)))
source = gp.Hdf5Source(
    os.path.join(directory, "sample_A.hdf"),
    datasets={
        labels: "volumes/labels/neuron_ids"  # reads resolution from file
    })

stardist_gen = gpstardist.AddStarDist3D(
    labels,
    stardists,
    rays=96,
    anisotropy=(40, 4, 4),
    grid=(1, 2, 2),
    unlabeled_id=int(np.array(-3).astype(np.uint64)),
    max_dist=max_dist,
)

writer = gp.ZarrWrite(
    output_dir=directory,
Example #12
0
def predict(data_dir,
            train_dir,
            iteration,
            sample,
            test_net_name='train_net',
            train_net_name='train_net',
            output_dir='.',
            clip_max=1000):

    if "hdf" not in data_dir:
        return

    print("Predicting ", sample)
    print(
        'checkpoint: ',
        os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration))

    checkpoint = os.path.join(train_dir,
                              train_net_name + '_checkpoint_%d' % iteration)

    with open(os.path.join(train_dir, test_net_name + '_config.json'),
              'r') as f:
        net_config = json.load(f)

    with open(os.path.join(train_dir, test_net_name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    # ArrayKeys
    raw = gp.ArrayKey('RAW')
    pred_mask = gp.ArrayKey('PRED_MASK')

    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    voxel_size = gp.Coordinate((1, 1, 1))
    context = gp.Coordinate(input_shape - output_shape) / 2

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape, voxel_size=voxel_size)
    request.add(pred_mask, output_shape, voxel_size=voxel_size)

    print("chunk request %s" % request)

    source = (gp.Hdf5Source(
        data_dir,
        datasets={
            raw: sample + '/raw',
        },
        array_specs={
            raw:
            gp.ArraySpec(
                interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
        },
    ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) +
              gp.Normalize(raw, factor=1.0 / clip_max) +
              gp.IntensityScaleShift(raw, 2, -1))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        print("raw_roi: %s" % raw_roi)
        sample_shape = raw_roi.grow(-context, -context).get_shape()

    print(sample_shape)

    # create zarr file with corresponding chunk size
    zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w')

    zf.create('volumes/pred_mask',
              shape=sample_shape,
              chunks=output_shape,
              dtype=np.float16)
    zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1]

    pipeline = (
        source + gp.tensorflow.Predict(
            graph=os.path.join(train_dir, test_net_name + '.meta'),
            checkpoint=checkpoint,
            inputs={
                net_names['raw']: raw,
            },
            outputs={
                net_names['pred']: pred_mask,
            },
            array_specs={
                pred_mask:
                gp.ArraySpec(roi=raw_roi.grow(-context, -context),
                             voxel_size=voxel_size),
            },
            max_shared_memory=1024 * 1024 * 1024) +
        Convert(pred_mask, np.float16) + gp.ZarrWrite(
            dataset_names={
                pred_mask: 'volumes/pred_mask',
            },
            output_dir=output_dir,
            output_filename=sample + '.zarr',
            compression_type='gzip',
            dataset_dtypes={pred_mask: np.float16}) +

        # show a summary of time spend in each node every x iterations
        gp.PrintProfilingStats(every=100) +
        gp.Scan(reference=request, num_workers=5, cache_size=50))

    with gp.build(pipeline):

        pipeline.request_batch(gp.BatchRequest())
Example #13
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for fused volume
    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    labels_fg = gp.ArrayKey('LABELS_FG')

    # array keys for base volume
    raw_base = gp.ArrayKey('RAW_BASE')
    labels_base = gp.ArrayKey('LABELS_BASE')
    swc_base = gp.PointsKey('SWC_BASE')
    swc_center_base = gp.PointsKey('SWC_CENTER_BASE')

    # array keys for add volume
    raw_add = gp.ArrayKey('RAW_ADD')
    labels_add = gp.ArrayKey('LABELS_ADD')
    swc_add = gp.PointsKey('SWC_ADD')
    swc_center_add = gp.PointsKey('SWC_CENTER_ADD')

    # output data
    fg = gp.ArrayKey('FG')
    gradient_fg = gp.ArrayKey('GRADIENT_FG')
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    voxel_size = gp.Coordinate((3, 3, 3))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_fg, output_size)
    request.add(loss_weights, output_size)

    request.add(swc_center_base, output_size)
    request.add(swc_base, input_size)

    request.add(swc_center_add, output_size)
    request.add(swc_add, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(fg, output_size)
    snapshot_request.add(labels_fg, output_size)
    snapshot_request.add(gradient_fg, output_size)
    snapshot_request.add(raw_base, input_size)
    snapshot_request.add(raw_add, input_size)
    snapshot_request.add(labels_base, input_size)
    snapshot_request.add(labels_add, input_size)

    # data source for "base" volume
    data_sources_base = tuple()
    data_sources_base += tuple(
        (gp.Hdf5Source(file,
                       datasets={
                           raw_base: '/volume',
                       },
                       array_specs={
                           raw_base:
                           gp.ArraySpec(interpolatable=True,
                                        voxel_size=voxel_size,
                                        dtype=np.uint16),
                       },
                       channels_first=False),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center_base, swc_base),
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton(
            points=swc_base,
            array=labels_base,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            iteration=10) for file in files)
    data_sources_base += gp.RandomProvider()

    # data source for "add" volume
    data_sources_add = tuple()
    data_sources_add += tuple(
        (gp.Hdf5Source(file,
                       datasets={
                           raw_add: '/volume',
                       },
                       array_specs={
                           raw_add:
                           gp.ArraySpec(interpolatable=True,
                                        voxel_size=voxel_size,
                                        dtype=np.uint16),
                       },
                       channels_first=False),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center_add, swc_add),
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton(
            points=swc_add,
            array=labels_add,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            iteration=1) for file in files)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add
                          ]) + gp.MergeProvider()

    pipeline = (
        data_sources + FusionAugment(raw_base,
                                     raw_add,
                                     labels_base,
                                     labels_add,
                                     raw,
                                     labels,
                                     blend_mode='labels_mask',
                                     blend_smoothness=10,
                                     num_blended_objects=0) +

        # augment
        gp.ElasticAugment([10, 10, 10], [1, 1, 1], [0, math.pi / 2.0],
                          subsample=8) +
        gp.SimpleAugment(mirror_only=[2], transpose_only=[]) +
        gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(labels, labels_fg) +
        gp.BalanceLabels(labels_fg, loss_weights) +

        # train
        gp.PreCache(cache_size=40, num_workers=10) +
        gp.tensorflow.Train('./train_net',
                            optimizer=net_names['optimizer'],
                            loss=net_names['loss'],
                            inputs={
                                net_names['raw']: raw,
                                net_names['labels_fg']: labels_fg,
                                net_names['loss_weights']: loss_weights,
                            },
                            outputs={
                                net_names['fg']: fg,
                            },
                            gradients={
                                net_names['fg']: gradient_fg,
                            },
                            save_every=100) +

        # visualize
        gp.Snapshot(output_filename='snapshot_{iteration}.hdf',
                    dataset_names={
                        raw: 'volumes/raw',
                        raw_base: 'volumes/raw_base',
                        raw_add: 'volumes/raw_add',
                        labels: 'volumes/labels',
                        labels_base: 'volumes/labels_base',
                        labels_add: 'volumes/labels_add',
                        fg: 'volumes/fg',
                        labels_fg: 'volumes/labels_fg',
                        gradient_fg: 'volumes/gradient_fg',
                    },
                    additional_request=snapshot_request,
                    every=10) + gp.PrintProfilingStats(every=100))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Example #14
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for fused volume
    raw = gp.ArrayKey("RAW")
    labels = gp.ArrayKey("LABELS")
    labels_fg = gp.ArrayKey("LABELS_FG")

    # array keys for base volume
    raw_base = gp.ArrayKey("RAW_BASE")
    labels_base = gp.ArrayKey("LABELS_BASE")
    swc_base = gp.PointsKey("SWC_BASE")
    swc_center_base = gp.PointsKey("SWC_CENTER_BASE")

    # array keys for add volume
    raw_add = gp.ArrayKey("RAW_ADD")
    labels_add = gp.ArrayKey("LABELS_ADD")
    swc_add = gp.PointsKey("SWC_ADD")
    swc_center_add = gp.PointsKey("SWC_CENTER_ADD")

    # output data
    fg = gp.ArrayKey("FG")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    voxel_size = gp.Coordinate((4, 1, 1))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size
    output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_fg, output_size)
    request.add(loss_weights, output_size)
    request.add(swc_center_base, output_size)
    request.add(swc_center_add, output_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(fg, output_size)
    snapshot_request.add(labels_fg, output_size)
    snapshot_request.add(gradient_fg, output_size)
    snapshot_request.add(raw_base, input_size)
    snapshot_request.add(raw_add, input_size)
    snapshot_request.add(labels_base, input_size)
    snapshot_request.add(labels_add, input_size)

    # data source for "base" volume
    data_sources_base = tuple(
        (
            gp.Hdf5Source(
                filename,
                datasets={raw_base: "/volume"},
                array_specs={
                    raw_base:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
                channels_first=False,
            ),
            SwcSource(
                filename=filename,
                dataset="/reconstruction",
                points=(swc_center_base, swc_base),
                scale=voxel_size,
            ),
        ) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton(
            points=swc_base,
            array=labels_base,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            radius=5.0,
        ) for filename in files)

    # data source for "add" volume
    data_sources_add = tuple(
        (
            gp.Hdf5Source(
                file,
                datasets={raw_add: "/volume"},
                array_specs={
                    raw_add:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
                channels_first=False,
            ),
            SwcSource(
                filename=file,
                dataset="/reconstruction",
                points=(swc_center_add, swc_add),
                scale=voxel_size,
            ),
        ) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton(
            points=swc_add,
            array=labels_add,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            radius=5.0,
        ) for file in files)
    data_sources = (
        (data_sources_base + gp.RandomProvider()),
        (data_sources_add + gp.RandomProvider()),
    ) + gp.MergeProvider()

    pipeline = (
        data_sources + FusionAugment(
            raw_base,
            raw_add,
            labels_base,
            labels_add,
            raw,
            labels,
            blend_mode="labels_mask",
            blend_smoothness=10,
            num_blended_objects=0,
        ) +
        # augment
        gp.ElasticAugment([40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0],
                          subsample=4) +
        gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +
        gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(labels, labels_fg) +
        gp.BalanceLabels(labels_fg, loss_weights) +
        # train
        gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train(
            "./train_net",
            optimizer=net_names["optimizer"],
            loss=net_names["loss"],
            inputs={
                net_names["raw"]: raw,
                net_names["labels_fg"]: labels_fg,
                net_names["loss_weights"]: loss_weights,
            },
            outputs={net_names["fg"]: fg},
            gradients={net_names["fg"]: gradient_fg},
            save_every=100000,
        ) +
        # visualize
        gp.Snapshot(
            output_filename="snapshot_{iteration}.hdf",
            dataset_names={
                raw: "volumes/raw",
                raw_base: "volumes/raw_base",
                raw_add: "volumes/raw_add",
                labels: "volumes/labels",
                labels_base: "volumes/labels_base",
                labels_add: "volumes/labels_add",
                fg: "volumes/fg",
                labels_fg: "volumes/labels_fg",
                gradient_fg: "volumes/gradient_fg",
            },
            additional_request=snapshot_request,
            every=100,
        ) + gp.PrintProfilingStats(every=100))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Example #15
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_instances = gp.ArrayKey('GT_INSTANCES')
    gt_mask = gp.ArrayKey('GT_MASK')
    pred_mask = gp.ArrayKey('PRED_MASK')
    #loss_weights = gp.ArrayKey('LOSS_WEIGHTS')
    loss_gradients = gp.ArrayKey('LOSS_GRADIENTS')

    # array keys for base and add volume
    raw_base = gp.ArrayKey('RAW_BASE')
    gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE')
    gt_mask_base = gp.ArrayKey('GT_MASK_BASE')
    raw_add = gp.ArrayKey('RAW_ADD')
    gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD')
    gt_mask_add = gp.ArrayKey('GT_MASK_ADD')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_instances, output_shape)
    request.add(gt_mask, output_shape)
    #request.add(loss_weights, output_shape)
    request.add(raw_base, input_shape)
    request.add(raw_add, input_shape)
    request.add(gt_mask_base, output_shape)
    request.add(gt_mask_add, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    #snapshot_request.add(raw_base, input_shape)
    #snapshot_request.add(raw_add, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    #snapshot_request.add(gt_mask_base, output_shape)
    #snapshot_request.add(gt_mask_add, output_shape)
    snapshot_request.add(pred_mask, output_shape)
    snapshot_request.add(loss_gradients, output_shape)

    # specify data source
    # data source for base volume
    data_sources_base = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_base += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_base: sample + '/raw',
                        gt_instances_base: sample + '/gt',
                        gt_mask_base: sample + '/fg',
                    },
                    array_specs={
                        raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_base, np.uint8) +
                gp.Pad(raw_base, context) +
                gp.Pad(gt_instances_base, context) +
                gp.Pad(gt_mask_base, context) +
                gp.RandomLocation(min_masked=0.005,  mask=gt_mask_base)
                #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.)
                for sample in f)
    data_sources_base += gp.RandomProvider()

    # data source for add volume
    data_sources_add = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_add += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_add: sample + '/raw',
                        gt_instances_add: sample + '/gt',
                        gt_mask_add: sample + '/fg',
                    },
                    array_specs={
                        raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_add, np.uint8) +
                gp.Pad(raw_add, context) +
                gp.Pad(gt_instances_add, context) +
                gp.Pad(gt_mask_add, context) +
                gp.RandomLocation() +
                gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95)
                for sample in f)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider()

    pipeline = (
            data_sources +
            nl.FusionAugment(
                raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances,
                blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0
            ) +
            BinarizeLabels(gt_instances, gt_mask) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2, -1) +
            #gp.BalanceLabels(gt_mask, loss_weights) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=10) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt']: gt_mask,
                    #net_names['loss_weights']: loss_weights,
                },
                outputs={
                    net_names['pred']: pred_mask,
                },
                gradients={
                    net_names['output']: loss_gradients,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    pred_mask: 'volumes/pred_mask',
                    gt_mask: 'volumes/gt_mask',
                    #loss_weights: 'volumes/loss_weights',
                    loss_gradients: 'volumes/loss_gradients',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2500) +
            gp.PrintProfilingStats(every=1000)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)