Example #1
0
def create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter,
                  gt_neurons):
    data_sources = tuple((
        Hdf5PointsSource(os.path.join(data_dir_syn, sample + '.hdf'),
                         datasets={
                             presyn: 'annotations',
                             postsyn: 'annotations'
                         },
                         rois={
                             presyn: cremi_roi,
                             postsyn: cremi_roi
                         }),
        Hdf5PointsSource(
            os.path.join(data_dir_syn, sample + '.hdf'),
            datasets={dummypostsyn: 'annotations'},
            rois={
                # presyn: cremi_roi,
                dummypostsyn: cremi_roi
            },
            kind='postsyn'),
        gp.Hdf5Source(os.path.join(data_dir, sample + '.hdf'),
                      datasets={
                          raw: 'volumes/raw',
                          gt_neurons: 'volumes/labels/neuron_ids',
                      },
                      array_specs={
                          raw: gp.ArraySpec(interpolatable=True),
                          gt_neurons: gp.ArraySpec(interpolatable=False),
                      })))
    source_pip = data_sources + gp.MergeProvider() + gp.Normalize(
        raw) + gp.RandomLocation(ensure_nonempty=dummypostsyn,
                                 p_nonempty=parameter['reject_probability'])
    return source_pip
Example #2
0
    def add_target(self, gt, target):

        return (gp.AddAffinities(affinity_neighborhood=self.neighborhood,
                                 labels=gt,
                                 affinities=target) +
                # ensure affs are float
                gp.Normalize(target, factor=1.0))
Example #3
0
def predict(iteration,path_to_dataGP):
   
  
    input_size = (8, 96, 96)
    output_size = (4, 64, 64)
    amount_size = gp.Coordinate((2, 16, 16))
    model = SpineUNet(crop_output='output_size')

    raw = gp.ArrayKey('RAW')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

                                
    reference_request = gp.BatchRequest()
    reference_request.add(raw, input_size)
    reference_request.add(affs_predicted, output_size)
    
    source = gp.ZarrSource(
        path_to_dataGP,
        {
            raw: 'validate/sample1/raw'
        } 
    )
  
    with gp.build(source):
        source_roi = source.spec[raw].roi
    request = gp.BatchRequest()
    request[raw] = gp.ArraySpec(roi=source_roi)
    request[affs_predicted] = gp.ArraySpec(roi=source_roi)

    pipeline = (
        source +
       
        gp.Pad(raw,amount_size) +
        gp.Normalize(raw) +
        # raw: (d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        gp_torch.Predict(
            model,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Scan(reference_request)
    )

    with gp.build(pipeline):
        prediction = pipeline.request_batch(request)

    return prediction[raw].data, prediction[affs_predicted].data
Example #4
0
def build_pipeline(data_dir, model, checkpoint_file, input_size, output_size,
                   raw, labels, affs_predicted, dataset_shape, num_samples,
                   sample_size):

    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['model_state_dict'])

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(affs_predicted, output_size)
    scan_request.add(labels, output_size)

    pipeline = (
        gp.ZarrSource(str(data_dir), {
            raw: 'validate/raw',
            labels: 'validate/gt'
        }) + gp.Pad(raw, size=None) + gp.Normalize(raw) +
        # raw: (s, h, w)
        # labels: (s, h, w)
        train.AddChannelDim(raw) +
        # raw: (c=1, s, h, w)
        # labels: (s, h, w)
        train.TransposeDims(raw, (1, 0, 2, 3)) +
        # raw: (s, c=1, h, w)
        # labels: (s, h, w)
        Predict(model=model, inputs={'x': raw}, outputs={0: affs_predicted}) +
        # raw: (s, c=1, h, w)
        # affs_predicted: (s, c=2, h, w)
        # labels: (s, h, w)
        train.TransposeDims(raw, (1, 0, 2, 3)) + train.RemoveChannelDim(raw) +
        # raw: (s, h, w)
        # affs_predicted: (s, c=2, h, w)
        # labels: (s, h, w)
        gp.PrintProfilingStats(every=100) + gp.Scan(scan_request))

    return pipeline
Example #5
0
def train_until(max_iteration):

    in_channels = 1
    num_fmaps = 12
    fmap_inc_factors = 6
    downsample_factors = [(1, 3, 3), (1, 3, 3), (3, 3, 3)]

    unet = UNet(in_channels,
                num_fmaps,
                fmap_inc_factors,
                downsample_factors,
                constant_upsample=True)

    model = Convolve(unet, 12, 1)

    loss = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

    # start of gunpowder part:

    raw = gp.ArrayKey('RAW')
    points = gp.GraphKey('POINTS')
    groundtruth = gp.ArrayKey('RASTER')
    prediction = gp.ArrayKey('PRED_POINT')
    grad = gp.ArrayKey('GRADIENT')

    voxel_size = gp.Coordinate((40, 4, 4))

    input_shape = (96, 430, 430)
    output_shape = (60, 162, 162)

    input_size = gp.Coordinate(input_shape) * voxel_size
    output_size = gp.Coordinate(output_shape) * voxel_size

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(points, output_size)
    request.add(groundtruth, output_size)
    request.add(prediction, output_size)
    request.add(grad, output_size)

    pos_sources = tuple(
        gp.ZarrSource(filename, {raw: 'volumes/raw'},
                      {raw: gp.ArraySpec(interpolatable=True)}) +
        AddCenterPoint(points, raw) + gp.Pad(raw, None) +
        gp.RandomLocation(ensure_nonempty=points)
        for filename in pos_samples) + gp.RandomProvider()
    neg_sources = tuple(
        gp.ZarrSource(filename, {raw: 'volumes/raw'},
                      {raw: gp.ArraySpec(interpolatable=True)}) +
        AddNoPoint(points, raw) + gp.RandomLocation()
        for filename in neg_samples) + gp.RandomProvider()

    data_sources = (pos_sources, neg_sources)
    data_sources += gp.RandomProvider(probabilities=[0.9, 0.1])
    data_sources += gp.Normalize(raw)

    train_pipeline = data_sources
    train_pipeline += gp.ElasticAugment(control_point_spacing=[4, 40, 40],
                                        jitter_sigma=[0, 2, 2],
                                        rotation_interval=[0, math.pi / 2.0],
                                        prob_slip=0.05,
                                        prob_shift=0.05,
                                        max_misalign=10,
                                        subsample=8)
    train_pipeline += gp.SimpleAugment(transpose_only=[1, 2])

    train_pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, \
            z_section_wise=True)
    train_pipeline += gp.RasterizePoints(
        points,
        groundtruth,
        array_spec=gp.ArraySpec(voxel_size=voxel_size),
        settings=gp.RasterizationSettings(radius=(100, 100, 100), mode='peak'))
    train_pipeline += gp.PreCache(cache_size=40, num_workers=10)

    train_pipeline += Reshape(raw, (1, 1) + input_shape)
    train_pipeline += Reshape(groundtruth, (1, 1) + output_shape)

    train_pipeline += gp_torch.Train(model=model,
                                     loss=loss,
                                     optimizer=optimizer,
                                     inputs={'x': raw},
                                     outputs={0: prediction},
                                     loss_inputs={
                                         0: prediction,
                                         1: groundtruth
                                     },
                                     gradients={0: grad},
                                     save_every=1000,
                                     log_dir='log')

    train_pipeline += Reshape(raw, input_shape)
    train_pipeline += Reshape(groundtruth, output_shape)
    train_pipeline += Reshape(prediction, output_shape)
    train_pipeline += Reshape(grad, output_shape)

    train_pipeline += gp.Snapshot(
        {
            raw: 'volumes/raw',
            groundtruth: 'volumes/groundtruth',
            prediction: 'volumes/prediction',
            grad: 'volumes/gradient'
        },
        every=500,
        output_filename='test_{iteration}.hdf')
    train_pipeline += gp.PrintProfilingStats(every=10)

    with gp.build(train_pipeline):
        for i in range(max_iteration):
            train_pipeline.request_batch(request)
Example #6
0
def predict_2d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = raw_data.shape
    dataset_roi = raw_data.roi
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError(
            "For 2D validation, please provide a 3D array where the first "
            "dimension indexes the samples.")

    num_samples = raw_data.num_samples

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    if gt_data:
        sources = (raw_data.get_source(raw, overwrite_spec=spec),
                   gt_data.get_source(gt, overwrite_spec=spec))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw, overwrite_spec=spec)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    # target: ([c,] s, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    if gt_data and predictor.target_channels == 0:
        pipeline += AddChannelDim(target)
    # raw: (c, s, h, w)
    # gt: ([c,] s, h, w)
    # target: (c, s, h, w)
    pipeline += TransposeDims(raw, (1, 0, 2, 3))
    if gt_data:
        pipeline += TransposeDims(target, (1, 0, 2, 3))
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    # prediction: (s, c, h, w)
    pipeline += gp.Scan(scan_request)

    total_request = gp.BatchRequest()
    total_request.add(raw, sample_size)
    total_request.add(prediction, sample_size)
    if gt_data:
        total_request.add(gt, sample_size)
        total_request.add(target, sample_size)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Example #7
0
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
    seg: gp.ArraySpec(interpolatable=False)
})
sourceC = gp.ZarrSource('../data/cropped_sample_C.zarr', {
    raw: 'raw',
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
    seg: gp.ArraySpec(interpolatable=False)
})

source = (sourceA, sourceB, sourceC) + gp.MergeProvider()
print(source)
normalize = gp.Normalize(raw)
simulate_cages = SimulateCages(raw, seg, out_cage_map, out_density_map, psf,
                               (min_density, max_density), [cage1], 0.5)
add_channel_dim = gp.Stack(1)
stack = gp.Stack(5)
prepare_data = PrepareTrainingData()
train = gp.torch.Train(model,
                       loss,
                       optimizer,
                       inputs={'input': raw},
                       loss_inputs={
                           0: prediction,
                           1: out_cage_map
                       },
                       outputs={0: prediction})
pipeline = (source + normalize + gp.RandomLocation() + simulate_cages +
Example #8
0
def predict(iteration):

    ##################
    # DECLARE ARRAYS #
    ##################

    # raw intensities
    raw = gp.ArrayKey('RAW')

    # the predicted affinities
    pred_affs = gp.ArrayKey('PRED_AFFS')

    ####################
    # DECLARE REQUESTS #
    ####################

    with open('test_net_config.json', 'r') as f:
        net_config = json.load(f)

    # get the input and output size in world units (nm, in this case)
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size
    context = input_size - output_size

    # formulate the request for what a batch should contain
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(pred_affs, output_size)

    #############################
    # ASSEMBLE TESTING PIPELINE #
    #############################

    source = gp.Hdf5Source('sample_A_padded_20160501.hdf',
                           datasets={raw: 'volumes/raw'})

    # get the ROI provided for raw (we need it later to calculate the ROI in
    # which we can make predictions)
    with gp.build(source):
        raw_roi = source.spec[raw].roi

    pipeline = (

        # read from HDF5 file
        source +

        # convert raw to float in [0, 1]
        gp.Normalize(raw) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(
            graph='test_net.meta',
            checkpoint='train_net_checkpoint_%d' % iteration,
            inputs={net_config['raw']: raw},
            outputs={net_config['pred_affs']: pred_affs},
            array_specs={
                pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))
            }) +

        # store all passing batches in the same HDF5 file
        gp.Hdf5Write({
            raw: '/volumes/raw',
            pred_affs: '/volumes/pred_affs',
        },
                     output_filename='predictions_sample_A.hdf',
                     compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request))

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
Example #9
0
def predict(iteration,
            raw_file,
            raw_dataset,
            out_file,
            db_host,
            db_name,
            worker_config,
            network_config,
            out_properties={},
            **kwargs):
    setup_dir = os.path.dirname(os.path.realpath(__file__))

    with open(
            os.path.join(setup_dir,
                         '{}_net_config.json'.format(network_config)),
            'r') as f:
        net_config = json.load(f)

    # voxels
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    # nm
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size

    parameterfile = os.path.join(setup_dir, 'parameter.json')
    if os.path.exists(parameterfile):
        with open(parameterfile, 'r') as f:
            parameters = json.load(f)
    else:
        parameters = {}

    raw = gp.ArrayKey('RAW')
    pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS')
    pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR')

    chunk_request = gp.BatchRequest()
    chunk_request.add(raw, input_size)
    chunk_request.add(pred_postpre_vectors, output_size)
    chunk_request.add(pred_post_indicator, output_size)

    d_property = out_properties[
        'pred_partner_vectors'] if 'pred_partner_vectors' in out_properties else None
    m_property = out_properties[
        'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None

    # Hdf5Source
    if raw_file.endswith('.hdf'):
        pipeline = gp.Hdf5Source(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'):
        pipeline = gp.ZarrSource(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    else:
        raise RuntimeError('unknwon input data format {}'.format(raw_file))

    pipeline += gp.Pad(raw, size=None)

    pipeline += gp.Normalize(raw)

    pipeline += gp.IntensityScaleShift(raw, 2, -1)

    pipeline += gp.tensorflow.Predict(
        os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration),
        inputs={net_config['raw']: raw},
        outputs={
            net_config['pred_syn_indicator_out']: pred_post_indicator,
            net_config['pred_partner_vectors']: pred_postpre_vectors
        },
        graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config)))
    d_scale = parameters['d_scale'] if 'd_scale' in parameters else None
    if d_scale != 1 and d_scale is not None:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors, 1. / d_scale,
                                           0)  # Map back to nm world.
    if m_property is not None and 'scale' in m_property:
        if m_property['scale'] != 1:
            pipeline += gp.IntensityScaleShift(pred_post_indicator,
                                               m_property['scale'], 0)
    if d_property is not None and 'scale' in d_property:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors,
                                           d_property['scale'], 0)
    if d_property is not None and 'dtype' in d_property:
        assert d_property['dtype'] == 'int8' or d_property[
            'dtype'] == 'float32', 'predict not adapted to dtype {}'.format(
                d_property['dtype'])
        if d_property['dtype'] == 'int8':
            pipeline += IntensityScaleShiftClip(pred_postpre_vectors,
                                                1,
                                                0,
                                                clip=(-128, 127))

    pipeline += gp.ZarrWrite(dataset_names={
        pred_post_indicator:
        'volumes/pred_syn_indicator',
        pred_postpre_vectors:
        'volumes/pred_partner_vectors',
    },
                             output_filename=out_file)

    pipeline += gp.PrintProfilingStats(every=10)

    pipeline += gp.DaisyRequestBlocks(
        chunk_request,
        roi_map={
            raw: 'read_roi',
            pred_postpre_vectors: 'write_roi',
            pred_post_indicator: 'write_roi'
        },
        num_workers=worker_config['num_cache_workers'],
        block_done_callback=lambda b, s, d: block_done_callback(
            db_host, db_name, worker_config, b, s, d))

    print("Starting prediction...")
    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
    print("Prediction finished")
Example #10
0
def build_pipeline(
        data_dir,  
        model, 
        save_every,
        batch_size, 
        input_size, 
        output_size,
        raw, 
        labels,
        affs,
        affs_predicted,
        lr=1e-5): 

    dataset_shape = zarr.open(str(data_dir))['train/raw'].shape
    num_samples = dataset_shape[0]
    sample_size = dataset_shape[1:]

    loss = torch.nn.MSELoss()
    optimizer = RAdam(model.parameters(), lr=lr)
    
    pipeline = (
            gp.ZarrSource(
                data_dir,
                {
                    raw: 'train/raw',
                    labels: 'train/gt'
                },
                array_specs={
                    raw: gp.ArraySpec(
                        roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size),
                        voxel_size=(1, 1, 1)),
                    labels: gp.ArraySpec(
                        roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size),
                        voxel_size=(1, 1, 1))
                }) +
            # raw: (d=1, h, w)
            # labels: (d=1, fmap_inc_factors=5h, w)
            gp.RandomLocation() +
            # raw: (d=1, h, w)
            # labels: (d=1, h, w)
            gp.AddAffinities(
                affinity_neighborhood=[(0, 1, 0), (0, 0, 1)],
                labels=labels,
                affinities=affs) +
            gp.Normalize(affs, factor=1.0) +
            # raw: (d=1, h, w)
            # affs: (c=2, d=1, h, w)
            Squash(dim=-3) +
            # get rid of z dim
            # raw: (h, w)
            # affs: (c=2, h, w)
            AddChannelDim(raw) +
            # raw: (c=1, h, w)
            # affs: (c=2, h, w)
            gp.PreCache() +
            gp.Stack(batch_size) +
            # raw: (b=10, c=1, h, w)
            # affs: (b=10, c=2, h, w)
            Train(
                model=model,
                loss=loss,
                optimizer=optimizer,
                inputs={'x': raw},
                target=affs,
                output=affs_predicted,
                save_every=save_every,
                log_dir='log') +
            # raw: (b=10, c=1, h, w)
            # affs: (b=10, c=2, h, w)
            # affs_predicted: (b=10, c=2, h, w)
            TransposeDims(raw,(1, 0, 2, 3)) +
            TransposeDims(affs,(1, 0, 2, 3)) +
            TransposeDims(affs_predicted,(1, 0, 2, 3)) +
            # raw: (c=1, b=10, h, w)
            # affs: (c=2, b=10, h, w)
            # affs_predicted: (c=2, b=10, h, w)
            RemoveChannelDim(raw) +
            # raw: (b=10, h, w)
            # affs: (c=2, b=10, h, w)
            # affs_predicted: (c=2, b=10, h, w)
            gp.Snapshot(
                dataset_names={
                    raw: 'raw',
                    labels: 'labels',
                    affs: 'affs',
                    affs_predicted: 'affs_predicted'
                },
                every=100) +
            gp.PrintProfilingStats(every=100)
        )
    return pipeline 
Example #11
0
def predict(data_dir,
            train_dir,
            iteration,
            sample,
            test_net_name='train_net',
            train_net_name='train_net',
            output_dir='.',
            clip_max=1000):

    if "hdf" not in data_dir:
        return

    print("Predicting ", sample)
    print(
        'checkpoint: ',
        os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration))

    checkpoint = os.path.join(train_dir,
                              train_net_name + '_checkpoint_%d' % iteration)

    with open(os.path.join(train_dir, test_net_name + '_config.json'),
              'r') as f:
        net_config = json.load(f)

    with open(os.path.join(train_dir, test_net_name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    # ArrayKeys
    raw = gp.ArrayKey('RAW')
    pred_mask = gp.ArrayKey('PRED_MASK')

    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    voxel_size = gp.Coordinate((1, 1, 1))
    context = gp.Coordinate(input_shape - output_shape) / 2

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape, voxel_size=voxel_size)
    request.add(pred_mask, output_shape, voxel_size=voxel_size)

    print("chunk request %s" % request)

    source = (gp.Hdf5Source(
        data_dir,
        datasets={
            raw: sample + '/raw',
        },
        array_specs={
            raw:
            gp.ArraySpec(
                interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
        },
    ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) +
              gp.Normalize(raw, factor=1.0 / clip_max) +
              gp.IntensityScaleShift(raw, 2, -1))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        print("raw_roi: %s" % raw_roi)
        sample_shape = raw_roi.grow(-context, -context).get_shape()

    print(sample_shape)

    # create zarr file with corresponding chunk size
    zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w')

    zf.create('volumes/pred_mask',
              shape=sample_shape,
              chunks=output_shape,
              dtype=np.float16)
    zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1]

    pipeline = (
        source + gp.tensorflow.Predict(
            graph=os.path.join(train_dir, test_net_name + '.meta'),
            checkpoint=checkpoint,
            inputs={
                net_names['raw']: raw,
            },
            outputs={
                net_names['pred']: pred_mask,
            },
            array_specs={
                pred_mask:
                gp.ArraySpec(roi=raw_roi.grow(-context, -context),
                             voxel_size=voxel_size),
            },
            max_shared_memory=1024 * 1024 * 1024) +
        Convert(pred_mask, np.float16) + gp.ZarrWrite(
            dataset_names={
                pred_mask: 'volumes/pred_mask',
            },
            output_dir=output_dir,
            output_filename=sample + '.zarr',
            compression_type='gzip',
            dataset_dtypes={pred_mask: np.float16}) +

        # show a summary of time spend in each node every x iterations
        gp.PrintProfilingStats(every=100) +
        gp.Scan(reference=request, num_workers=5, cache_size=50))

    with gp.build(pipeline):

        pipeline.request_batch(gp.BatchRequest())
Example #12
0
def train_until(max_iteration, return_intermediates=False):

    # get the latest checkpoint
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # input data
    ch1 = gp.ArrayKey('CH1')
    ch2 = gp.ArrayKey('CH2')
    swc = gp.PointsKey('SWC')
    swc_env = gp.PointsKey('SWC_ENV')
    swc_center = gp.PointsKey('SWC_CENTER')
    gt = gp.ArrayKey('GT')
    gt_fg = gp.ArrayKey('GT_FG')

    # show fusion augment batches
    if return_intermediates:

        a_ch1 = gp.ArrayKey('A_CH1')
        a_ch2 = gp.ArrayKey('A_CH2')
        b_ch1 = gp.ArrayKey('B_CH1')
        b_ch2 = gp.ArrayKey('B_CH2')
        soft_mask = gp.ArrayKey('SOFT_MASK')

    # output data
    fg = gp.ArrayKey('FG')
    gradient_fg = gp.ArrayKey('GRADIENT_FG')
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    voxel_size = gp.Coordinate((4, 1, 1))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(ch1, input_size)
    request.add(ch2, input_size)
    request.add(swc, input_size)
    request.add(swc_center, output_size)
    request.add(gt, output_size)
    request.add(gt_fg, output_size)
    # request.add(loss_weights, output_size)

    if return_intermediates:

        request.add(a_ch1, input_size)
        request.add(a_ch2, input_size)
        request.add(b_ch1, input_size)
        request.add(b_ch2, input_size)
        request.add(soft_mask, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    # snapshot_request[fg] = request[gt]
    # snapshot_request[gt_fg] = request[gt]
    # snapshot_request[gradient_fg] = request[gt]

    data_sources = tuple()
    data_sources += tuple(
        (Hdf5ChannelSource(file,
                           datasets={
                               ch1: '/volume',
                               ch2: '/volume',
                           },
                           channel_ids={
                               ch1: 0,
                               ch2: 1,
                           },
                           data_format='channels_last',
                           array_specs={
                               ch1:
                               gp.ArraySpec(interpolatable=True,
                                            voxel_size=voxel_size,
                                            dtype=np.uint16),
                               ch2:
                               gp.ArraySpec(interpolatable=True,
                                            voxel_size=voxel_size,
                                            dtype=np.uint16),
                           }),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center, swc),
                   return_env=True,
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center) + RasterizeSkeleton(
            points=swc,
            array=gt,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            points_env=swc_env,
            iteration=10) for file in files)

    snapshot_datasets = {}

    if return_intermediates:

        snapshot_datasets = {
            ch1: 'volumes/ch1',
            ch2: 'volumes/ch2',
            a_ch1: 'volumes/a_ch1',
            a_ch2: 'volumes/a_ch2',
            b_ch1: 'volumes/b_ch1',
            b_ch2: 'volumes/b_ch2',
            soft_mask: 'volumes/soft_mask',
            gt: 'volumes/gt',
            fg: 'volumes/fg',
            gt_fg: 'volumes/gt_fg',
            gradient_fg: 'volumes/gradient_fg',
        }

    else:

        snapshot_datasets = {
            ch1: 'volumes/ch1',
            ch2: 'volumes/ch2',
            gt: 'volumes/gt',
            fg: 'volumes/fg',
            gt_fg: 'volumes/gt_fg',
            gradient_fg: 'volumes/gradient_fg',
        }

    pipeline = (
        data_sources +
        #gp.RandomProvider() +
        FusionAugment(ch1,
                      ch2,
                      gt,
                      smoothness=1,
                      return_intermediate=return_intermediates) +

        # augment
        #gp.ElasticAugment(...) +
        #gp.SimpleAugment() +
        gp.Normalize(ch1) + gp.Normalize(ch2) + gp.Normalize(a_ch1) +
        gp.Normalize(a_ch2) + gp.Normalize(b_ch1) + gp.Normalize(b_ch2) +
        gp.IntensityAugment(ch1, 0.9, 1.1, -0.001, 0.001) +
        gp.IntensityAugment(ch2, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(gt, gt_fg) +

        # visualize
        gp.Snapshot(output_filename='snapshot_{iteration}.hdf',
                    dataset_names=snapshot_datasets,
                    additional_request=snapshot_request,
                    every=20) + gp.PrintProfilingStats(every=1000))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
def train_distance_pipeline(n_iterations, setup_config, mknet_tensor_names,
                            loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    max_label_dist = setup_config["MAX_LABEL_DIST"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    dist = gp.ArrayKey("DIST")
    dist_mask = gp.ArrayKey("DIST_MASK")
    dist_cropped = gp.ArrayKey("DIST_CROPPED")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    fg_dist = gp.ArrayKey("FG_DIST")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")

    # add request
    request = gp.BatchRequest()
    request.add(dist_mask, output_size)
    request.add(dist_cropped, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(dist, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)
    request.add(loss_weights, output_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()

    # tensorflow requests
    snapshot_request.add(raw, input_size)  # input_size request for positioning
    snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    snapshot_request.add(fg_dist, output_size, voxel_size=voxel_size)

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched, center_size=output_size) +
        RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + gp.contrib.nodes.add_distance.AddDistance(
            labels,
            dist,
            dist_mask,
            max_distance=max_label_dist * micron_scale) + gp.contrib.nodes.
        tanh_saturate.TanhSaturate(dist, scale=micron_scale, offset=1)
        + ThresholdMask(dist, loss_weights, 1e-4)
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi /
             2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(dist, dist_cropped)
        # + gp.PreCache(cache_size=cache_size, num_workers=num_workers)
        + gp.tensorflow.Train(
            "train_net_foreground",
            optimizer=mknet_tensor_names["optimizer"],
            loss=mknet_tensor_names["fg_loss"],
            inputs={
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_distances"]: dist_cropped,
                mknet_tensor_names["loss_weights"]: loss_weights,
            },
            outputs={mknet_tensor_names["fg_pred"]: fg_dist},
            gradients={mknet_tensor_names["fg_pred"]: gradient_fg},
            save_every=checkpoint_every,
            # summary=mknet_tensor_names["summaries"],
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                labels: "volumes/labels",
                # labeled data
                dist_cropped: "volumes/dist",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                fg_dist: "volumes/fg_dist",
                gradient_fg: "volumes/gradient_fg",
                # output debug data
                dist_mask: "volumes/dist_mask",
                loss_weights: "volumes/loss_weights"
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
Example #14
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for fused volume
    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    labels_fg = gp.ArrayKey('LABELS_FG')

    # array keys for base volume
    raw_base = gp.ArrayKey('RAW_BASE')
    labels_base = gp.ArrayKey('LABELS_BASE')
    swc_base = gp.PointsKey('SWC_BASE')
    swc_center_base = gp.PointsKey('SWC_CENTER_BASE')

    # array keys for add volume
    raw_add = gp.ArrayKey('RAW_ADD')
    labels_add = gp.ArrayKey('LABELS_ADD')
    swc_add = gp.PointsKey('SWC_ADD')
    swc_center_add = gp.PointsKey('SWC_CENTER_ADD')

    # output data
    fg = gp.ArrayKey('FG')
    gradient_fg = gp.ArrayKey('GRADIENT_FG')
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    voxel_size = gp.Coordinate((3, 3, 3))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_fg, output_size)
    request.add(loss_weights, output_size)

    request.add(swc_center_base, output_size)
    request.add(swc_base, input_size)

    request.add(swc_center_add, output_size)
    request.add(swc_add, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(fg, output_size)
    snapshot_request.add(labels_fg, output_size)
    snapshot_request.add(gradient_fg, output_size)
    snapshot_request.add(raw_base, input_size)
    snapshot_request.add(raw_add, input_size)
    snapshot_request.add(labels_base, input_size)
    snapshot_request.add(labels_add, input_size)

    # data source for "base" volume
    data_sources_base = tuple()
    data_sources_base += tuple(
        (gp.Hdf5Source(file,
                       datasets={
                           raw_base: '/volume',
                       },
                       array_specs={
                           raw_base:
                           gp.ArraySpec(interpolatable=True,
                                        voxel_size=voxel_size,
                                        dtype=np.uint16),
                       },
                       channels_first=False),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center_base, swc_base),
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton(
            points=swc_base,
            array=labels_base,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            iteration=10) for file in files)
    data_sources_base += gp.RandomProvider()

    # data source for "add" volume
    data_sources_add = tuple()
    data_sources_add += tuple(
        (gp.Hdf5Source(file,
                       datasets={
                           raw_add: '/volume',
                       },
                       array_specs={
                           raw_add:
                           gp.ArraySpec(interpolatable=True,
                                        voxel_size=voxel_size,
                                        dtype=np.uint16),
                       },
                       channels_first=False),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center_add, swc_add),
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton(
            points=swc_add,
            array=labels_add,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            iteration=1) for file in files)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add
                          ]) + gp.MergeProvider()

    pipeline = (
        data_sources + FusionAugment(raw_base,
                                     raw_add,
                                     labels_base,
                                     labels_add,
                                     raw,
                                     labels,
                                     blend_mode='labels_mask',
                                     blend_smoothness=10,
                                     num_blended_objects=0) +

        # augment
        gp.ElasticAugment([10, 10, 10], [1, 1, 1], [0, math.pi / 2.0],
                          subsample=8) +
        gp.SimpleAugment(mirror_only=[2], transpose_only=[]) +
        gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(labels, labels_fg) +
        gp.BalanceLabels(labels_fg, loss_weights) +

        # train
        gp.PreCache(cache_size=40, num_workers=10) +
        gp.tensorflow.Train('./train_net',
                            optimizer=net_names['optimizer'],
                            loss=net_names['loss'],
                            inputs={
                                net_names['raw']: raw,
                                net_names['labels_fg']: labels_fg,
                                net_names['loss_weights']: loss_weights,
                            },
                            outputs={
                                net_names['fg']: fg,
                            },
                            gradients={
                                net_names['fg']: gradient_fg,
                            },
                            save_every=100) +

        # visualize
        gp.Snapshot(output_filename='snapshot_{iteration}.hdf',
                    dataset_names={
                        raw: 'volumes/raw',
                        raw_base: 'volumes/raw_base',
                        raw_add: 'volumes/raw_add',
                        labels: 'volumes/labels',
                        labels_base: 'volumes/labels_base',
                        labels_add: 'volumes/labels_add',
                        fg: 'volumes/fg',
                        labels_fg: 'volumes/labels_fg',
                        gradient_fg: 'volumes/gradient_fg',
                    },
                    additional_request=snapshot_request,
                    every=10) + gp.PrintProfilingStats(every=100))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Example #15
0
def predict_volume(model,
                   dataset,
                   out_dir,
                   out_filename,
                   out_ds_names,
                   checkpoint,
                   input_name='raw_0',
                   normalize_factor=None,
                   model_output=0,
                   in_shape=None,
                   out_shape=None,
                   spawn_subprocess=True,
                   num_workers=0,
                   apply_voxel_size=True):

    raw = gp.ArrayKey('RAW')
    prediction = gp.ArrayKey('PREDICTION')

    data = daisy.open_ds(dataset.filename, dataset.ds_names[0])
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    data_dims = len(data.shape)

    # Get in and out shape
    if in_shape is None:
        in_shape = model.in_shape
    if out_shape is None:
        out_shape = model.out_shape

    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    spatial_dims = in_shape.dims()
    is_2d = spatial_dims == 2

    in_shape = in_shape * voxel_size
    out_shape = out_shape * voxel_size

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    context = (in_shape - out_shape) / 2

    source = (gp.ZarrSource(
        dataset.filename, {
            raw: dataset.ds_names[0],
        },
        array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)}))

    # ensure raw has sample and channel dims
    #
    # n = number of samples
    # c = number of channels

    # 2D raw is either (n, y, x) or (c, n, y, x)
    # 3D raw is either (z, y, x) or (c, z, y, x)
    for _ in range((2 + spatial_dims) - data_dims):
        source += AddChannelDim(raw)

    # 2D raw: (c, n, y, x)
    # 3D raw: (c, n=1, z, y, x)

    # prediction requires samples first, channels second
    source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims)))

    # 2D raw: (n, c, y, x)
    # 3D raw: (n=1, c, z, y, x)

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline = source

    if normalize_factor != "skip":
        pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor)

    pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict(
        model,
        inputs={input_name: raw},
        outputs={model_output: prediction},
        array_specs={prediction: gp.ArraySpec(roi=raw_roi)},
        checkpoint=checkpoint,
        spawn_subprocess=spawn_subprocess))

    # 2D raw       : (n, c, y, x)
    # 2D prediction: (n, c, y, x)
    # 3D raw       : (n=1, c, z, y, x)
    # 3D prediction: (n=1, c, z, y, x)

    if is_2d:

        # restore channels first for 2D data
        pipeline += TransposeDims(raw,
                                  (1, 0) + tuple(range(2, 2 + spatial_dims)))
        pipeline += TransposeDims(prediction,
                                  (1, 0) + tuple(range(2, 2 + spatial_dims)))

    else:

        # remove sample dimension for 3D data
        pipeline += RemoveChannelDim(raw)
        pipeline += RemoveChannelDim(prediction)

    # 2D raw       : (c, n, y, x)
    # 2D prediction: (c, n, y, x)
    # 3D raw       : (c, z, y, x)
    # 3D prediction: (c, z, y, x)

    pipeline += (gp.ZarrWrite({
        prediction: out_ds_names[0],
    },
                              output_dir=out_dir,
                              output_filename=out_filename,
                              compression_type='gzip') +
                 gp.Scan(request, num_workers=num_workers))

    logger.info("Writing prediction to %s/%s[%s]", out_dir, out_filename,
                out_ds_names[0])

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
Example #16
0
        transform_file=str((filename / "transform.txt").absolute()),
        ignore_human_nodes=False,
    ),
) + gp.MergeProvider() + gp.RandomLocation(
    ensure_nonempty=swcs, ensure_centered=True) + RasterizeSkeleton(
        points=swcs,
        array=labels,
        array_spec=gp.ArraySpec(
            interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
    ) + GrowLabels(labels, radius=20)
                     # augment
                     + gp.ElasticAugment(
                         [40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0],
                         subsample=4) + gp.SimpleAugment(
                             mirror_only=[1, 2], transpose_only=[1, 2]) +
                     gp.Normalize(raw) +
                     gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
                     for filename in path_to_data.iterdir()
                     if "2018-07-02" in filename.name)

pipeline = (data_sources + gp.RandomProvider() + GetNeuronPair(
    swcs,
    raw,
    labels,
    (swc_base, swc_add),
    (raw_base, raw_add),
    (labels_base, labels_add),
    seperate_by=SEPERATE_DISTANCE,
    shift_attempts=50,
    request_attempts=10,
) + FusionAugment(
Example #17
0
def train(iterations):

    ##################
    # DECLARE ARRAYS #
    ##################

    # raw intensities
    raw = gp.ArrayKey('RAW')

    # objects labelled with unique IDs
    gt_labels = gp.ArrayKey('LABELS')

    # array of per-voxel affinities to direct neighbors
    gt_affs = gp.ArrayKey('AFFINITIES')

    # weights to use to balance the loss
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    # the predicted affinities
    pred_affs = gp.ArrayKey('PRED_AFFS')

    # the gredient of the loss wrt to the predicted affinities
    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')

    ####################
    # DECLARE REQUESTS #
    ####################

    with open('train_net_config.json', 'r') as f:
        net_config = json.load(f)

    # get the input and output size in world units (nm, in this case)
    voxel_size = gp.Coordinate((8, 8, 8))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt_affs, output_size)
    request.add(loss_weights, output_size)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request[pred_affs] = request[gt_affs]
    snapshot_request[pred_affs_gradients] = request[gt_affs]

    ##############################
    # ASSEMBLE TRAINING PIPELINE #
    ##############################

    pipeline = (

        # a tuple of sources, one for each sample (A, B, and C) provided by the
        # CREMI challenge
        tuple(

            # read batches from the HDF5 file
            gp.Hdf5Source(os.path.join(data_dir, 'fib.hdf'),
                          datasets={
                              raw: 'volumes/raw',
                              gt_labels: 'volumes/labels/neuron_ids'
                          }) +

            # convert raw to float in [0, 1]
            gp.Normalize(raw) +

            # chose a random location for each requested batch
            gp.RandomLocation()) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        gp.ElasticAugment([8, 8, 8], [0, 2, 2], [0, math.pi / 2.0],
                          prob_slip=0.05,
                          prob_shift=0.05,
                          max_misalign=25) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(transpose_only=[1, 2]) +

        # scale and shift the intensity of the raw array
        gp.IntensityAugment(raw,
                            scale_min=0.9,
                            scale_max=1.1,
                            shift_min=-0.1,
                            shift_max=0.1,
                            z_section_wise=True) +

        # grow a boundary between labels
        gp.GrowBoundary(gt_labels, steps=3, only_xy=True) +

        # convert labels into affinities between voxels
        gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels,
                         gt_affs) +

        # create a weight array that balances positive and negative samples in
        # the affinity array
        gp.BalanceLabels(gt_affs, loss_weights) +

        # pre-cache batches from the point upstream
        gp.PreCache(cache_size=10, num_workers=5) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            'train_net',
            net_config['optimizer'],
            net_config['loss'],
            inputs={
                net_config['raw']: raw,
                net_config['gt_affs']: gt_affs,
                net_config['loss_weights']: loss_weights
            },
            outputs={net_config['pred_affs']: pred_affs},
            gradients={net_config['pred_affs']: pred_affs_gradients},
            save_every=10000) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                gt_labels: '/volumes/labels/neuron_ids',
                gt_affs: '/volumes/labels/affs',
                pred_affs: '/volumes/pred_affs',
                pred_affs_gradients: '/volumes/pred_affs_gradients'
            },
            output_dir='snapshots',
            output_filename='batch_{iteration}.hdf',
            every=1000,
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=1000))

    #########
    # TRAIN #
    #########

    print("Training for", iterations, "iterations")

    with gp.build(pipeline):
        for i in range(iterations):
            pipeline.request_batch(request)

    print("Finished")
Example #18
0
def validation_pipeline(config):
    """
    Per block
    {
        Raw -> predict -> scan
        gt -> rasterize        -> merge -> candidates -> trees
    } -> merge -> comatch + evaluate
    """
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        raw = gp.ArrayKey(f"RAW_{block}")
        raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}")
        ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}")
        labels = gp.ArrayKey(f"LABELS_{block}")

        raw_source = (gp.ZarrSource(
            filename=str(Path(sample_dir, sample, raw_n5).absolute()),
            datasets={
                raw: "volume-rechunked",
                raw_clahed: "volume-rechunked"
            },
            array_specs={
                raw:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
                raw_clahed:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
            },
        ) + gp.Normalize(raw, dtype=np.float32) +
                      gp.Normalize(raw_clahed, dtype=np.float32) +
                      scipyCLAHE([raw_clahed], [20, 64, 64]))
        swc_source = nl.gunpowder.nodes.MouselightSwcFileSource(
            validation_dir,
            [ground_truth],
            transform_file=transform_template.format(sample=sample),
            ignore_human_nodes=False,
            scale=voxel_size,
            transpose=[2, 1, 0],
            points_spec=[
                gp.PointsSpec(roi=gp.Roi(
                    gp.Coordinate([None, None, None]),
                    gp.Coordinate([None, None, None]),
                ))
            ],
        )

        additional_request = BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)

        cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()),
                                  cube_roi.get_shape())
        input_roi = cube_roi_shifted.grow((input_size - output_size) // 2,
                                          (input_size - output_size) // 2)

        block_spec = specs.setdefault(block, {})
        block_spec[raw] = gp.ArraySpec(input_roi)
        additional_request[raw] = gp.ArraySpec(roi=input_roi)
        block_spec[raw_clahed] = gp.ArraySpec(input_roi)
        additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi)
        block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted)
        additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted)
        block_spec[labels] = gp.ArraySpec(cube_roi_shifted)
        additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted)

        pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() +
                    gp.SpecifiedLocation(locations=[cube_roi.get_center()]) +
                    gp.Crop(raw, roi=input_roi) +
                    gp.Crop(raw_clahed, roi=input_roi) +
                    gp.Crop(ground_truth, roi=cube_roi_shifted) +
                    nl.gunpowder.RasterizeSkeleton(
                        ground_truth,
                        labels,
                        connected_component_labeling=True,
                        array_spec=gp.ArraySpec(
                            voxel_size=voxel_size,
                            dtype=np.int64,
                            roi=gp.Roi(
                                gp.Coordinate([None, None, None]),
                                gp.Coordinate([None, None, None]),
                            ),
                        ),
                    ) + nl.gunpowder.GrowLabels(
                        labels, radii=[neuron_width * micron_scale]) +
                    gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot(
                        {
                            raw: f"volumes/{block}/raw",
                            raw_clahed: f"volumes/{block}/raw_clahe",
                            ground_truth: f"points/{block}/ground_truth",
                            labels: f"volumes/{block}/labels",
                        },
                        additional_request=additional_request,
                        output_dir="validations",
                        output_filename="validations.hdf",
                    ))

        validation_pipelines.append(pipeline)

    validation_pipeline = (tuple(pipeline
                                 for pipeline in validation_pipelines) +
                           gp.MergeProvider() + gp.PrintProfilingStats())
    return validation_pipeline, specs
Example #19
0
    def create_train_pipeline(self, model):

        print(f"Creating training pipeline with batch size \
              {self.params['batch_size']}")

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        points = gp.GraphKey("POINTS")
        locations = gp.ArrayKey("LOCATIONS")
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_roi = gp.Coordinate(model.base_encoder.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_roi = out_roi * out_voxel_size
        out_shape = gp.Coordinate(
            (self.params["num_points"], *model.out_shape[2:]))

        context = (in_shape - out_roi) / 2
        gt_labels_out_shape = out_roi
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            gt_labels_out_shape = (1, *gt_labels_out_shape)

            points_roi = out_voxel_size * tuple((*self.params["point_roi"], ))
            points_pad = (0, *points_roi)
            context = gp.Coordinate((0, None, None))
        else:
            points_roi = source_voxel_size * tuple(self.params["point_roi"])
            points_pad = points_roi
            context = gp.Coordinate((None, None, None))

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")
        logger.info(f"out_voxel_size: {out_voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(points, points_roi)
        request.add(gt_labels, out_roi)
        request[locations] = gp.ArraySpec(nonspatial=True)
        request[predictions] = gp.ArraySpec(nonspatial=True)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))

        source = (
            (gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                           array_specs={
                               raw:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size,
                                            interpolatable=True),
                               gt_labels:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size)
                           }),
             PointsLabelsSource(points, self.data, scale=source_voxel_size)) +
            gp.MergeProvider() + gp.Pad(raw, context) +
            gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) +
            gp.RandomLocation(ensure_nonempty=points) +
            gp.Normalize(raw, self.params['norm_factor'])
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            # If 2d then source_roi = (1, input_shape) in order to select a RL
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # Batches seem to be rejected because points are chosen near the
            # edge of the points ROI and the augmentations remove them.
            # TODO: Figure out if this is an actual issue, and if anything can
            # be done.
            gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) +
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            AddChannelDim(raw) + AddChannelDim(gt_labels)
            # raw      : (c=1, source_roi)
            # gt_labels: (c=2, source_roi)
            # points   : (c=1, source_locations_shape)
        )

        if is_2d:
            pipeline = (
                # Remove extra dim the 2d roi had
                pipeline + RemoveSpatialDim(raw) +
                RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points)
                # raw      : (c=1, roi)
                # gt_labels: (c=1, roi)
                # points   : (c=1, locations_shape)
            )

        pipeline = (
            pipeline +
            FillLocations(raw, points, locations, is_2d=False, max_points=1) +
            gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, roi)
            # gt_labels: (b, c=1, roi)
            # locations: (b, c=1, locations_shape)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={
                    'raw': raw,
                    'points': locations
                },
                loss_inputs={
                    0: predictions,
                    1: gt_labels,
                    2: locations
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(nonspatial=True),
                    emb: gp.ArraySpec(voxel_size=out_voxel_size)
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, roi)
            # gt_labels  : (b, c=1, roi)
            # predictions: (b, num_points)
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            InspectBatch('END') + gp.PrintProfilingStats(every=500))

        return pipeline, request
Example #20
0
def train(n_iterations):

    raw = gp.ArrayKey("RAW")
    gt = gp.ArrayKey("GT")
    gt_fg = gp.ArrayKey("GT_FP")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")

    request = gp.BatchRequest()
    request.add(raw, (200, 200))
    request.add(gt, (160, 160))

    snapshot_request = gp.BatchRequest()
    snapshot_request[embedding] = request[gt]
    snapshot_request[fg] = request[gt]
    snapshot_request[gt_fg] = request[gt]
    snapshot_request[maxima] = request[gt]
    snapshot_request[gradient_embedding] = request[gt]
    snapshot_request[gradient_fg] = request[gt]
    snapshot_request[emst] = gp.ArraySpec()
    snapshot_request[edges_u] = gp.ArraySpec()
    snapshot_request[edges_v] = gp.ArraySpec()

    pipeline = (Synthetic2DSource(raw, gt) + gp.Normalize(raw) +
                gp.tensorflow.Train(
                    "train_net",
                    optimizer=add_loss,
                    loss=None,
                    inputs={
                        tensor_names["raw"]: raw,
                        tensor_names["gt_labels"]: gt
                    },
                    outputs={
                        tensor_names["embedding"]: embedding,
                        tensor_names["fg"]: fg,
                        "maxima:0": maxima,
                        "gt_fg:0": gt_fg,
                        emst_name: emst,
                        edges_u_name: edges_u,
                        edges_v_name: edges_v,
                    },
                    gradients={
                        tensor_names["embedding"]: gradient_embedding,
                        tensor_names["fg"]: gradient_fg,
                    },
                ) + gp.Snapshot(
                    output_filename="{iteration}.hdf",
                    dataset_names={
                        raw: "volumes/raw",
                        gt: "volumes/gt",
                        embedding: "volumes/embedding",
                        fg: "volumes/fg",
                        maxima: "volumes/maxima",
                        gt_fg: "volumes/gt_fg",
                        gradient_embedding: "volumes/gradient_embedding",
                        gradient_fg: "volumes/gradient_fg",
                        emst: "emst",
                        edges_u: "edges_u",
                        edges_v: "edges_v",
                    },
                    dataset_dtypes={
                        maxima: np.float32,
                        gt_fg: np.float32
                    },
                    every=100,
                    additional_request=snapshot_request,
                ))

    with gp.build(pipeline):
        for i in range(n_iterations):
            pipeline.request_batch(request)
Example #21
0
    def make_pipeline(self):
        raw = gp.ArrayKey('RAW')
        pred_affs = gp.ArrayKey('PREDICTIONS')

        source_shape = zarr.open(self.data_file)[self.dataset].shape
        raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:])

        data = daisy.open_ds(self.data_file, self.dataset)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(self.model.in_shape)
        out_shape = gp.Coordinate(self.model.out_shape[2:])

        is_2d = in_shape.dims() == 2

        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(pred_affs, out_shape)

        context = (in_shape - out_shape) / 2

        source = (gp.ZarrSource(self.data_file, {
            raw: self.dataset,
        },
                                array_specs={
                                    raw:
                                    gp.ArraySpec(roi=source_roi,
                                                 interpolatable=False)
                                }))

        in_dims = len(self.model.in_shape)
        if is_2d:
            # 2D: [samples, y, x] or [samples, channels, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 1)
            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=1))
            # raw [samples, channels, y, x]
        else:
            # 3D: [z, y, x] or [channel, z, y, x] or [sample, channel, z, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 0)
            needs_batch_fix = (len(data.shape) - in_dims <= 1)

            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=0))
            # Batch fix
            if needs_batch_fix:
                source = (source + AddChannelDim(raw))
            # raw: [sample, channels, z, y, x]

        with gp.build(source):
            raw_roi = source.spec[raw].roi
            logger.info(f"raw_roi: {raw_roi}")

        pipeline = (source +
                    gp.Normalize(raw, factor=self.params['norm_factor']) +
                    gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict(
                        self.model,
                        inputs={'raw': raw},
                        outputs={0: pred_affs},
                        array_specs={pred_affs: gp.ArraySpec(roi=raw_roi)}))

        pipeline = (pipeline + gp.ZarrWrite({
            pred_affs: 'predictions',
        },
                                            output_dir=self.curr_log_dir,
                                            output_filename='predictions.zarr',
                                            compression_type='gzip') +
                    gp.Scan(request))

        return pipeline, request, pred_affs
Example #22
0
def create_pipeline_3d(task, predictor, optimizer, batch_size, outdir,
                       snapshot_every):

    raw_channels = max(1, task.data.raw.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = task.data.raw.train.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    weights = gp.ArrayKey('WEIGHTS')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = task.data.raw.train.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D training not yet implemented")

    sources = (task.data.raw.train.get_source(raw),
               task.data.gt.train.get_source(gt))
    pipeline = sources + gp.MergeProvider()
    pipeline += gp.Pad(raw, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.RandomLocation()
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    for augmentation in eval(task.augmentations):
        pipeline += augmentation
    pipeline += predictor.add_target(gt, target)
    # (don't care about gt anymore)
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    weights_node = task.loss.add_weights(target, weights)
    if weights_node:
        pipeline += weights_node
        loss_inputs = {0: prediction, 1: target, 2: weights}
    else:
        loss_inputs = {0: prediction, 1: target}
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    pipeline += gp.PreCache()
    pipeline += gp.Stack(batch_size)
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    pipeline += gp_torch.Train(model=predictor,
                               loss=task.loss,
                               optimizer=optimizer,
                               inputs={'x': raw},
                               loss_inputs=loss_inputs,
                               outputs={0: prediction},
                               save_every=1e6)
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    # prediction: (b, [c,] d, h, w)
    if snapshot_every > 0:
        # get channels first
        pipeline += TransposeDims(raw, (1, 0, 2, 3, 4))
        if predictor.target_channels > 0:
            pipeline += TransposeDims(target, (1, 0, 2, 3, 4))
            if weights_node:
                pipeline += TransposeDims(weights, (1, 0, 2, 3, 4))
        if predictor.prediction_channels > 0:
            pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4))
        # raw: (c, b, d, h, w)
        # target: ([c,] b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: ([c,] b, d, h, w)
        if channel_dims == 0:
            pipeline += RemoveChannelDim(raw)
        # raw: ([c,] b, d, h, w)
        # target: (c, b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: (c, b, d, h, w)
        pipeline += gp.Snapshot(dataset_names={
            raw: 'raw',
            target: 'target',
            prediction: 'prediction',
            weights: 'weights'
        },
                                every=snapshot_every,
                                output_dir=os.path.join(outdir, 'snapshots'),
                                output_filename="{iteration}.hdf")
    pipeline += gp.PrintProfilingStats(every=100)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, output_size)
    request.add(target, output_size)
    if weights_node:
        request.add(weights, output_size)
    request.add(prediction, output_size)

    return pipeline, request
Example #23
0
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names,
                          loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    neuron_radius = setup_config["NEURON_RADIUS"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    gt_fg = gp.ArrayKey("GT_FG")
    fg_pred = gp.ArrayKey("FG_PRED")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")
    ratio_pos = gp.ArrayKey("RATIO_POS")
    ratio_neg = gp.ArrayKey("RATIO_NEG")
    dist = gp.ArrayKey("DIST")
    num_pos_pairs = gp.ArrayKey("NUM_POS")
    num_neg_pairs = gp.ArrayKey("NUM_NEG")

    # add request
    request = gp.BatchRequest()
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    request.add(labels_fg, output_size)

    # tensorflow requests
    # snapshot_request.add(raw, input_size)  # input_size request for positioning
    # snapshot_request.add(embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size)
    # snapshot_request.add(maxima, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    # snapshot_request[emst] = gp.ArraySpec()
    # snapshot_request[edges_u] = gp.ArraySpec()
    # snapshot_request[edges_v] = gp.ArraySpec()
    # snapshot_request[ratio_pos] = gp.ArraySpec()
    # snapshot_request[ratio_neg] = gp.ArraySpec()
    # snapshot_request[dist] = gp.ArraySpec()
    # snapshot_request[num_pos_pairs] = gp.ArraySpec()
    # snapshot_request[num_neg_pairs] = gp.ArraySpec()

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched) + RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + GrowLabels(labels, radii=[neuron_radius * micron_scale])
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(labels, labels_fg) +
        BinarizeGt(labels_fg, labels_fg_bin) +
        gp.BalanceLabels(labels_fg_bin, loss_weights) +
        gp.PreCache(cache_size=cache_size, num_workers=num_workers) +
        gp.tensorflow.Train(
            "train_net",
            optimizer=create_custom_loss(mknet_tensor_names, setup_config),
            loss=None,
            inputs={
                mknet_tensor_names["loss_weights"]: loss_weights,
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_labels"]: labels_fg,
            },
            outputs={
                mknet_tensor_names["embedding"]: embedding,
                mknet_tensor_names["fg"]: fg,
                loss_tensor_names["fg_pred"]: fg_pred,
                loss_tensor_names["maxima"]: maxima,
                loss_tensor_names["gt_fg"]: gt_fg,
                loss_tensor_names["emst"]: emst,
                loss_tensor_names["edges_u"]: edges_u,
                loss_tensor_names["edges_v"]: edges_v,
                loss_tensor_names["ratio_pos"]: ratio_pos,
                loss_tensor_names["ratio_neg"]: ratio_neg,
                loss_tensor_names["dist"]: dist,
                loss_tensor_names["num_pos_pairs"]: num_pos_pairs,
                loss_tensor_names["num_neg_pairs"]: num_neg_pairs,
            },
            gradients={
                mknet_tensor_names["embedding"]: gradient_embedding,
                mknet_tensor_names["fg"]: gradient_fg,
            },
            save_every=checkpoint_every,
            summary="Merge/MergeSummary:0",
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                # labeled data
                labels: "volumes/labels",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                embedding: "volumes/embedding",
                fg: "volumes/fg",
                maxima: "volumes/maxima",
                gt_fg: "volumes/gt_fg",
                fg_pred: "volumes/fg_pred",
                gradient_embedding: "volumes/gradient_embedding",
                gradient_fg: "volumes/gradient_fg",
                # output trees
                emst: "emst",
                edges_u: "edges_u",
                edges_v: "edges_v",
                # output debug data
                ratio_pos: "ratio_pos",
                ratio_neg: "ratio_neg",
                dist: "dist",
                num_pos_pairs: "num_pos_pairs",
                num_neg_pairs: "num_neg_pairs",
                loss_weights: "volumes/loss_weights",
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
Example #24
0
def create_pipeline_2d(task, predictor, optimizer, batch_size, outdir,
                       snapshot_every):

    raw_channels = task.data.raw.num_channels
    filename = task.data.raw.train.filename
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = task.data.raw.train.shape
    dataset_roi = task.data.raw.train.roi
    voxel_size = task.data.raw.train.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    weights = gp.ArrayKey('WEIGHTS')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError("For 2D training, please provide a 3D array where "
                           "the first dimension indexes the samples.")

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    sources = (task.data.raw.train.get_source(raw, overwrite_spec=spec),
               task.data.gt.train.get_source(gt, overwrite_spec=spec))
    pipeline = sources + gp.MergeProvider()
    pipeline += gp.Pad(raw, None)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d=1, h, w)
    # gt: ([c,] d=1, h, w)
    pipeline += gp.RandomLocation()
    # raw: ([c,] d=1, h, w)
    # gt: ([c,] d=1, h, w)
    for augmentation in eval(task.augmentations):
        pipeline += augmentation
    pipeline += predictor.add_target(gt, target)
    # (don't care about gt anymore)
    # raw: ([c,] d=1, h, w)
    # target: ([c,] d=1, h, w)
    weights_node = task.loss.add_weights(target, weights)
    if weights_node:
        pipeline += weights_node
        loss_inputs = {0: prediction, 1: target, 2: weights}
    else:
        loss_inputs = {0: prediction, 1: target}
    # raw: ([c,] d=1, h, w)
    # target: ([c,] d=1, h, w)
    # [weights: ([c,] d=1, h, w)]
    # get rid of z dim:
    pipeline += Squash(dim=-3)
    # raw: ([c,] h, w)
    # target: ([c,] h, w)
    # [weights: ([c,] h, w)]
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, h, w)
    # target: ([c,] h, w)
    # [weights: ([c,] h, w)]
    pipeline += gp.PreCache()
    pipeline += gp.Stack(batch_size)
    # raw: (b, c, h, w)
    # target: (b, [c,] h, w)
    # [weights: (b, [c,] h, w)]
    pipeline += gp_torch.Train(model=predictor,
                               loss=task.loss,
                               optimizer=optimizer,
                               inputs={'x': raw},
                               loss_inputs=loss_inputs,
                               outputs={0: prediction},
                               save_every=1e6)
    # raw: (b, c, h, w)
    # target: (b, [c,] h, w)
    # [weights: (b, [c,] h, w)]
    # prediction: (b, [c,] h, w)
    if snapshot_every > 0:
        # get channels first
        pipeline += TransposeDims(raw, (1, 0, 2, 3))
        if predictor.target_channels > 0:
            pipeline += TransposeDims(target, (1, 0, 2, 3))
            if weights_node:
                pipeline += TransposeDims(weights, (1, 0, 2, 3))
        if predictor.prediction_channels > 0:
            pipeline += TransposeDims(prediction, (1, 0, 2, 3))
        # raw: (c, b, h, w)
        # target: ([c,] b, h, w)
        # [weights: ([c,] b, h, w)]
        # prediction: ([c,] b, h, w)
        if channel_dims == 0:
            pipeline += RemoveChannelDim(raw)
        # raw: ([c,] b, h, w)
        # target: ([c,] b, h, w)
        # [weights: ([c,] b, h, w)]
        # prediction: ([c,] b, h, w)
        pipeline += gp.Snapshot(dataset_names={
            raw: 'raw',
            target: 'target',
            prediction: 'prediction',
            weights: 'weights'
        },
                                every=snapshot_every,
                                output_dir=os.path.join(outdir, 'snapshots'),
                                output_filename="{iteration}.hdf")
    pipeline += gp.PrintProfilingStats(every=100)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, output_size)
    request.add(target, output_size)
    if weights_node:
        request.add(weights, output_size)
    request.add(prediction, output_size)

    return pipeline, request
Example #25
0
def validation_data_sources_recomputed(config, blocks):
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    validation_dirs = {}
    for group in benchmark_datasets_path.iterdir():
        if "validation" in group.name and group.is_dir():
            for validation_dir in group.iterdir():
                validation_num = int(validation_dir.name.split("_")[-1])
                if validation_num in blocks:
                    validation_dirs[validation_num] = validation_dir

    validation_dirs = [validation_dirs[block] for block in blocks]

    raw = gp.ArrayKey("RAW")
    ground_truth = gp.GraphKey("GROUND_TRUTH")
    labels = gp.ArrayKey("LABELS")

    validation_pipelines = []
    for validation_dir in validation_dirs:
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        pipeline = ((
            gp.ZarrSource(
                filename=str(Path(sample_dir, sample, raw_n5).absolute()),
                datasets={raw: "volume-rechunked"},
                array_specs={
                    raw: gp.ArraySpec(interpolatable=True,
                                      voxel_size=voxel_size)
                },
            ),
            nl.gunpowder.nodes.MouselightSwcFileSource(
                validation_dir,
                [ground_truth],
                transform_file=transform_template.format(sample=sample),
                ignore_human_nodes=False,
                scale=voxel_size,
                transpose=[2, 1, 0],
                points_spec=[
                    gp.PointsSpec(roi=gp.Roi(
                        gp.Coordinate([None, None, None]),
                        gp.Coordinate([None, None, None]),
                    ))
                ],
            ),
        ) + gp.nodes.MergeProvider() + gp.Normalize(
            raw, dtype=np.float32) + nl.gunpowder.RasterizeSkeleton(
                ground_truth,
                labels,
                connected_component_labeling=True,
                array_spec=gp.ArraySpec(
                    voxel_size=voxel_size,
                    dtype=np.int64,
                    roi=gp.Roi(
                        gp.Coordinate([None, None, None]),
                        gp.Coordinate([None, None, None]),
                    ),
                ),
            ) + nl.gunpowder.GrowLabels(labels, radii=[neuron_width * 1000]))

        request = gp.BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        print(f"input_roi has shape: {input_roi.get_shape()}")
        print(f"cube_roi has shape: {cube_roi.get_shape()}")
        request[raw] = gp.ArraySpec(input_roi)
        request[ground_truth] = gp.GraphSpec(cube_roi)
        request[labels] = gp.ArraySpec(cube_roi)

        validation_pipelines.append((pipeline, request))
    return validation_pipelines, (raw, labels, ground_truth)
Example #26
0
def train(until):

    model = SpineUNet()
    loss = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    input_size = (8, 96, 96)

    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    affs = gp.ArrayKey('AFFS')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

    pipeline = (
        (
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample1/raw',
                    labels: 'train/sample1/labels'
                }),
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample2/raw',
                    labels: 'train/sample2/labels'
                }),
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample3/raw',
                    labels: 'train/sample3/labels'
                })
        ) +
        gp.RandomProvider() +
        gp.Normalize(raw) +
        gp.RandomLocation() +
        gp.SimpleAugment(transpose_only=(1, 2)) +
        gp.ElasticAugment((2, 10, 10), (0.0, 0.5, 0.5), [0, math.pi]) +
        gp.AddAffinities(
            [(1, 0, 0), (0, 1, 0), (0, 0, 1)],
            labels,
            affs) +
        gp.Normalize(affs, factor=1.0) +
        #gp.PreCache(num_workers=1) +
        # raw: (d, h, w)
        # affs: (3, d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        # affs: (1, 3, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        # affs: (1, 3, d, h, w)
        gp_torch.Train(
            model,
            loss,
            optimizer,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            loss_inputs={0: affs_predicted, 1: affs},
            save_every=10000) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs: (3, d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Snapshot(
            {
                raw: 'raw',
                labels: 'labels',
                affs: 'affs',
                affs_predicted: 'affs_predicted'
            },
            every=500,
            output_filename='iteration_{iteration}.hdf')
    )

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(affs, input_size)
    request.add(affs_predicted, input_size)

    with gp.build(pipeline):
        for i in range(until):
            pipeline.request_batch(request)
Example #27
0
def predict_3d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # ensure validation ROI is at least the size of the network input
    roi = raw_data.roi.grow(input_size / 2, input_size / 2)

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=roi)
    total_request[prediction] = gp.ArraySpec(roi=roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=roi)
        total_request[target] = gp.ArraySpec(roi=roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Example #28
0
    def create_train_pipeline(self, model):

        print(
            f"Creating training pipeline with batch size {self.params['batch_size']}"
        )

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        gt_aff = gp.ArrayKey('AFFINITIES')
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_shape = gp.Coordinate(model.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_shape = out_shape * out_voxel_size

        context = (in_shape - out_shape) / 2
        gt_labels_out_shape = out_shape
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            aff_neighborhood = [[0, -1, 0], [0, 0, -1]]
            gt_labels_out_shape = (1, *gt_labels_out_shape)
        else:
            aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(gt_aff, out_shape)
        request.add(predictions, out_shape)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))
        snapshot_request[gt_labels] = gp.ArraySpec(
            roi=gp.Roi(context, gt_labels_out_shape))

        source = (
            gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                          array_specs={
                              raw:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size,
                                           interpolatable=True),
                              gt_labels:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size)
                          }) + gp.Normalize(raw, self.params['norm_factor']) +
            gp.Pad(raw, context) + gp.Pad(gt_labels, context) +
            gp.RandomLocation()
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
            gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) +
            SetDtype(gt_aff, np.float32) +
            # raw      : (l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
            AddChannelDim(raw)
            # raw      : (c=1, l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
        )

        if is_2d:
            pipeline = (
                pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff)
                # raw      : (c=1, h, w)
                # gt_aff   : (c=2, h, w)
            )

        pipeline = (
            pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, h, w)
            # gt_aff   : (b, c=2, h, w)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={'raw': raw},
                loss_inputs={
                    0: predictions,
                    1: gt_aff
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(voxel_size=out_voxel_size),
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, h, w)
            # gt_aff     : (b, c=2, h, w)
            # predictions: (b, c=2, h, w)

            # Crop GT to look at labels
            gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) +
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            gt_aff: 'gt_aff',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            gp.PrintProfilingStats(every=500))

        return pipeline, request
Example #29
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_mask = gp.ArrayKey('GT_MASK')
    gt_dt = gp.ArrayKey('GT_DT')
    pred_dt = gp.ArrayKey('PRED_DT')
    loss_gradient = gp.ArrayKey('LOSS_GRADIENT')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_mask, output_shape)
    request.add(gt_dt, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    snapshot_request.add(gt_dt, output_shape)
    snapshot_request.add(pred_dt, output_shape)
    snapshot_request.add(loss_gradient, output_shape)

    # specify data source
    data_sources = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw: sample + '/raw',
                        gt_mask: sample + '/fg'
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask, np.uint8) +
                gp.Pad(raw, context) +
                gp.Pad(gt_mask, context) +
                gp.RandomLocation()
                for sample in f)

    pipeline = (
            data_sources +
            gp.RandomProvider() +
            gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) +
            DistanceTransform(gt_mask, gt_dt, 3) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2,-1) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=5) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt_dt']: gt_dt,
                },
                outputs={
                    net_names['pred_dt']: pred_dt,
                },
                gradients={
                    net_names['pred_dt']: loss_gradient,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    gt_mask: 'volumes/gt_mask',
                    gt_dt: 'volumes/gt_dt',
                    pred_dt: 'volumes/pred_dt',
                    loss_gradient: 'volumes/gradient',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2000) +
            gp.PrintProfilingStats(every=500)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Example #30
0
    def make_pipeline(self):
        raw = gp.ArrayKey('RAW')
        embs = gp.ArrayKey('EMBS')

        source_shape = zarr.open(self.data_file)[self.dataset].shape
        raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:])

        data = daisy.open_ds(self.data_file, self.dataset)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(self.model.in_shape)
        out_shape = gp.Coordinate(self.model.out_shape[2:])

        is_2d = in_shape.dims() == 2

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")
        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(embs, out_shape)

        context = (in_shape - out_shape) / 2

        source = (gp.ZarrSource(self.data_file, {
            raw: self.dataset,
        },
                                array_specs={
                                    raw:
                                    gp.ArraySpec(roi=source_roi,
                                                 interpolatable=False)
                                }))

        if is_2d:
            source = (source + AddChannelDim(raw, axis=1))
        else:
            source = (source + AddChannelDim(raw, axis=0) + AddChannelDim(raw))

        source = (
            source
            # raw      : (c=1, roi)
        )

        with gp.build(source):
            raw_roi = source.spec[raw].roi
            logger.info(f"raw_roi: {raw_roi}")

        pipeline = (
            source + gp.Normalize(raw, factor=self.params['norm_factor']) +
            gp.Pad(raw, context) + gp.PreCache() +
            gp.torch.Predict(self.model,
                             inputs={'raw': raw},
                             outputs={0: embs},
                             array_specs={embs: gp.ArraySpec(roi=raw_roi)}))

        pipeline = (pipeline +
                    gp.ZarrWrite({
                        embs: 'embs',
                    },
                                 output_dir=self.curr_log_dir,
                                 output_filename=self.dataset + '_embs.zarr',
                                 compression_type='gzip') + gp.Scan(request))

        return pipeline, request, embs