Exemplo n.º 1
0
    def get_source(self, array, overwrite_spec=None):

        if overwrite_spec:
            return gp.ZarrSource(self.filename, {array: self.ds_name},
                                 array_specs={array: overwrite_spec})
        else:
            return gp.ZarrSource(self.filename, {array: self.ds_name})
Exemplo n.º 2
0
    def build_source(self):
        data = daisy.open_ds(filename, key)

        if self.time_window is None:
            source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        else:
            offs = list(data.roi.get_offset())
            offs[1] += self.time_window[0]
            sh = list(data.roi.get_shape())
            offs[1] = self.time_window[1] - self.time_window[0]
            source_roi = gp.Roi(tuple(offs), tuple(sh))

        voxel_size = gp.Coordinate(data.voxel_size)

        return gp.ZarrSource(filename,
                             {
                                 self.raw_0: key,
                                 self.raw_1: key
                             },
                             array_specs={
                                 self.raw_0: gp.ArraySpec(
                                     roi=source_roi,
                                     voxel_size=voxel_size,
                                     interpolatable=True),
                                 self.raw_1: gp.ArraySpec(
                                     roi=source_roi,
                                     voxel_size=voxel_size,
                                     interpolatable=True)
                             })
Exemplo n.º 3
0
def predict(iteration,path_to_dataGP):
   
  
    input_size = (8, 96, 96)
    output_size = (4, 64, 64)
    amount_size = gp.Coordinate((2, 16, 16))
    model = SpineUNet(crop_output='output_size')

    raw = gp.ArrayKey('RAW')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

                                
    reference_request = gp.BatchRequest()
    reference_request.add(raw, input_size)
    reference_request.add(affs_predicted, output_size)
    
    source = gp.ZarrSource(
        path_to_dataGP,
        {
            raw: 'validate/sample1/raw'
        } 
    )
  
    with gp.build(source):
        source_roi = source.spec[raw].roi
    request = gp.BatchRequest()
    request[raw] = gp.ArraySpec(roi=source_roi)
    request[affs_predicted] = gp.ArraySpec(roi=source_roi)

    pipeline = (
        source +
       
        gp.Pad(raw,amount_size) +
        gp.Normalize(raw) +
        # raw: (d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        gp_torch.Predict(
            model,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Scan(reference_request)
    )

    with gp.build(pipeline):
        prediction = pipeline.request_batch(request)

    return prediction[raw].data, prediction[affs_predicted].data
Exemplo n.º 4
0
    def __init__(self, filename,
                 key,
                 density=None,
                 channels=0,
                 shape=(16, 256, 256),
                 time_window=None,
                 add_sparse_mosaic_channel=True,
                 random_rot=False):

        self.filename = filename
        self.key = key
        self.shape = shape
        self.density = density
        self.raw = gp.ArrayKey('RAW_0')
        self.add_sparse_mosaic_channel = add_sparse_mosaic_channel
        self.random_rot = random_rot
        self.channels = channels

        data = daisy.open_ds(filename, key)

        if time_window is None:
            source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        else:
            offs = list(data.roi.get_offset())
            offs[1] += time_window[0]
            sh = list(data.roi.get_shape())
            offs[1] = time_window[1] - time_window[0]
            source_roi = gp.Roi(tuple(offs), tuple(sh))

        voxel_size = gp.Coordinate(data.voxel_size)

        self.pipeline = gp.ZarrSource(
            filename,
            {
                self.raw: key
            },
            array_specs={
                self.raw: gp.ArraySpec(
                    roi=source_roi,
                    voxel_size=voxel_size,
                    interpolatable=True)
            }) + gp.RandomLocation() + IntensityDiffFilter(self.raw, 0, min_distance=0.1, channels=Slice(None))

        # add  augmentations
        self.pipeline = self.pipeline + gp.ElasticAugment([40, 40],
                                                          [2, 2],
                                                          [0, math.pi / 2.0],
                                                          prob_slip=-1,
                                                          spatial_dims=2)



        self.pipeline.setup()
        np.random.seed(os.getpid() + int(time.time()))
Exemplo n.º 5
0
def build_pipeline(data_dir, model, checkpoint_file, input_size, output_size,
                   raw, labels, affs_predicted, dataset_shape, num_samples,
                   sample_size):

    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['model_state_dict'])

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(affs_predicted, output_size)
    scan_request.add(labels, output_size)

    pipeline = (
        gp.ZarrSource(str(data_dir), {
            raw: 'validate/raw',
            labels: 'validate/gt'
        }) + gp.Pad(raw, size=None) + gp.Normalize(raw) +
        # raw: (s, h, w)
        # labels: (s, h, w)
        train.AddChannelDim(raw) +
        # raw: (c=1, s, h, w)
        # labels: (s, h, w)
        train.TransposeDims(raw, (1, 0, 2, 3)) +
        # raw: (s, c=1, h, w)
        # labels: (s, h, w)
        Predict(model=model, inputs={'x': raw}, outputs={0: affs_predicted}) +
        # raw: (s, c=1, h, w)
        # affs_predicted: (s, c=2, h, w)
        # labels: (s, h, w)
        train.TransposeDims(raw, (1, 0, 2, 3)) + train.RemoveChannelDim(raw) +
        # raw: (s, h, w)
        # affs_predicted: (s, c=2, h, w)
        # labels: (s, h, w)
        gp.PrintProfilingStats(every=100) + gp.Scan(scan_request))

    return pipeline
Exemplo n.º 6
0
def predict_volume(model,
                   dataset,
                   out_dir,
                   out_filename,
                   out_ds_names,
                   input_key='0/raw',
                   normalize_factor=None,
                   model_output=0,
                   in_shape=None,
                   out_shape=None,
                   spawn_subprocess=True,
                   num_workers=0):

    raw = gp.ArrayKey('RAW')
    prediction = gp.ArrayKey('PREDICTION')

    data = daisy.open_ds(dataset.filename, dataset.ds_names[0])
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    data_dims = len(data.shape)

    # Get in and out shape
    if in_shape is None:
        in_shape = model.in_shape
    if out_shape is None:
        out_shape = model.out_shape

    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    spatial_dims = in_shape.dims()

    if apply_voxel_size:
        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    context = (in_shape - out_shape) / 2

    print("context", context, in_shape, out_shape)

    source = (gp.ZarrSource(
        dataset.filename, {
            raw: dataset.ds_names[0],
        },
        array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)}))

    num_additional_channels = (2 + spatial_dims) - data_dims

    for _ in range(num_additional_channels):
        source += AddChannelDim(raw)

    # prediction requires samples first, channels second
    source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims)))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline = source

    if normalize_factor != "skip":
        pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor)

    pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict(
        model,
        inputs={input_name: raw},
        outputs={model_output: prediction},
        array_specs={prediction: gp.ArraySpec(roi=raw_roi)},
        checkpoint=checkpoint,
        spawn_subprocess=spawn_subprocess))

    # # remove sample dimension for 3D data
    # pipeline += RemoveChannelDim(raw)
    # pipeline += RemoveChannelDim(prediction)

    pipeline += (gp.ZarrWrite({
        prediction: out_ds_names[0],
    },
                              output_dir=out_dir,
                              output_filename=out_filename,
                              compression_type='gzip') +
                 gp.Scan(request, num_workers=num_workers))

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
Exemplo n.º 7
0
def predict(iteration,
            raw_file,
            raw_dataset,
            out_file,
            db_host,
            db_name,
            worker_config,
            network_config,
            out_properties={},
            **kwargs):
    setup_dir = os.path.dirname(os.path.realpath(__file__))

    with open(
            os.path.join(setup_dir,
                         '{}_net_config.json'.format(network_config)),
            'r') as f:
        net_config = json.load(f)

    # voxels
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    # nm
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size

    parameterfile = os.path.join(setup_dir, 'parameter.json')
    if os.path.exists(parameterfile):
        with open(parameterfile, 'r') as f:
            parameters = json.load(f)
    else:
        parameters = {}

    raw = gp.ArrayKey('RAW')
    pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS')
    pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR')

    chunk_request = gp.BatchRequest()
    chunk_request.add(raw, input_size)
    chunk_request.add(pred_postpre_vectors, output_size)
    chunk_request.add(pred_post_indicator, output_size)

    d_property = out_properties[
        'pred_partner_vectors'] if 'pred_partner_vectors' in out_properties else None
    m_property = out_properties[
        'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None

    # Hdf5Source
    if raw_file.endswith('.hdf'):
        pipeline = gp.Hdf5Source(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'):
        pipeline = gp.ZarrSource(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    else:
        raise RuntimeError('unknwon input data format {}'.format(raw_file))

    pipeline += gp.Pad(raw, size=None)

    pipeline += gp.Normalize(raw)

    pipeline += gp.IntensityScaleShift(raw, 2, -1)

    pipeline += gp.tensorflow.Predict(
        os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration),
        inputs={net_config['raw']: raw},
        outputs={
            net_config['pred_syn_indicator_out']: pred_post_indicator,
            net_config['pred_partner_vectors']: pred_postpre_vectors
        },
        graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config)))
    d_scale = parameters['d_scale'] if 'd_scale' in parameters else None
    if d_scale != 1 and d_scale is not None:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors, 1. / d_scale,
                                           0)  # Map back to nm world.
    if m_property is not None and 'scale' in m_property:
        if m_property['scale'] != 1:
            pipeline += gp.IntensityScaleShift(pred_post_indicator,
                                               m_property['scale'], 0)
    if d_property is not None and 'scale' in d_property:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors,
                                           d_property['scale'], 0)
    if d_property is not None and 'dtype' in d_property:
        assert d_property['dtype'] == 'int8' or d_property[
            'dtype'] == 'float32', 'predict not adapted to dtype {}'.format(
                d_property['dtype'])
        if d_property['dtype'] == 'int8':
            pipeline += IntensityScaleShiftClip(pred_postpre_vectors,
                                                1,
                                                0,
                                                clip=(-128, 127))

    pipeline += gp.ZarrWrite(dataset_names={
        pred_post_indicator:
        'volumes/pred_syn_indicator',
        pred_postpre_vectors:
        'volumes/pred_partner_vectors',
    },
                             output_filename=out_file)

    pipeline += gp.PrintProfilingStats(every=10)

    pipeline += gp.DaisyRequestBlocks(
        chunk_request,
        roi_map={
            raw: 'read_roi',
            pred_postpre_vectors: 'write_roi',
            pred_post_indicator: 'write_roi'
        },
        num_workers=worker_config['num_cache_workers'],
        block_done_callback=lambda b, s, d: block_done_callback(
            db_host, db_name, worker_config, b, s, d))

    print("Starting prediction...")
    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
    print("Prediction finished")
Exemplo n.º 8
0
    def create_train_pipeline(self, model):

        print(
            f"Creating training pipeline with batch size {self.params['batch_size']}"
        )

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        gt_aff = gp.ArrayKey('AFFINITIES')
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_shape = gp.Coordinate(model.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_shape = out_shape * out_voxel_size

        context = (in_shape - out_shape) / 2
        gt_labels_out_shape = out_shape
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            aff_neighborhood = [[0, -1, 0], [0, 0, -1]]
            gt_labels_out_shape = (1, *gt_labels_out_shape)
        else:
            aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(gt_aff, out_shape)
        request.add(predictions, out_shape)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))
        snapshot_request[gt_labels] = gp.ArraySpec(
            roi=gp.Roi(context, gt_labels_out_shape))

        source = (
            gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                          array_specs={
                              raw:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size,
                                           interpolatable=True),
                              gt_labels:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size)
                          }) + gp.Normalize(raw, self.params['norm_factor']) +
            gp.Pad(raw, context) + gp.Pad(gt_labels, context) +
            gp.RandomLocation()
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
            gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) +
            SetDtype(gt_aff, np.float32) +
            # raw      : (l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
            AddChannelDim(raw)
            # raw      : (c=1, l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
        )

        if is_2d:
            pipeline = (
                pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff)
                # raw      : (c=1, h, w)
                # gt_aff   : (c=2, h, w)
            )

        pipeline = (
            pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, h, w)
            # gt_aff   : (b, c=2, h, w)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={'raw': raw},
                loss_inputs={
                    0: predictions,
                    1: gt_aff
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(voxel_size=out_voxel_size),
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, h, w)
            # gt_aff     : (b, c=2, h, w)
            # predictions: (b, c=2, h, w)

            # Crop GT to look at labels
            gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) +
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            gt_aff: 'gt_aff',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            gp.PrintProfilingStats(every=500))

        return pipeline, request
Exemplo n.º 9
0
def train(until):

    model = SpineUNet()
    loss = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    input_size = (8, 96, 96)

    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    affs = gp.ArrayKey('AFFS')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

    pipeline = (
        (
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample1/raw',
                    labels: 'train/sample1/labels'
                }),
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample2/raw',
                    labels: 'train/sample2/labels'
                }),
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample3/raw',
                    labels: 'train/sample3/labels'
                })
        ) +
        gp.RandomProvider() +
        gp.Normalize(raw) +
        gp.RandomLocation() +
        gp.SimpleAugment(transpose_only=(1, 2)) +
        gp.ElasticAugment((2, 10, 10), (0.0, 0.5, 0.5), [0, math.pi]) +
        gp.AddAffinities(
            [(1, 0, 0), (0, 1, 0), (0, 0, 1)],
            labels,
            affs) +
        gp.Normalize(affs, factor=1.0) +
        #gp.PreCache(num_workers=1) +
        # raw: (d, h, w)
        # affs: (3, d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        # affs: (1, 3, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        # affs: (1, 3, d, h, w)
        gp_torch.Train(
            model,
            loss,
            optimizer,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            loss_inputs={0: affs_predicted, 1: affs},
            save_every=10000) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs: (3, d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Snapshot(
            {
                raw: 'raw',
                labels: 'labels',
                affs: 'affs',
                affs_predicted: 'affs_predicted'
            },
            every=500,
            output_filename='iteration_{iteration}.hdf')
    )

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(affs, input_size)
    request.add(affs_predicted, input_size)

    with gp.build(pipeline):
        for i in range(until):
            pipeline.request_batch(request)
Exemplo n.º 10
0
def predict_volume(model,
                   dataset,
                   out_dir,
                   out_filename,
                   out_ds_names,
                   checkpoint,
                   input_name='raw_0',
                   normalize_factor=None,
                   model_output=0,
                   in_shape=None,
                   out_shape=None,
                   spawn_subprocess=True,
                   num_workers=0,
                   apply_voxel_size=True):

    raw = gp.ArrayKey('RAW')
    prediction = gp.ArrayKey('PREDICTION')

    data = daisy.open_ds(dataset.filename, dataset.ds_names[0])
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    data_dims = len(data.shape)

    # Get in and out shape
    if in_shape is None:
        in_shape = model.in_shape
    if out_shape is None:
        out_shape = model.out_shape

    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    spatial_dims = in_shape.dims()
    is_2d = spatial_dims == 2

    in_shape = in_shape * voxel_size
    out_shape = out_shape * voxel_size

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    context = (in_shape - out_shape) / 2

    source = (gp.ZarrSource(
        dataset.filename, {
            raw: dataset.ds_names[0],
        },
        array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)}))

    # ensure raw has sample and channel dims
    #
    # n = number of samples
    # c = number of channels

    # 2D raw is either (n, y, x) or (c, n, y, x)
    # 3D raw is either (z, y, x) or (c, z, y, x)
    for _ in range((2 + spatial_dims) - data_dims):
        source += AddChannelDim(raw)

    # 2D raw: (c, n, y, x)
    # 3D raw: (c, n=1, z, y, x)

    # prediction requires samples first, channels second
    source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims)))

    # 2D raw: (n, c, y, x)
    # 3D raw: (n=1, c, z, y, x)

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline = source

    if normalize_factor != "skip":
        pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor)

    pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict(
        model,
        inputs={input_name: raw},
        outputs={model_output: prediction},
        array_specs={prediction: gp.ArraySpec(roi=raw_roi)},
        checkpoint=checkpoint,
        spawn_subprocess=spawn_subprocess))

    # 2D raw       : (n, c, y, x)
    # 2D prediction: (n, c, y, x)
    # 3D raw       : (n=1, c, z, y, x)
    # 3D prediction: (n=1, c, z, y, x)

    if is_2d:

        # restore channels first for 2D data
        pipeline += TransposeDims(raw,
                                  (1, 0) + tuple(range(2, 2 + spatial_dims)))
        pipeline += TransposeDims(prediction,
                                  (1, 0) + tuple(range(2, 2 + spatial_dims)))

    else:

        # remove sample dimension for 3D data
        pipeline += RemoveChannelDim(raw)
        pipeline += RemoveChannelDim(prediction)

    # 2D raw       : (c, n, y, x)
    # 2D prediction: (c, n, y, x)
    # 3D raw       : (c, z, y, x)
    # 3D prediction: (c, z, y, x)

    pipeline += (gp.ZarrWrite({
        prediction: out_ds_names[0],
    },
                              output_dir=out_dir,
                              output_filename=out_filename,
                              compression_type='gzip') +
                 gp.Scan(request, num_workers=num_workers))

    logger.info("Writing prediction to %s/%s[%s]", out_dir, out_filename,
                out_ds_names[0])

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
Exemplo n.º 11
0
def train_until(max_iteration):

    in_channels = 1
    num_fmaps = 12
    fmap_inc_factors = 6
    downsample_factors = [(1, 3, 3), (1, 3, 3), (3, 3, 3)]

    unet = UNet(in_channels,
                num_fmaps,
                fmap_inc_factors,
                downsample_factors,
                constant_upsample=True)

    model = Convolve(unet, 12, 1)

    loss = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

    # start of gunpowder part:

    raw = gp.ArrayKey('RAW')
    points = gp.GraphKey('POINTS')
    groundtruth = gp.ArrayKey('RASTER')
    prediction = gp.ArrayKey('PRED_POINT')
    grad = gp.ArrayKey('GRADIENT')

    voxel_size = gp.Coordinate((40, 4, 4))

    input_shape = (96, 430, 430)
    output_shape = (60, 162, 162)

    input_size = gp.Coordinate(input_shape) * voxel_size
    output_size = gp.Coordinate(output_shape) * voxel_size

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(points, output_size)
    request.add(groundtruth, output_size)
    request.add(prediction, output_size)
    request.add(grad, output_size)

    pos_sources = tuple(
        gp.ZarrSource(filename, {raw: 'volumes/raw'},
                      {raw: gp.ArraySpec(interpolatable=True)}) +
        AddCenterPoint(points, raw) + gp.Pad(raw, None) +
        gp.RandomLocation(ensure_nonempty=points)
        for filename in pos_samples) + gp.RandomProvider()
    neg_sources = tuple(
        gp.ZarrSource(filename, {raw: 'volumes/raw'},
                      {raw: gp.ArraySpec(interpolatable=True)}) +
        AddNoPoint(points, raw) + gp.RandomLocation()
        for filename in neg_samples) + gp.RandomProvider()

    data_sources = (pos_sources, neg_sources)
    data_sources += gp.RandomProvider(probabilities=[0.9, 0.1])
    data_sources += gp.Normalize(raw)

    train_pipeline = data_sources
    train_pipeline += gp.ElasticAugment(control_point_spacing=[4, 40, 40],
                                        jitter_sigma=[0, 2, 2],
                                        rotation_interval=[0, math.pi / 2.0],
                                        prob_slip=0.05,
                                        prob_shift=0.05,
                                        max_misalign=10,
                                        subsample=8)
    train_pipeline += gp.SimpleAugment(transpose_only=[1, 2])

    train_pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, \
            z_section_wise=True)
    train_pipeline += gp.RasterizePoints(
        points,
        groundtruth,
        array_spec=gp.ArraySpec(voxel_size=voxel_size),
        settings=gp.RasterizationSettings(radius=(100, 100, 100), mode='peak'))
    train_pipeline += gp.PreCache(cache_size=40, num_workers=10)

    train_pipeline += Reshape(raw, (1, 1) + input_shape)
    train_pipeline += Reshape(groundtruth, (1, 1) + output_shape)

    train_pipeline += gp_torch.Train(model=model,
                                     loss=loss,
                                     optimizer=optimizer,
                                     inputs={'x': raw},
                                     outputs={0: prediction},
                                     loss_inputs={
                                         0: prediction,
                                         1: groundtruth
                                     },
                                     gradients={0: grad},
                                     save_every=1000,
                                     log_dir='log')

    train_pipeline += Reshape(raw, input_shape)
    train_pipeline += Reshape(groundtruth, output_shape)
    train_pipeline += Reshape(prediction, output_shape)
    train_pipeline += Reshape(grad, output_shape)

    train_pipeline += gp.Snapshot(
        {
            raw: 'volumes/raw',
            groundtruth: 'volumes/groundtruth',
            prediction: 'volumes/prediction',
            grad: 'volumes/gradient'
        },
        every=500,
        output_filename='test_{iteration}.hdf')
    train_pipeline += gp.PrintProfilingStats(every=10)

    with gp.build(train_pipeline):
        for i in range(max_iteration):
            train_pipeline.request_batch(request)
Exemplo n.º 12
0
    def make_pipeline(self):
        raw = gp.ArrayKey('RAW')
        embs = gp.ArrayKey('EMBS')

        source_shape = zarr.open(self.data_file)[self.dataset].shape
        raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:])

        data = daisy.open_ds(self.data_file, self.dataset)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(self.model.in_shape)
        out_shape = gp.Coordinate(self.model.out_shape[2:])

        is_2d = in_shape.dims() == 2

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")
        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(embs, out_shape)

        context = (in_shape - out_shape) / 2

        source = (gp.ZarrSource(self.data_file, {
            raw: self.dataset,
        },
                                array_specs={
                                    raw:
                                    gp.ArraySpec(roi=source_roi,
                                                 interpolatable=False)
                                }))

        if is_2d:
            source = (source + AddChannelDim(raw, axis=1))
        else:
            source = (source + AddChannelDim(raw, axis=0) + AddChannelDim(raw))

        source = (
            source
            # raw      : (c=1, roi)
        )

        with gp.build(source):
            raw_roi = source.spec[raw].roi
            logger.info(f"raw_roi: {raw_roi}")

        pipeline = (
            source + gp.Normalize(raw, factor=self.params['norm_factor']) +
            gp.Pad(raw, context) + gp.PreCache() +
            gp.torch.Predict(self.model,
                             inputs={'raw': raw},
                             outputs={0: embs},
                             array_specs={embs: gp.ArraySpec(roi=raw_roi)}))

        pipeline = (pipeline +
                    gp.ZarrWrite({
                        embs: 'embs',
                    },
                                 output_dir=self.curr_log_dir,
                                 output_filename=self.dataset + '_embs.zarr',
                                 compression_type='gzip') + gp.Scan(request))

        return pipeline, request, embs
Exemplo n.º 13
0
def predict_frame(in_shape,
                  out_shape,
                  model_output,
                  model_configfile,
                  model_checkpoint,
                  input_dataset_file,
                  inference_frame,
                  out_dir,
                  out_filename,
                  out_key_or_index=1,
                  intermediate_layer=None,
                  dataset_raw_key="train/raw",
                  dataset_prediction_key="train/prediction",
                  dataset_intermediate_key="train/prediction_interm",
                  model_input_tensor_name="patches",
                  model_architecture="PatchedResnet",
                  num_workers=5):

    # initialize model
    if model_architecture == "PatchedResnet":
        model = PatchedResnet(1, 2, resnet_size=18)
    elif model_architecture == "unet":
        model = lisl.models.create(model_configfile)
    else:
        raise NotImplementedError(f"{model_architecture} not implemented")

    model.add_spatial_dim = True
    model.eval()

    # gp variables
    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    raw = gp.ArrayKey(f'RAW_{inference_frame}')
    prediction = gp.ArrayKey(f'PREDICTION_{inference_frame}')
    intermediate_prediction = gp.ArrayKey(f'ITERM_{inference_frame}')

    ds_key = f'{dataset_raw_key}/{inference_frame}'
    out_key = f'{dataset_prediction_key}/{inference_frame}'
    interm_key = f'{dataset_intermediate_key}/{inference_frame}'

    # build pipeline
    zsource = gp.ZarrSource(
        input_dataset_file, {raw: ds_key},
        {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))})

    pipeline = zsource
    with gp.build(zsource):
        raw_roi = zsource.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline += AddChannelDim(raw)
    pipeline += AddChannelDim(raw)

    pipeline += gp.Pad(raw, None)
    # setup prediction node
    pred_dict = {out_key_or_index: prediction}
    pred_spec = {prediction: gp.ArraySpec(roi=raw_roi)}
    if intermediate_layer is not None:
        pred_dict[intermediate_layer] = intermediate_prediction
        pred_spec[intermediate_prediction] = gp.ArraySpec(roi=raw_roi)

    pipeline += gp.torch.Predict(model,
                                 inputs={model_input_tensor_name: raw},
                                 outputs=pred_dict,
                                 array_specs=pred_spec,
                                 checkpoint=model_checkpoint,
                                 spawn_subprocess=True)

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    zarr_dict = {prediction: out_key}
    if intermediate_layer is not None:
        zarr_dict[intermediate_prediction] = interm_key
        request.add(intermediate_prediction, out_shape)
    pipeline += gp.Scan(request, num_workers=num_workers)
    pipeline += gp.ZarrWrite(zarr_dict,
                             output_dir=out_dir,
                             output_filename=out_filename,
                             compression_type='gzip')

    total_request = gp.BatchRequest()
    total_request[prediction] = gp.ArraySpec(roi=raw_roi)
    if intermediate_layer is not None:
        total_request[intermediate_prediction] = gp.ArraySpec(roi=raw_roi)
    with gp.build(pipeline):
        pipeline.request_batch(total_request)
Exemplo n.º 14
0
def get_mouselight_data_sources(setup_config: Dict[str, Any],
                                source_samples: List[str],
                                locations=False):
    # Source Paths and accessibility
    raw_n5 = setup_config["RAW_N5"]
    mongo_url = setup_config["MONGO_URL"]
    samples_path = Path(setup_config["SAMPLES_PATH"])

    # specified_locations = setup_config.get("SPECIFIED_LOCATIONS")

    # Graph matching parameters
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    matching_failures_dir = setup_config["MATCHING_FAILURES_DIR"]
    matching_failures_dir = (matching_failures_dir
                             if matching_failures_dir is None else
                             Path(matching_failures_dir))

    # Data Properties
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    output_size = output_shape * voxel_size
    micron_scale = voxel_size[0]

    distance_attr = setup_config["DISTANCE_ATTRIBUTE"]
    target_distance = float(setup_config["MIN_DIST_TO_FALLBACK"])
    max_nonempty_points = int(setup_config["MAX_RANDOM_LOCATION_POINTS"])

    mongo_db_template = setup_config["MONGO_DB_TEMPLATE"]
    matched_source = setup_config.get("MATCHED_SOURCE", "matched")

    # New array keys
    # Note: These are intended to be requested with size input_size
    raw = ArrayKey("RAW")
    matched = gp.PointsKey("MATCHED")
    nonempty_placeholder = gp.PointsKey("NONEMPTY")
    labels = ArrayKey("LABELS")

    ensure_nonempty = nonempty_placeholder

    node_offset = {
        sample.name: (daisy.persistence.MongoDbGraphProvider(
            mongo_db_template.format(sample=sample.name,
                                     source="skeletonization"),
            mongo_url,
        ).num_nodes(None) + 1)
        for sample in samples_path.iterdir() if sample.name in source_samples
    }

    # if specified_locations is not None:
    #     centers = pickle.load(open(specified_locations, "rb"))
    #     random = gp.SpecifiedLocation
    #     kwargs = {"locations": centers, "choose_randomly": True}
    #     logger.info(f"Using specified locations from {specified_locations}")
    # elif locations:
    #     random = RandomLocations
    #     kwargs = {
    #         "ensure_nonempty": ensure_nonempty,
    #         "ensure_centered": True,
    #         "point_balance_radius": point_balance_radius * micron_scale,
    #         "loc": gp.ArrayKey("RANDOM_LOCATION"),
    #     }
    # else:

    random = RandomLocation
    kwargs = {
        "ensure_nonempty": ensure_nonempty,
        "ensure_centered": True,
        "point_balance_radius": point_balance_radius * micron_scale,
    }

    data_sources = (tuple(
        (
            gp.ZarrSource(
                filename=str((sample / raw_n5).absolute()),
                datasets={raw: "volume-rechunked"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            DaisyGraphProvider(
                mongo_db_template.format(sample=sample.name,
                                         source=matched_source),
                mongo_url,
                points=[matched],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            FilteredDaisyGraphProvider(
                mongo_db_template.format(sample=sample.name,
                                         source=matched_source),
                mongo_url,
                points=[nonempty_placeholder],
                directed=True,
                node_attrs=["distance_to_fallback"],
                edge_attrs=[],
                num_nodes=max_nonempty_points,
                dist_attribute=distance_attr,
                min_dist=target_distance,
            ),
        ) + gp.MergeProvider() + random(**kwargs) + gp.Normalize(raw) +
        FilterComponents(
            matched, node_offset[sample.name], centroid_size=output_size) +
        RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.int64),
        ) for sample in samples_path.iterdir()
        if sample.name in source_samples) + gp.RandomProvider())

    return (data_sources, raw, labels, nonempty_placeholder, matched)
Exemplo n.º 15
0
def validation_pipeline(config):
    """
    Per block
    {
        Raw -> predict -> scan
        gt -> rasterize        -> merge -> candidates -> trees
    } -> merge -> comatch + evaluate
    """
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]
    candidate_threshold = config["NMS_THRESHOLD"]
    candidate_spacing = min(config["NMS_WINDOW_SIZE"]) * micron_scale
    coordinate_scale = config["COORDINATE_SCALE"] * np.array(
        voxel_size) / micron_scale

    emb_model = get_emb_model(config)
    fg_model = get_fg_model(config)

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        raw = gp.ArrayKey(f"RAW_{block}")
        ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}")
        labels = gp.ArrayKey(f"LABELS_{block}")
        candidates = gp.ArrayKey(f"CANDIDATES_{block}")
        mst = gp.GraphKey(f"MST_{block}")

        raw_source = (gp.ZarrSource(
            filename=str(Path(sample_dir, sample, raw_n5).absolute()),
            datasets={raw: "volume-rechunked"},
            array_specs={
                raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size)
            },
        ) + gp.Normalize(raw, dtype=np.float32) + mCLAHE([raw], [20, 64, 64]))
        emb_source, emb = add_emb_pred(config, raw_source, raw, block,
                                       emb_model)
        pred_source, fg = add_fg_pred(config, emb_source, raw, block, fg_model)
        pred_source = add_scan(pred_source, {
            raw: input_size,
            emb: output_size,
            fg: output_size
        })
        swc_source = nl.gunpowder.nodes.MouselightSwcFileSource(
            validation_dir,
            [ground_truth],
            transform_file=transform_template.format(sample=sample),
            ignore_human_nodes=False,
            scale=voxel_size,
            transpose=[2, 1, 0],
            points_spec=[
                gp.PointsSpec(roi=gp.Roi(
                    gp.Coordinate([None, None, None]),
                    gp.Coordinate([None, None, None]),
                ))
            ],
        )

        additional_request = BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        block_spec = specs.setdefault(block, {})
        block_spec["raw"] = (raw, gp.ArraySpec(input_roi))
        additional_request[raw] = gp.ArraySpec(roi=input_roi)
        block_spec["ground_truth"] = (ground_truth, gp.GraphSpec(cube_roi))
        additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi)
        block_spec["labels"] = (labels, gp.ArraySpec(cube_roi))
        additional_request[labels] = gp.ArraySpec(roi=cube_roi)
        block_spec["fg_pred"] = (fg, gp.ArraySpec(cube_roi))
        additional_request[fg] = gp.ArraySpec(roi=cube_roi)
        block_spec["emb_pred"] = (emb, gp.ArraySpec(cube_roi))
        additional_request[emb] = gp.ArraySpec(roi=cube_roi)
        block_spec["candidates"] = (candidates, gp.ArraySpec(cube_roi))
        additional_request[candidates] = gp.ArraySpec(roi=cube_roi)
        block_spec["mst_pred"] = (mst, gp.GraphSpec(cube_roi))
        additional_request[mst] = gp.GraphSpec(roi=cube_roi)

        pipeline = ((swc_source, pred_source) + gp.nodes.MergeProvider() +
                    nl.gunpowder.RasterizeSkeleton(
                        ground_truth,
                        labels,
                        connected_component_labeling=True,
                        array_spec=gp.ArraySpec(
                            voxel_size=voxel_size,
                            dtype=np.int64,
                            roi=gp.Roi(
                                gp.Coordinate([None, None, None]),
                                gp.Coordinate([None, None, None]),
                            ),
                        ),
                    ) + nl.gunpowder.GrowLabels(
                        labels, radii=[neuron_width * micron_scale]) +
                    Skeletonize(fg, candidates, candidate_spacing,
                                candidate_threshold) + EMST(
                                    emb,
                                    candidates,
                                    mst,
                                    distance_attr=distance_attr,
                                    coordinate_scale=coordinate_scale,
                                ) + gp.Snapshot(
                                    {
                                        raw: f"volumes/{raw}",
                                        ground_truth: f"points/{ground_truth}",
                                        labels: f"volumes/{labels}",
                                        fg: f"volumes/{fg}",
                                        emb: f"volumes/{emb}",
                                        candidates: f"volumes/{candidates}",
                                        mst: f"points/{mst}",
                                    },
                                    additional_request=additional_request,
                                    output_dir="snapshots",
                                    output_filename="{id}.hdf",
                                    edge_attrs={mst: [distance_attr]},
                                ))

        validation_pipelines.append(pipeline)

    full_gt = gp.GraphKey("FULL_GT")
    full_mst = gp.GraphKey("FULL_MST")
    score = gp.ArrayKey("SCORE")

    validation_pipeline = (
        tuple(pipeline for pipeline in validation_pipelines) +
        gp.MergeProvider() + MergeGraphs(specs, full_gt, full_mst) +
        Evaluate(full_gt, full_mst, score, edge_threshold_attr=distance_attr) +
        gp.PrintProfilingStats())
    return validation_pipeline, score
Exemplo n.º 16
0
def validation_data_sources_recomputed(config, blocks):
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    validation_dirs = {}
    for group in benchmark_datasets_path.iterdir():
        if "validation" in group.name and group.is_dir():
            for validation_dir in group.iterdir():
                validation_num = int(validation_dir.name.split("_")[-1])
                if validation_num in blocks:
                    validation_dirs[validation_num] = validation_dir

    validation_dirs = [validation_dirs[block] for block in blocks]

    raw = gp.ArrayKey("RAW")
    ground_truth = gp.GraphKey("GROUND_TRUTH")
    labels = gp.ArrayKey("LABELS")

    validation_pipelines = []
    for validation_dir in validation_dirs:
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        pipeline = ((
            gp.ZarrSource(
                filename=str(Path(sample_dir, sample, raw_n5).absolute()),
                datasets={raw: "volume-rechunked"},
                array_specs={
                    raw: gp.ArraySpec(interpolatable=True,
                                      voxel_size=voxel_size)
                },
            ),
            nl.gunpowder.nodes.MouselightSwcFileSource(
                validation_dir,
                [ground_truth],
                transform_file=transform_template.format(sample=sample),
                ignore_human_nodes=False,
                scale=voxel_size,
                transpose=[2, 1, 0],
                points_spec=[
                    gp.PointsSpec(roi=gp.Roi(
                        gp.Coordinate([None, None, None]),
                        gp.Coordinate([None, None, None]),
                    ))
                ],
            ),
        ) + gp.nodes.MergeProvider() + gp.Normalize(
            raw, dtype=np.float32) + nl.gunpowder.RasterizeSkeleton(
                ground_truth,
                labels,
                connected_component_labeling=True,
                array_spec=gp.ArraySpec(
                    voxel_size=voxel_size,
                    dtype=np.int64,
                    roi=gp.Roi(
                        gp.Coordinate([None, None, None]),
                        gp.Coordinate([None, None, None]),
                    ),
                ),
            ) + nl.gunpowder.GrowLabels(labels, radii=[neuron_width * 1000]))

        request = gp.BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        print(f"input_roi has shape: {input_roi.get_shape()}")
        print(f"cube_roi has shape: {cube_roi.get_shape()}")
        request[raw] = gp.ArraySpec(input_roi)
        request[ground_truth] = gp.GraphSpec(cube_roi)
        request[labels] = gp.ArraySpec(cube_roi)

        validation_pipelines.append((pipeline, request))
    return validation_pipelines, (raw, labels, ground_truth)
Exemplo n.º 17
0
                 chunks=(50, 50, 50),
                 overwrite=True)
f["sphere"].attrs["offset"] = (0, 0, 0)
f["sphere"].attrs["resolution"] = (1, 1, 1)
f["sphere"][:] = sphere

# declare arrays to use
labels = gp.ArrayKey("LABELS")
stardists = gp.ArrayKey("STARDIST")

# prepare requests
scan_request = gp.BatchRequest()
scan_request[stardists] = gp.Roi((0, 0, 0), (50, 50, 50))
request = gp.BatchRequest()

source = gp.ZarrSource(os.path.join(directory, "sphere.n5"),
                       datasets={labels: "sphere"})

# prepare node for 3D stardist generation with a maximum distance
stardist_gen = gpstardist.AddStarDist3D(labels,
                                        stardists,
                                        rays=96,
                                        anisotropy=(1, 1, 1),
                                        grid=(1, 1, 1),
                                        max_dist=max_dist,
                                        unlabeled_id=-1,
                                        invalid_value=-3)

# write result to a new dataset
writer = gp.ZarrWrite(
    output_dir=directory,
    output_filename="sphere.n5",
Exemplo n.º 18
0
def random_point_pairs_pipeline(model,
                                loss,
                                optimizer,
                                dataset,
                                augmentation_parameters,
                                point_density,
                                out_dir,
                                normalize_factor=None,
                                checkpoint_interval=5000,
                                snapshot_interval=5000):

    raw_0 = gp.ArrayKey('RAW_0')
    points_0 = gp.GraphKey('POINTS_0')
    locations_0 = gp.ArrayKey('LOCATIONS_0')
    emb_0 = gp.ArrayKey('EMBEDDING_0')
    raw_1 = gp.ArrayKey('RAW_1')
    points_1 = gp.GraphKey('POINTS_1')
    locations_1 = gp.ArrayKey('LOCATIONS_1')
    emb_1 = gp.ArrayKey('EMBEDDING_1')

    # TODO parse this key from somewhere
    key = 'train/raw/0'

    data = daisy.open_ds(dataset.filename, key)
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    emb_voxel_size = voxel_size

    # Get in and out shape
    in_shape = gp.Coordinate(model.in_shape)
    out_shape = gp.Coordinate(model.out_shape)

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw_0, in_shape)
    request.add(raw_1, in_shape)
    request.add(points_0, out_shape)
    request.add(points_1, out_shape)
    request[locations_0] = gp.ArraySpec(nonspatial=True)
    request[locations_1] = gp.ArraySpec(nonspatial=True)

    snapshot_request = gp.BatchRequest()
    snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi)
    snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi)

    # Let's hardcode this for now
    # TODO read actual number from zarr file keys
    n_samples = 447
    batch_size = 1
    dim = 2
    padding = (100, 100)

    sources = []
    for i in range(n_samples):

        ds_key = f'train/raw/{i}'
        image_sources = tuple(
            gp.ZarrSource(
                dataset.filename, {raw: ds_key},
                {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) +
            gp.Pad(raw, None) for raw in [raw_0, raw_1])

        random_point_generator = RandomPointGenerator(density=point_density,
                                                      repetitions=2)

        point_sources = tuple(
            (RandomPointSource(points_0,
                               dim,
                               random_point_generator=random_point_generator),
             RandomPointSource(points_1,
                               dim,
                               random_point_generator=random_point_generator)))

        # TODO: get augmentation parameters from some config file!
        points_and_image_sources = tuple(
            (img_source, point_source) + gp.MergeProvider() + \
            gp.SimpleAugment() + \
            gp.ElasticAugment(
                spatial_dims=2,
                control_point_spacing=(10, 10),
                jitter_sigma=(0.0, 0.0),
                rotation_interval=(0, math.pi/2)) + \
            gp.IntensityAugment(r,
                                scale_min=0.8,
                                scale_max=1.2,
                                shift_min=-0.2,
                                shift_max=0.2,
                                clip=False) + \
            gp.NoiseAugment(r, var=0.01, clip=False)
            for r, img_source, point_source
            in zip([raw_0, raw_1], image_sources, point_sources))

        sample_source = points_and_image_sources + gp.MergeProvider()

        data = daisy.open_ds(dataset.filename, ds_key)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        sample_source += gp.Crop(raw_0, source_roi)
        sample_source += gp.Crop(raw_1, source_roi)
        sample_source += gp.Pad(raw_0, padding)
        sample_source += gp.Pad(raw_1, padding)
        sample_source += gp.RandomLocation()
        sources.append(sample_source)

    sources = tuple(sources)

    pipeline = sources + gp.RandomProvider()
    pipeline += gp.Unsqueeze([raw_0, raw_1])

    pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0,
                             locations_1)

    # How does prepare batch relate to Stack?????
    pipeline += RejectArray(ensure_nonempty=locations_1)
    pipeline += RejectArray(ensure_nonempty=locations_0)

    # batch content
    # raw_0:          (1, h, w)
    # raw_1:          (1, h, w)
    # locations_0:    (n, 2)
    # locations_1:    (n, 2)

    pipeline += gp.Stack(batch_size)

    # batch content
    # raw_0:          (b, 1, h, w)
    # raw_1:          (b, 1, h, w)
    # locations_0:    (b, n, 2)
    # locations_1:    (b, n, 2)

    pipeline += gp.PreCache(num_workers=10)

    pipeline += gp.torch.Train(
        model,
        loss,
        optimizer,
        inputs={
            'raw_0': raw_0,
            'raw_1': raw_1
        },
        loss_inputs={
            'emb_0': emb_0,
            'emb_1': emb_1,
            'locations_0': locations_0,
            'locations_1': locations_1
        },
        outputs={
            2: emb_0,
            3: emb_1
        },
        array_specs={
            emb_0: gp.ArraySpec(voxel_size=emb_voxel_size),
            emb_1: gp.ArraySpec(voxel_size=emb_voxel_size)
        },
        checkpoint_basename=os.path.join(out_dir, 'model'),
        save_every=checkpoint_interval)

    pipeline += gp.Snapshot(
        {
            raw_0: 'raw_0',
            raw_1: 'raw_1',
            emb_0: 'emb_0',
            emb_1: 'emb_1',
            # locations_0 : 'locations_0',
            # locations_1 : 'locations_1',
        },
        every=snapshot_interval,
        additional_request=snapshot_request)

    return pipeline, request
def batch__aug_data_generator(input_path,
                              batch_size=12,
                              voxel_shape=[1, 1, 1],
                              input_shape=[240, 240, 4],
                              output_shape=[240, 240, 4],
                              without_background=False,
                              mix_output=False,
                              validate=False,
                              seq=None):
    raw = gp.ArrayKey('raw')
    gt = gp.ArrayKey('ground_truth')
    files = os.listdir(input_path)
    files = [os.path.join(input_path, f) for f in files]
    pipeline = (
        tuple(
            gp.ZarrSource(
                files[t],  # the zarr container
                {
                    raw: 'raw',
                    gt: 'ground_truth'
                },  # which dataset to associate to the array key
                {
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 dtype=np.dtype('float32'),
                                 voxel_size=voxel_shape),
                    gt:
                    gp.ArraySpec(interpolatable=True,
                                 dtype=np.dtype('float32'),
                                 voxel_size=voxel_shape)
                }  # meta-information
            ) + gp.RandomLocation()
            for t in range(len(files))) + gp.RandomProvider()
        #    +gp.Stack(batch_size)
    )
    input_size = gp.Coordinate(input_shape)
    output_size = gp.Coordinate(output_shape)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, input_size)
    diff = input_shape[1] - output_shape[1]
    diff = int(diff / 2)
    max_p = input_shape[1] - diff
    different_shape = diff > 0
    if different_shape:
        print('Difference padding: {}'.format(diff))
    with gp.build(pipeline):
        while 1:
            b = 0
            imgs = []
            masks = []
            while b < batch_size:
                valid = False
                batch = pipeline.request_batch(request)
                if validate:
                    valid = validate_mask(batch[gt].data)
                else:
                    valid = True

                while (valid == False):
                    batch = pipeline.request_batch(request)
                    valid = validate_mask(batch[gt].data)
                im = batch[raw].data
                out = batch[gt].data
                # if different_shape:
                #     out = out[diff:max_p,diff:max_p,:]
                if without_background:
                    out = out[:, :, 1:4]
                if mix_output:
                    out = out.argmax(axis=3).astype(float)
                imgs.append(im)
                masks.append(out)
                b = b + 1
            imgs = np.asarray(imgs)
            masks = np.asarray(masks)
            if seq is not None:
                imgs, masks = augmentation(imgs, masks, seq)
            if different_shape:
                out = []
                for m in masks:
                    out.append(m[diff:max_p, diff:max_p, :])
                masks = np.asarray(out)
            yield imgs, masks
            

def batch_data_generator(input_path,batch_size=12,voxel_shape = [1,1,1],
                             input_shape= [240, 240,4],
                             output_shape = [240, 240,4],
                             without_background = False,
                                 mix_output = False, 
                                 validate = False:
    raw = gp.ArrayKey('raw')
    gt = gp.ArrayKey('ground_truth')
    files = os.listdir(input_path)
    files = [os.path.join(input_path,f) for f in files ]
    pipeline =( tuple (
        gp.ZarrSource(
            files[t],  # the zarr container
            {raw: 'raw', gt : 'ground_truth'},  # which dataset to associate to the array key
            {raw: gp.ArraySpec(interpolatable=True,dtype=np.dtype('float32'),voxel_size=voxel_shape),
             gt: gp.ArraySpec(interpolatable=True,dtype=np.dtype('float32'),voxel_size=voxel_shape)}  # meta-information
        )
        + gp.RandomLocation()
        for t in range(len(files))
    )
               + gp.RandomProvider()
            #    +gp.Stack(batch_size)
              )
    input_size = gp.Coordinate(input_shape)
    output_size = gp.Coordinate(output_shape)

    request = gp.BatchRequest()
    request.add(raw,input_size)
    request.add(gt,input_size)
    diff = input_shape[1] - output_shape[1]
Exemplo n.º 21
0
    def create_train_pipeline(self, model):

        print(f"Creating training pipeline with batch size \
              {self.params['batch_size']}")

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        points = gp.GraphKey("POINTS")
        locations = gp.ArrayKey("LOCATIONS")
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_roi = gp.Coordinate(model.base_encoder.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_roi = out_roi * out_voxel_size
        out_shape = gp.Coordinate(
            (self.params["num_points"], *model.out_shape[2:]))

        context = (in_shape - out_roi) / 2
        gt_labels_out_shape = out_roi
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            gt_labels_out_shape = (1, *gt_labels_out_shape)

            points_roi = out_voxel_size * tuple((*self.params["point_roi"], ))
            points_pad = (0, *points_roi)
            context = gp.Coordinate((0, None, None))
        else:
            points_roi = source_voxel_size * tuple(self.params["point_roi"])
            points_pad = points_roi
            context = gp.Coordinate((None, None, None))

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")
        logger.info(f"out_voxel_size: {out_voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(points, points_roi)
        request.add(gt_labels, out_roi)
        request[locations] = gp.ArraySpec(nonspatial=True)
        request[predictions] = gp.ArraySpec(nonspatial=True)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))

        source = (
            (gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                           array_specs={
                               raw:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size,
                                            interpolatable=True),
                               gt_labels:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size)
                           }),
             PointsLabelsSource(points, self.data, scale=source_voxel_size)) +
            gp.MergeProvider() + gp.Pad(raw, context) +
            gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) +
            gp.RandomLocation(ensure_nonempty=points) +
            gp.Normalize(raw, self.params['norm_factor'])
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            # If 2d then source_roi = (1, input_shape) in order to select a RL
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # Batches seem to be rejected because points are chosen near the
            # edge of the points ROI and the augmentations remove them.
            # TODO: Figure out if this is an actual issue, and if anything can
            # be done.
            gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) +
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            AddChannelDim(raw) + AddChannelDim(gt_labels)
            # raw      : (c=1, source_roi)
            # gt_labels: (c=2, source_roi)
            # points   : (c=1, source_locations_shape)
        )

        if is_2d:
            pipeline = (
                # Remove extra dim the 2d roi had
                pipeline + RemoveSpatialDim(raw) +
                RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points)
                # raw      : (c=1, roi)
                # gt_labels: (c=1, roi)
                # points   : (c=1, locations_shape)
            )

        pipeline = (
            pipeline +
            FillLocations(raw, points, locations, is_2d=False, max_points=1) +
            gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, roi)
            # gt_labels: (b, c=1, roi)
            # locations: (b, c=1, locations_shape)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={
                    'raw': raw,
                    'points': locations
                },
                loss_inputs={
                    0: predictions,
                    1: gt_labels,
                    2: locations
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(nonspatial=True),
                    emb: gp.ArraySpec(voxel_size=out_voxel_size)
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, roi)
            # gt_labels  : (b, c=1, roi)
            # predictions: (b, num_points)
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            InspectBatch('END') + gp.PrintProfilingStats(every=500))

        return pipeline, request
Exemplo n.º 22
0
def build_pipeline(
        data_dir,  
        model, 
        save_every,
        batch_size, 
        input_size, 
        output_size,
        raw, 
        labels,
        affs,
        affs_predicted,
        lr=1e-5): 

    dataset_shape = zarr.open(str(data_dir))['train/raw'].shape
    num_samples = dataset_shape[0]
    sample_size = dataset_shape[1:]

    loss = torch.nn.MSELoss()
    optimizer = RAdam(model.parameters(), lr=lr)
    
    pipeline = (
            gp.ZarrSource(
                data_dir,
                {
                    raw: 'train/raw',
                    labels: 'train/gt'
                },
                array_specs={
                    raw: gp.ArraySpec(
                        roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size),
                        voxel_size=(1, 1, 1)),
                    labels: gp.ArraySpec(
                        roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size),
                        voxel_size=(1, 1, 1))
                }) +
            # raw: (d=1, h, w)
            # labels: (d=1, fmap_inc_factors=5h, w)
            gp.RandomLocation() +
            # raw: (d=1, h, w)
            # labels: (d=1, h, w)
            gp.AddAffinities(
                affinity_neighborhood=[(0, 1, 0), (0, 0, 1)],
                labels=labels,
                affinities=affs) +
            gp.Normalize(affs, factor=1.0) +
            # raw: (d=1, h, w)
            # affs: (c=2, d=1, h, w)
            Squash(dim=-3) +
            # get rid of z dim
            # raw: (h, w)
            # affs: (c=2, h, w)
            AddChannelDim(raw) +
            # raw: (c=1, h, w)
            # affs: (c=2, h, w)
            gp.PreCache() +
            gp.Stack(batch_size) +
            # raw: (b=10, c=1, h, w)
            # affs: (b=10, c=2, h, w)
            Train(
                model=model,
                loss=loss,
                optimizer=optimizer,
                inputs={'x': raw},
                target=affs,
                output=affs_predicted,
                save_every=save_every,
                log_dir='log') +
            # raw: (b=10, c=1, h, w)
            # affs: (b=10, c=2, h, w)
            # affs_predicted: (b=10, c=2, h, w)
            TransposeDims(raw,(1, 0, 2, 3)) +
            TransposeDims(affs,(1, 0, 2, 3)) +
            TransposeDims(affs_predicted,(1, 0, 2, 3)) +
            # raw: (c=1, b=10, h, w)
            # affs: (c=2, b=10, h, w)
            # affs_predicted: (c=2, b=10, h, w)
            RemoveChannelDim(raw) +
            # raw: (b=10, h, w)
            # affs: (c=2, b=10, h, w)
            # affs_predicted: (c=2, b=10, h, w)
            gp.Snapshot(
                dataset_names={
                    raw: 'raw',
                    labels: 'labels',
                    affs: 'affs',
                    affs_predicted: 'affs_predicted'
                },
                every=100) +
            gp.PrintProfilingStats(every=100)
        )
    return pipeline 
Exemplo n.º 23
0
    def make_pipeline(self):
        raw = gp.ArrayKey('RAW')
        pred_affs = gp.ArrayKey('PREDICTIONS')

        source_shape = zarr.open(self.data_file)[self.dataset].shape
        raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:])

        data = daisy.open_ds(self.data_file, self.dataset)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(self.model.in_shape)
        out_shape = gp.Coordinate(self.model.out_shape[2:])

        is_2d = in_shape.dims() == 2

        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(pred_affs, out_shape)

        context = (in_shape - out_shape) / 2

        source = (gp.ZarrSource(self.data_file, {
            raw: self.dataset,
        },
                                array_specs={
                                    raw:
                                    gp.ArraySpec(roi=source_roi,
                                                 interpolatable=False)
                                }))

        in_dims = len(self.model.in_shape)
        if is_2d:
            # 2D: [samples, y, x] or [samples, channels, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 1)
            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=1))
            # raw [samples, channels, y, x]
        else:
            # 3D: [z, y, x] or [channel, z, y, x] or [sample, channel, z, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 0)
            needs_batch_fix = (len(data.shape) - in_dims <= 1)

            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=0))
            # Batch fix
            if needs_batch_fix:
                source = (source + AddChannelDim(raw))
            # raw: [sample, channels, z, y, x]

        with gp.build(source):
            raw_roi = source.spec[raw].roi
            logger.info(f"raw_roi: {raw_roi}")

        pipeline = (source +
                    gp.Normalize(raw, factor=self.params['norm_factor']) +
                    gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict(
                        self.model,
                        inputs={'raw': raw},
                        outputs={0: pred_affs},
                        array_specs={pred_affs: gp.ArraySpec(roi=raw_roi)}))

        pipeline = (pipeline + gp.ZarrWrite({
            pred_affs: 'predictions',
        },
                                            output_dir=self.curr_log_dir,
                                            output_filename='predictions.zarr',
                                            compression_type='gzip') +
                    gp.Scan(request))

        return pipeline, request, pred_affs
Exemplo n.º 24
0
def validation_pipeline(config):
    """
    Per block
    {
        Raw -> predict -> scan
        gt -> rasterize        -> merge -> candidates -> trees
    } -> merge -> comatch + evaluate
    """
    blocks = config["BLOCKS"]
    benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"])
    sample = config["VALIDATION_SAMPLES"][0]
    sample_dir = Path(config["SAMPLES_PATH"])
    raw_n5 = config["RAW_N5"]
    transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt"

    neuron_width = int(config["NEURON_RADIUS"])
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    micron_scale = max(voxel_size)
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    distance_attr = config["DISTANCE_ATTR"]

    validation_pipelines = []
    specs = {}

    for block in blocks:
        validation_dir = get_validation_dir(benchmark_datasets_path, block)
        trees = []
        cube = None
        for gt_file in validation_dir.iterdir():
            if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc":
                trees.append(gt_file)
            if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc":
                cube = gt_file
        assert cube.exists()

        cube_roi = get_roi_from_swc(
            cube,
            Path(transform_template.format(sample=sample)),
            np.array([300, 300, 1000]),
        )

        raw = gp.ArrayKey(f"RAW_{block}")
        raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}")
        ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}")
        labels = gp.ArrayKey(f"LABELS_{block}")

        raw_source = (gp.ZarrSource(
            filename=str(Path(sample_dir, sample, raw_n5).absolute()),
            datasets={
                raw: "volume-rechunked",
                raw_clahed: "volume-rechunked"
            },
            array_specs={
                raw:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
                raw_clahed:
                gp.ArraySpec(interpolatable=True, voxel_size=voxel_size),
            },
        ) + gp.Normalize(raw, dtype=np.float32) +
                      gp.Normalize(raw_clahed, dtype=np.float32) +
                      scipyCLAHE([raw_clahed], [20, 64, 64]))
        swc_source = nl.gunpowder.nodes.MouselightSwcFileSource(
            validation_dir,
            [ground_truth],
            transform_file=transform_template.format(sample=sample),
            ignore_human_nodes=False,
            scale=voxel_size,
            transpose=[2, 1, 0],
            points_spec=[
                gp.PointsSpec(roi=gp.Roi(
                    gp.Coordinate([None, None, None]),
                    gp.Coordinate([None, None, None]),
                ))
            ],
        )

        additional_request = BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)

        cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()),
                                  cube_roi.get_shape())
        input_roi = cube_roi_shifted.grow((input_size - output_size) // 2,
                                          (input_size - output_size) // 2)

        block_spec = specs.setdefault(block, {})
        block_spec[raw] = gp.ArraySpec(input_roi)
        additional_request[raw] = gp.ArraySpec(roi=input_roi)
        block_spec[raw_clahed] = gp.ArraySpec(input_roi)
        additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi)
        block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted)
        additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted)
        block_spec[labels] = gp.ArraySpec(cube_roi_shifted)
        additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted)

        pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() +
                    gp.SpecifiedLocation(locations=[cube_roi.get_center()]) +
                    gp.Crop(raw, roi=input_roi) +
                    gp.Crop(raw_clahed, roi=input_roi) +
                    gp.Crop(ground_truth, roi=cube_roi_shifted) +
                    nl.gunpowder.RasterizeSkeleton(
                        ground_truth,
                        labels,
                        connected_component_labeling=True,
                        array_spec=gp.ArraySpec(
                            voxel_size=voxel_size,
                            dtype=np.int64,
                            roi=gp.Roi(
                                gp.Coordinate([None, None, None]),
                                gp.Coordinate([None, None, None]),
                            ),
                        ),
                    ) + nl.gunpowder.GrowLabels(
                        labels, radii=[neuron_width * micron_scale]) +
                    gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot(
                        {
                            raw: f"volumes/{block}/raw",
                            raw_clahed: f"volumes/{block}/raw_clahe",
                            ground_truth: f"points/{block}/ground_truth",
                            labels: f"volumes/{block}/labels",
                        },
                        additional_request=additional_request,
                        output_dir="validations",
                        output_filename="validations.hdf",
                    ))

        validation_pipelines.append(pipeline)

    validation_pipeline = (tuple(pipeline
                                 for pipeline in validation_pipelines) +
                           gp.MergeProvider() + gp.PrintProfilingStats())
    return validation_pipeline, specs
Exemplo n.º 25
0
out_density_map = gp.ArrayKey('OUT_DENSITY_MAP')
prediction = gp.ArrayKey('PREDICTION')


class PrepareTrainingData(gp.BatchFilter):
    def process(self, batch, request):

        batch[out_cage_map].data = batch[out_cage_map].data.astype(np.float32)
        batch[out_cage_map].spec.dtype = np.float32


# assemble pipeline
sourceA = gp.ZarrSource('../data/cropped_sample_A.zarr', {
    raw: 'raw',
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
    seg: gp.ArraySpec(interpolatable=False)
})
sourceB = gp.ZarrSource('../data/cropped_sample_B.zarr', {
    raw: 'raw',
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
    seg: gp.ArraySpec(interpolatable=False)
})
sourceC = gp.ZarrSource('../data/cropped_sample_C.zarr', {
    raw: 'raw',
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
Exemplo n.º 26
0
    def create_train_pipeline(self, model):

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        filename = self.params['data_file']
        datasets = self.params['dataset']

        raw_0 = gp.ArrayKey('RAW_0')
        points_0 = gp.GraphKey('POINTS_0')
        locations_0 = gp.ArrayKey('LOCATIONS_0')
        emb_0 = gp.ArrayKey('EMBEDDING_0')
        raw_1 = gp.ArrayKey('RAW_1')
        points_1 = gp.GraphKey('POINTS_1')
        locations_1 = gp.ArrayKey('LOCATIONS_1')
        emb_1 = gp.ArrayKey('EMBEDDING_1')

        data = daisy.open_ds(filename, datasets[0])
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_shape = gp.Coordinate(model.out_shape[2:])
        is_2d = in_shape.dims() == 2

        emb_voxel_size = voxel_size

        cv_loss = ContrastiveVolumeLoss(self.params['temperature'],
                                        self.params['point_density'],
                                        out_shape * voxel_size)

        # Add fake 3rd dim
        if is_2d:
            in_shape = gp.Coordinate((1, *in_shape))
            out_shape = gp.Coordinate((1, *out_shape))
            voxel_size = gp.Coordinate((1, *voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (data.shape[0], *source_roi.get_shape()))

        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw_0, in_shape)
        request.add(raw_1, in_shape)
        request.add(points_0, out_shape)
        request.add(points_1, out_shape)
        request[locations_0] = gp.ArraySpec(nonspatial=True)
        request[locations_1] = gp.ArraySpec(nonspatial=True)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi)
        snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi)

        random_point_generator = RandomPointGenerator(
            density=self.params['point_density'], repetitions=2)

        # Use volume to calculate probabilities, RandomSourceGenerator will
        # normalize volumes to probablilties
        probabilities = np.array([
            np.product(daisy.open_ds(filename, dataset).shape)
            for dataset in datasets
        ])
        random_source_generator = RandomSourceGenerator(
            num_sources=len(datasets),
            probabilities=probabilities,
            repetitions=2)

        array_sources = tuple(
            tuple(
                gp.ZarrSource(
                    filename,
                    {raw: dataset},
                    # fake 3D data
                    array_specs={
                        raw:
                        gp.ArraySpec(roi=source_roi,
                                     voxel_size=voxel_size,
                                     interpolatable=True)
                    }) for dataset in datasets) for raw in [raw_0, raw_1])

        # Choose a random dataset to pull from
        array_sources = \
            tuple(arrays +
                  RandomMultiBranchSource(random_source_generator) +
                  gp.Normalize(raw, self.params['norm_factor']) +
                  gp.Pad(raw, None)
                  for raw, arrays
                  in zip([raw_0, raw_1], array_sources))

        point_sources = tuple(
            (RandomPointSource(points_0,
                               random_point_generator=random_point_generator),
             RandomPointSource(points_1,
                               random_point_generator=random_point_generator)))

        # Merge the point and array sources together.
        # There is one array and point source per branch.
        sources = tuple((array_source, point_source) + gp.MergeProvider()
                        for array_source, point_source in zip(
                            array_sources, point_sources))

        sources = tuple(
            self._make_train_augmentation_pipeline(raw, source)
            for raw, source in zip([raw_0, raw_1], sources))

        pipeline = (sources + gp.MergeProvider() + gp.Crop(raw_0, source_roi) +
                    gp.Crop(raw_1, source_roi) + gp.RandomLocation() +
                    PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0,
                                 locations_1, is_2d) +
                    RejectArray(ensure_nonempty=locations_0) +
                    RejectArray(ensure_nonempty=locations_1))

        if not is_2d:
            pipeline = (pipeline + AddChannelDim(raw_0) + AddChannelDim(raw_1))

        pipeline = (pipeline + gp.PreCache() + gp.torch.Train(
            model,
            cv_loss,
            optimizer,
            inputs={
                'raw_0': raw_0,
                'raw_1': raw_1
            },
            loss_inputs={
                'emb_0': emb_0,
                'emb_1': emb_1,
                'locations_0': locations_0,
                'locations_1': locations_1
            },
            outputs={
                2: emb_0,
                3: emb_1
            },
            array_specs={
                emb_0: gp.ArraySpec(voxel_size=emb_voxel_size),
                emb_1: gp.ArraySpec(voxel_size=emb_voxel_size)
            },
            checkpoint_basename=self.logdir + '/contrastive/checkpoints/model',
            save_every=self.params['save_every'],
            log_dir=self.logdir + "/contrastive",
            log_every=self.log_every))

        if is_2d:
            pipeline = (
                pipeline +
                # everything is 3D, except emb_0 and emb_1
                AddSpatialDim(emb_0) + AddSpatialDim(emb_1))

        pipeline = (
            pipeline +
            # now everything is 3D
            RemoveChannelDim(raw_0) + RemoveChannelDim(raw_1) +
            RemoveChannelDim(emb_0) + RemoveChannelDim(emb_1) +
            gp.Snapshot(output_dir=self.logdir + '/contrastive/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw_0: 'raw_0',
                            raw_1: 'raw_1',
                            locations_0: 'locations_0',
                            locations_1: 'locations_1',
                            emb_0: 'emb_0',
                            emb_1: 'emb_1'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            gp.PrintProfilingStats(every=500))

        return pipeline, request