def predict(data_dir,
            train_dir,
            iteration,
            sample,
            test_net_name='train_net',
            train_net_name='train_net',
            output_dir='.',
            clip_max=1000):

    if "hdf" not in data_dir:
        return

    print("Predicting ", sample)
    print(
        'checkpoint: ',
        os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration))

    checkpoint = os.path.join(train_dir,
                              train_net_name + '_checkpoint_%d' % iteration)

    with open(os.path.join(train_dir, test_net_name + '_config.json'),
              'r') as f:
        net_config = json.load(f)

    with open(os.path.join(train_dir, test_net_name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    # ArrayKeys
    raw = gp.ArrayKey('RAW')
    pred_mask = gp.ArrayKey('PRED_MASK')

    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    voxel_size = gp.Coordinate((1, 1, 1))
    context = gp.Coordinate(input_shape - output_shape) / 2

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape, voxel_size=voxel_size)
    request.add(pred_mask, output_shape, voxel_size=voxel_size)

    print("chunk request %s" % request)

    source = (gp.Hdf5Source(
        data_dir,
        datasets={
            raw: sample + '/raw',
        },
        array_specs={
            raw:
            gp.ArraySpec(
                interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
        },
    ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) +
              gp.Normalize(raw, factor=1.0 / clip_max) +
              gp.IntensityScaleShift(raw, 2, -1))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        print("raw_roi: %s" % raw_roi)
        sample_shape = raw_roi.grow(-context, -context).get_shape()

    print(sample_shape)

    # create zarr file with corresponding chunk size
    zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w')

    zf.create('volumes/pred_mask',
              shape=sample_shape,
              chunks=output_shape,
              dtype=np.float16)
    zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1]

    pipeline = (
        source + gp.tensorflow.Predict(
            graph=os.path.join(train_dir, test_net_name + '.meta'),
            checkpoint=checkpoint,
            inputs={
                net_names['raw']: raw,
            },
            outputs={
                net_names['pred']: pred_mask,
            },
            array_specs={
                pred_mask:
                gp.ArraySpec(roi=raw_roi.grow(-context, -context),
                             voxel_size=voxel_size),
            },
            max_shared_memory=1024 * 1024 * 1024) +
        Convert(pred_mask, np.float16) + gp.ZarrWrite(
            dataset_names={
                pred_mask: 'volumes/pred_mask',
            },
            output_dir=output_dir,
            output_filename=sample + '.zarr',
            compression_type='gzip',
            dataset_dtypes={pred_mask: np.float16}) +

        # show a summary of time spend in each node every x iterations
        gp.PrintProfilingStats(every=100) +
        gp.Scan(reference=request, num_workers=5, cache_size=50))

    with gp.build(pipeline):

        pipeline.request_batch(gp.BatchRequest())
Exemple #2
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_mask = gp.ArrayKey('GT_MASK')
    gt_dt = gp.ArrayKey('GT_DT')
    pred_dt = gp.ArrayKey('PRED_DT')
    loss_gradient = gp.ArrayKey('LOSS_GRADIENT')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_mask, output_shape)
    request.add(gt_dt, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    snapshot_request.add(gt_dt, output_shape)
    snapshot_request.add(pred_dt, output_shape)
    snapshot_request.add(loss_gradient, output_shape)

    # specify data source
    data_sources = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw: sample + '/raw',
                        gt_mask: sample + '/fg'
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask, np.uint8) +
                gp.Pad(raw, context) +
                gp.Pad(gt_mask, context) +
                gp.RandomLocation()
                for sample in f)

    pipeline = (
            data_sources +
            gp.RandomProvider() +
            gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) +
            DistanceTransform(gt_mask, gt_dt, 3) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2,-1) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=5) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt_dt']: gt_dt,
                },
                outputs={
                    net_names['pred_dt']: pred_dt,
                },
                gradients={
                    net_names['pred_dt']: loss_gradient,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    gt_mask: 'volumes/gt_mask',
                    gt_dt: 'volumes/gt_dt',
                    pred_dt: 'volumes/pred_dt',
                    loss_gradient: 'volumes/gradient',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2000) +
            gp.PrintProfilingStats(every=500)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_instances = gp.ArrayKey('GT_INSTANCES')
    gt_mask = gp.ArrayKey('GT_MASK')
    pred_mask = gp.ArrayKey('PRED_MASK')
    #loss_weights = gp.ArrayKey('LOSS_WEIGHTS')
    loss_gradients = gp.ArrayKey('LOSS_GRADIENTS')

    # array keys for base and add volume
    raw_base = gp.ArrayKey('RAW_BASE')
    gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE')
    gt_mask_base = gp.ArrayKey('GT_MASK_BASE')
    raw_add = gp.ArrayKey('RAW_ADD')
    gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD')
    gt_mask_add = gp.ArrayKey('GT_MASK_ADD')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_instances, output_shape)
    request.add(gt_mask, output_shape)
    #request.add(loss_weights, output_shape)
    request.add(raw_base, input_shape)
    request.add(raw_add, input_shape)
    request.add(gt_mask_base, output_shape)
    request.add(gt_mask_add, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    #snapshot_request.add(raw_base, input_shape)
    #snapshot_request.add(raw_add, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    #snapshot_request.add(gt_mask_base, output_shape)
    #snapshot_request.add(gt_mask_add, output_shape)
    snapshot_request.add(pred_mask, output_shape)
    snapshot_request.add(loss_gradients, output_shape)

    # specify data source
    # data source for base volume
    data_sources_base = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_base += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_base: sample + '/raw',
                        gt_instances_base: sample + '/gt',
                        gt_mask_base: sample + '/fg',
                    },
                    array_specs={
                        raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_base, np.uint8) +
                gp.Pad(raw_base, context) +
                gp.Pad(gt_instances_base, context) +
                gp.Pad(gt_mask_base, context) +
                gp.RandomLocation(min_masked=0.005,  mask=gt_mask_base)
                #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.)
                for sample in f)
    data_sources_base += gp.RandomProvider()

    # data source for add volume
    data_sources_add = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_add += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_add: sample + '/raw',
                        gt_instances_add: sample + '/gt',
                        gt_mask_add: sample + '/fg',
                    },
                    array_specs={
                        raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_add, np.uint8) +
                gp.Pad(raw_add, context) +
                gp.Pad(gt_instances_add, context) +
                gp.Pad(gt_mask_add, context) +
                gp.RandomLocation() +
                gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95)
                for sample in f)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider()

    pipeline = (
            data_sources +
            nl.FusionAugment(
                raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances,
                blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0
            ) +
            BinarizeLabels(gt_instances, gt_mask) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2, -1) +
            #gp.BalanceLabels(gt_mask, loss_weights) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=10) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt']: gt_mask,
                    #net_names['loss_weights']: loss_weights,
                },
                outputs={
                    net_names['pred']: pred_mask,
                },
                gradients={
                    net_names['output']: loss_gradients,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    pred_mask: 'volumes/pred_mask',
                    gt_mask: 'volumes/gt_mask',
                    #loss_weights: 'volumes/loss_weights',
                    loss_gradients: 'volumes/loss_gradients',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2500) +
            gp.PrintProfilingStats(every=1000)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)