Пример #1
0
def add_scans(pipelines, data_shapes):
    scanned_pipelines = []
    for pipeline in pipelines:
        ref_request = gp.BatchRequest()
        for key, shape in data_shapes.items():
            ref_request.add(key, shape)
        scanned_pipelines.append(pipeline + gp.Scan(reference=ref_request))
    return scanned_pipelines
Пример #2
0
def predict(iteration,path_to_dataGP):
   
  
    input_size = (8, 96, 96)
    output_size = (4, 64, 64)
    amount_size = gp.Coordinate((2, 16, 16))
    model = SpineUNet(crop_output='output_size')

    raw = gp.ArrayKey('RAW')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

                                
    reference_request = gp.BatchRequest()
    reference_request.add(raw, input_size)
    reference_request.add(affs_predicted, output_size)
    
    source = gp.ZarrSource(
        path_to_dataGP,
        {
            raw: 'validate/sample1/raw'
        } 
    )
  
    with gp.build(source):
        source_roi = source.spec[raw].roi
    request = gp.BatchRequest()
    request[raw] = gp.ArraySpec(roi=source_roi)
    request[affs_predicted] = gp.ArraySpec(roi=source_roi)

    pipeline = (
        source +
       
        gp.Pad(raw,amount_size) +
        gp.Normalize(raw) +
        # raw: (d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        gp_torch.Predict(
            model,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Scan(reference_request)
    )

    with gp.build(pipeline):
        prediction = pipeline.request_batch(request)

    return prediction[raw].data, prediction[affs_predicted].data
Пример #3
0
def build_pipeline(data_dir, model, checkpoint_file, input_size, output_size,
                   raw, labels, affs_predicted, dataset_shape, num_samples,
                   sample_size):

    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['model_state_dict'])

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(affs_predicted, output_size)
    scan_request.add(labels, output_size)

    pipeline = (
        gp.ZarrSource(str(data_dir), {
            raw: 'validate/raw',
            labels: 'validate/gt'
        }) + gp.Pad(raw, size=None) + gp.Normalize(raw) +
        # raw: (s, h, w)
        # labels: (s, h, w)
        train.AddChannelDim(raw) +
        # raw: (c=1, s, h, w)
        # labels: (s, h, w)
        train.TransposeDims(raw, (1, 0, 2, 3)) +
        # raw: (s, c=1, h, w)
        # labels: (s, h, w)
        Predict(model=model, inputs={'x': raw}, outputs={0: affs_predicted}) +
        # raw: (s, c=1, h, w)
        # affs_predicted: (s, c=2, h, w)
        # labels: (s, h, w)
        train.TransposeDims(raw, (1, 0, 2, 3)) + train.RemoveChannelDim(raw) +
        # raw: (s, h, w)
        # affs_predicted: (s, c=2, h, w)
        # labels: (s, h, w)
        gp.PrintProfilingStats(every=100) + gp.Scan(scan_request))

    return pipeline
Пример #4
0
    def make_pipeline(self):
        raw = gp.ArrayKey('RAW')
        pred_affs = gp.ArrayKey('PREDICTIONS')

        source_shape = zarr.open(self.data_file)[self.dataset].shape
        raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:])

        data = daisy.open_ds(self.data_file, self.dataset)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(self.model.in_shape)
        out_shape = gp.Coordinate(self.model.out_shape[2:])

        is_2d = in_shape.dims() == 2

        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(pred_affs, out_shape)

        context = (in_shape - out_shape) / 2

        source = (gp.ZarrSource(self.data_file, {
            raw: self.dataset,
        },
                                array_specs={
                                    raw:
                                    gp.ArraySpec(roi=source_roi,
                                                 interpolatable=False)
                                }))

        in_dims = len(self.model.in_shape)
        if is_2d:
            # 2D: [samples, y, x] or [samples, channels, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 1)
            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=1))
            # raw [samples, channels, y, x]
        else:
            # 3D: [z, y, x] or [channel, z, y, x] or [sample, channel, z, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 0)
            needs_batch_fix = (len(data.shape) - in_dims <= 1)

            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=0))
            # Batch fix
            if needs_batch_fix:
                source = (source + AddChannelDim(raw))
            # raw: [sample, channels, z, y, x]

        with gp.build(source):
            raw_roi = source.spec[raw].roi
            logger.info(f"raw_roi: {raw_roi}")

        pipeline = (source +
                    gp.Normalize(raw, factor=self.params['norm_factor']) +
                    gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict(
                        self.model,
                        inputs={'raw': raw},
                        outputs={0: pred_affs},
                        array_specs={pred_affs: gp.ArraySpec(roi=raw_roi)}))

        pipeline = (pipeline + gp.ZarrWrite({
            pred_affs: 'predictions',
        },
                                            output_dir=self.curr_log_dir,
                                            output_filename='predictions.zarr',
                                            compression_type='gzip') +
                    gp.Scan(request))

        return pipeline, request, pred_affs
Пример #5
0
def predict(**kwargs):
    name = kwargs['name']

    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_fgbg = gp.ArrayKey('PRED_FGBG')

    with open(os.path.join(kwargs['input_folder'], name + '_config.json'),
              'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['input_folder'], name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size
    context = (input_shape_world - output_shape_world) // 2

    # formulate the request for what a batch should contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(pred_affs, output_shape_world)
    request.add(pred_fgbg, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("predict node for %s not implemented yet",
                                  kwargs['input_format'])
    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
        with h5py.File(
                os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"),
                'r') as f:
            shape = f['volumes/raw'].shape
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource
        f = zarr.open(
            os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"),
            'r')
        shape = f['volumes/raw'].shape
    source = sourceNode(os.path.join(
        kwargs['data_folder'],
        kwargs['sample'] + "." + kwargs['input_format']),
                        datasets={raw: 'volumes/raw'})

    if kwargs['output_format'] != "zarr":
        raise NotImplementedError("Please use zarr as prediction output")
    # pre-create zarr file
    zf = zarr.open(os.path.join(kwargs['output_folder'],
                                kwargs['sample'] + '.zarr'),
                   mode='w')
    zf.create('volumes/pred_affs',
              shape=[3] + list(shape),
              chunks=[3] + list(shape),
              dtype=np.float32)
    zf['volumes/pred_affs'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size']

    zf.create('volumes/pred_fgbg',
              shape=[1] + list(shape),
              chunks=[1] + list(shape),
              dtype=np.float32)
    zf['volumes/pred_fgbg'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_fgbg'].attrs['resolution'] = kwargs['voxel_size']

    zf.create('volumes/raw_cropped',
              shape=[1] + list(shape),
              chunks=[1] + list(shape),
              dtype=np.float32)
    zf['volumes/raw_cropped'].attrs['offset'] = [0, 0, 0]
    zf['volumes/raw_cropped'].attrs['resolution'] = kwargs['voxel_size']

    pipeline = (

        # read from HDF5 file
        source + gp.Pad(raw, context) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(graph=os.path.join(kwargs['input_folder'],
                                                 name + '.meta'),
                              checkpoint=kwargs['checkpoint'],
                              inputs={net_names['raw']: raw},
                              outputs={
                                  net_names['pred_affs']: pred_affs,
                                  net_names['pred_fgbg']: pred_fgbg,
                                  net_names['raw_cropped']: raw_cropped
                              }) +

        # store all passing batches in the same HDF5 file
        gp.ZarrWrite(
            {
                raw_cropped: '/volumes/raw_cropped',
                pred_affs: '/volumes/pred_affs',
                pred_fgbg: '/volumes/pred_fgbg',
            },
            output_dir=kwargs['output_folder'],
            output_filename=kwargs['sample'] + ".zarr",
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request))

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
Пример #6
0
        (2, 2, 2)))
source = gp.Hdf5Source(
    os.path.join(directory, "sample_A.hdf"),
    datasets={
        labels: "volumes/labels/neuron_ids"  # reads resolution from file
    })

stardist_gen = gpstardist.AddStarDist3D(
    labels,
    stardists,
    rays=96,
    anisotropy=(40, 4, 4),
    grid=(1, 2, 2),
    unlabeled_id=int(np.array(-3).astype(np.uint64)),
    max_dist=max_dist,
)

writer = gp.ZarrWrite(
    output_dir=directory,
    output_filename=result_file,
    dataset_names={stardists: ds_name},
    compression_type="gzip",
)

scan = gp.Scan(scan_request)

pipeline = source + stardist_gen + writer + scan

with gp.build(pipeline):
    pipeline.request_batch(request)
Пример #7
0
def predict(
    model: Model,
    raw_array: Array,
    prediction_array_identifier: LocalArrayIdentifier,
    num_cpu_workers: int = 4,
    compute_context: ComputeContext = LocalTorch(),
    output_roi: Optional[Roi] = None,
):
    # get the model's input and output size

    input_voxel_size = Coordinate(raw_array.voxel_size)
    output_voxel_size = model.scale(input_voxel_size)
    input_shape = Coordinate(model.eval_input_shape)
    input_size = input_voxel_size * input_shape
    output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

    logger.info(
        "Predicting with input size %s, output size %s", input_size, output_size
    )

    # calculate input and output rois

    context = (input_size - output_size) / 2
    if output_roi is None:
        input_roi = raw_array.roi
        output_roi = input_roi.grow(-context, -context)
    else:
        input_roi = output_roi.grow(context, context)

    logger.info("Total input ROI: %s, output ROI: %s", input_roi, output_roi)

    # prepare prediction dataset
    axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"]
    ZarrArray.create_from_array_identifier(
        prediction_array_identifier,
        axes,
        output_roi,
        model.num_out_channels,
        output_voxel_size,
        np.float32,
    )

    # create gunpowder keys

    raw = gp.ArrayKey("RAW")
    prediction = gp.ArrayKey("PREDICTION")

    # assemble prediction pipeline

    # prepare data source
    pipeline = DaCapoArraySource(raw_array, raw)
    # raw: (c, d, h, w)
    pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims))
    # raw: (c, d, h, w)
    pipeline += gp.Unsqueeze([raw])
    # raw: (1, c, d, h, w)

    gt_padding = (output_size - output_roi.shape) % output_size
    prediction_roi = output_roi.grow(gt_padding)

    # predict
    pipeline += gp_torch.Predict(
        model=model,
        inputs={"x": raw},
        outputs={0: prediction},
        array_specs={
            prediction: gp.ArraySpec(
                roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32
            )
        },
        spawn_subprocess=False,
        device=str(compute_context.device),
    )
    # raw: (1, c, d, h, w)
    # prediction: (1, [c,] d, h, w)

    # prepare writing
    pipeline += gp.Squeeze([raw, prediction])
    # raw: (c, d, h, w)
    # prediction: (c, d, h, w)
    # raw: (c, d, h, w)
    # prediction: (c, d, h, w)

    # write to zarr
    pipeline += gp.ZarrWrite(
        {prediction: prediction_array_identifier.dataset},
        prediction_array_identifier.container.parent,
        prediction_array_identifier.container.name,
    )

    # create reference batch request
    ref_request = gp.BatchRequest()
    ref_request.add(raw, input_size)
    ref_request.add(prediction, output_size)
    pipeline += gp.Scan(ref_request)

    # build pipeline and predict in complete output ROI

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())

    container = zarr.open(prediction_array_identifier.container)
    dataset = container[prediction_array_identifier.dataset]
    dataset.attrs["axes"] = (
        raw_array.axes if "c" in raw_array.axes else ["c"] + raw_array.axes
    )
Пример #8
0
def predict_volume(model,
                   dataset,
                   out_dir,
                   out_filename,
                   out_ds_names,
                   input_key='0/raw',
                   normalize_factor=None,
                   model_output=0,
                   in_shape=None,
                   out_shape=None,
                   spawn_subprocess=True,
                   num_workers=0):

    raw = gp.ArrayKey('RAW')
    prediction = gp.ArrayKey('PREDICTION')

    data = daisy.open_ds(dataset.filename, dataset.ds_names[0])
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    data_dims = len(data.shape)

    # Get in and out shape
    if in_shape is None:
        in_shape = model.in_shape
    if out_shape is None:
        out_shape = model.out_shape

    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    spatial_dims = in_shape.dims()

    if apply_voxel_size:
        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    context = (in_shape - out_shape) / 2

    print("context", context, in_shape, out_shape)

    source = (gp.ZarrSource(
        dataset.filename, {
            raw: dataset.ds_names[0],
        },
        array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)}))

    num_additional_channels = (2 + spatial_dims) - data_dims

    for _ in range(num_additional_channels):
        source += AddChannelDim(raw)

    # prediction requires samples first, channels second
    source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims)))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline = source

    if normalize_factor != "skip":
        pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor)

    pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict(
        model,
        inputs={input_name: raw},
        outputs={model_output: prediction},
        array_specs={prediction: gp.ArraySpec(roi=raw_roi)},
        checkpoint=checkpoint,
        spawn_subprocess=spawn_subprocess))

    # # remove sample dimension for 3D data
    # pipeline += RemoveChannelDim(raw)
    # pipeline += RemoveChannelDim(prediction)

    pipeline += (gp.ZarrWrite({
        prediction: out_ds_names[0],
    },
                              output_dir=out_dir,
                              output_filename=out_filename,
                              compression_type='gzip') +
                 gp.Scan(request, num_workers=num_workers))

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
Пример #9
0
def predict_3d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # ensure validation ROI is at least the size of the network input
    roi = raw_data.roi.grow(input_size / 2, input_size / 2)

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=roi)
    total_request[prediction] = gp.ArraySpec(roi=roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=roi)
        total_request[target] = gp.ArraySpec(roi=roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Пример #10
0
def predict(**kwargs):
    name = kwargs['name']

    raw = gp.ArrayKey('RAW')
    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_numinst = gp.ArrayKey('PRED_NUMINST')

    with open(os.path.join(kwargs['input_folder'], name + '_config.json'),
              'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['input_folder'], name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size
    context = (input_shape_world - output_shape_world) // 2
    chunksize = list(np.asarray(output_shape_world) // 2)

    raw_key = kwargs.get('raw_key', 'volumes/raw')

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape_world, voxel_size=voxel_size)
    request.add(pred_affs, output_shape_world, voxel_size=voxel_size)
    if kwargs['overlapping_inst']:
        request.add(pred_numinst, output_shape_world, voxel_size=voxel_size)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("predict node for %s not implemented yet",
                                  kwargs['input_format'])
    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
        with h5py.File(
                os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"),
                'r') as f:
            shape = f[raw_key].shape[1:]
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource
        f = zarr.open(
            os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"),
            'r')
        shape = f[raw_key].shape[1:]
    source = sourceNode(os.path.join(
        kwargs['data_folder'],
        kwargs['sample'] + "." + kwargs['input_format']),
                        datasets={raw: raw_key})

    if kwargs['output_format'] != "zarr":
        raise NotImplementedError("Please use zarr as prediction output")

    # open zarr file
    zf = zarr.open(os.path.join(kwargs['output_folder'],
                                kwargs['sample'] + '.zarr'),
                   mode='w')
    zf.create('volumes/pred_affs',
              shape=[int(np.prod(kwargs['patchshape']))] + list(shape),
              chunks=[int(np.prod(kwargs['patchshape']))] + list(chunksize),
              dtype=np.float16)
    zf['volumes/pred_affs'].attrs['offset'] = [0, 0]
    zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size']

    if kwargs['overlapping_inst']:
        zf.create('volumes/pred_numinst',
                  shape=[int(kwargs['max_num_inst']) + 1] + list(shape),
                  chunks=[int(kwargs['max_num_inst']) + 1] + list(chunksize),
                  dtype=np.float16)
        zf['volumes/pred_numinst'].attrs['offset'] = [0, 0]
        zf['volumes/pred_numinst'].attrs['resolution'] = kwargs['voxel_size']

    outputs = {
        net_names['pred_affs']: pred_affs,
    }
    outVolumes = {
        pred_affs: '/volumes/pred_affs',
    }
    if kwargs['overlapping_inst']:
        outputs[net_names['pred_numinst']] = pred_numinst
        outVolumes[pred_numinst] = '/volumes/pred_numinst'

    pipeline = (
        source + gp.Pad(raw, context) + gp.IntensityScaleShift(raw, 2, -1) +
        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(graph=os.path.join(kwargs['input_folder'],
                                                 name + '.meta'),
                              checkpoint=kwargs['checkpoint'],
                              inputs={net_names['raw']: raw},
                              outputs=outputs) +

        # store all passing batches in the same HDF5 file
        gp.ZarrWrite(outVolumes,
                     output_dir=kwargs['output_folder'],
                     output_filename=kwargs['sample'] + ".zarr",
                     compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=100) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request))

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
Пример #11
0
def predict(data_dir,
            train_dir,
            iteration,
            sample,
            test_net_name='train_net',
            train_net_name='train_net',
            output_dir='.',
            clip_max=1000):

    if "hdf" not in data_dir:
        return

    print("Predicting ", sample)
    print(
        'checkpoint: ',
        os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration))

    checkpoint = os.path.join(train_dir,
                              train_net_name + '_checkpoint_%d' % iteration)

    with open(os.path.join(train_dir, test_net_name + '_config.json'),
              'r') as f:
        net_config = json.load(f)

    with open(os.path.join(train_dir, test_net_name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    # ArrayKeys
    raw = gp.ArrayKey('RAW')
    pred_mask = gp.ArrayKey('PRED_MASK')

    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    voxel_size = gp.Coordinate((1, 1, 1))
    context = gp.Coordinate(input_shape - output_shape) / 2

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape, voxel_size=voxel_size)
    request.add(pred_mask, output_shape, voxel_size=voxel_size)

    print("chunk request %s" % request)

    source = (gp.Hdf5Source(
        data_dir,
        datasets={
            raw: sample + '/raw',
        },
        array_specs={
            raw:
            gp.ArraySpec(
                interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
        },
    ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) +
              gp.Normalize(raw, factor=1.0 / clip_max) +
              gp.IntensityScaleShift(raw, 2, -1))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        print("raw_roi: %s" % raw_roi)
        sample_shape = raw_roi.grow(-context, -context).get_shape()

    print(sample_shape)

    # create zarr file with corresponding chunk size
    zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w')

    zf.create('volumes/pred_mask',
              shape=sample_shape,
              chunks=output_shape,
              dtype=np.float16)
    zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1]

    pipeline = (
        source + gp.tensorflow.Predict(
            graph=os.path.join(train_dir, test_net_name + '.meta'),
            checkpoint=checkpoint,
            inputs={
                net_names['raw']: raw,
            },
            outputs={
                net_names['pred']: pred_mask,
            },
            array_specs={
                pred_mask:
                gp.ArraySpec(roi=raw_roi.grow(-context, -context),
                             voxel_size=voxel_size),
            },
            max_shared_memory=1024 * 1024 * 1024) +
        Convert(pred_mask, np.float16) + gp.ZarrWrite(
            dataset_names={
                pred_mask: 'volumes/pred_mask',
            },
            output_dir=output_dir,
            output_filename=sample + '.zarr',
            compression_type='gzip',
            dataset_dtypes={pred_mask: np.float16}) +

        # show a summary of time spend in each node every x iterations
        gp.PrintProfilingStats(every=100) +
        gp.Scan(reference=request, num_workers=5, cache_size=50))

    with gp.build(pipeline):

        pipeline.request_batch(gp.BatchRequest())
Пример #12
0
def predict_volume(model,
                   dataset,
                   out_dir,
                   out_filename,
                   out_ds_names,
                   checkpoint,
                   input_name='raw_0',
                   normalize_factor=None,
                   model_output=0,
                   in_shape=None,
                   out_shape=None,
                   spawn_subprocess=True,
                   num_workers=0,
                   apply_voxel_size=True):

    raw = gp.ArrayKey('RAW')
    prediction = gp.ArrayKey('PREDICTION')

    data = daisy.open_ds(dataset.filename, dataset.ds_names[0])
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    data_dims = len(data.shape)

    # Get in and out shape
    if in_shape is None:
        in_shape = model.in_shape
    if out_shape is None:
        out_shape = model.out_shape

    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    spatial_dims = in_shape.dims()
    is_2d = spatial_dims == 2

    in_shape = in_shape * voxel_size
    out_shape = out_shape * voxel_size

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    context = (in_shape - out_shape) / 2

    source = (gp.ZarrSource(
        dataset.filename, {
            raw: dataset.ds_names[0],
        },
        array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)}))

    # ensure raw has sample and channel dims
    #
    # n = number of samples
    # c = number of channels

    # 2D raw is either (n, y, x) or (c, n, y, x)
    # 3D raw is either (z, y, x) or (c, z, y, x)
    for _ in range((2 + spatial_dims) - data_dims):
        source += AddChannelDim(raw)

    # 2D raw: (c, n, y, x)
    # 3D raw: (c, n=1, z, y, x)

    # prediction requires samples first, channels second
    source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims)))

    # 2D raw: (n, c, y, x)
    # 3D raw: (n=1, c, z, y, x)

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline = source

    if normalize_factor != "skip":
        pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor)

    pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict(
        model,
        inputs={input_name: raw},
        outputs={model_output: prediction},
        array_specs={prediction: gp.ArraySpec(roi=raw_roi)},
        checkpoint=checkpoint,
        spawn_subprocess=spawn_subprocess))

    # 2D raw       : (n, c, y, x)
    # 2D prediction: (n, c, y, x)
    # 3D raw       : (n=1, c, z, y, x)
    # 3D prediction: (n=1, c, z, y, x)

    if is_2d:

        # restore channels first for 2D data
        pipeline += TransposeDims(raw,
                                  (1, 0) + tuple(range(2, 2 + spatial_dims)))
        pipeline += TransposeDims(prediction,
                                  (1, 0) + tuple(range(2, 2 + spatial_dims)))

    else:

        # remove sample dimension for 3D data
        pipeline += RemoveChannelDim(raw)
        pipeline += RemoveChannelDim(prediction)

    # 2D raw       : (c, n, y, x)
    # 2D prediction: (c, n, y, x)
    # 3D raw       : (c, z, y, x)
    # 3D prediction: (c, z, y, x)

    pipeline += (gp.ZarrWrite({
        prediction: out_ds_names[0],
    },
                              output_dir=out_dir,
                              output_filename=out_filename,
                              compression_type='gzip') +
                 gp.Scan(request, num_workers=num_workers))

    logger.info("Writing prediction to %s/%s[%s]", out_dir, out_filename,
                out_ds_names[0])

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
Пример #13
0
    def make_pipeline(self):
        raw = gp.ArrayKey('RAW')
        embs = gp.ArrayKey('EMBS')

        source_shape = zarr.open(self.data_file)[self.dataset].shape
        raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:])

        data = daisy.open_ds(self.data_file, self.dataset)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(self.model.in_shape)
        out_shape = gp.Coordinate(self.model.out_shape[2:])

        is_2d = in_shape.dims() == 2

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")
        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(embs, out_shape)

        context = (in_shape - out_shape) / 2

        source = (gp.ZarrSource(self.data_file, {
            raw: self.dataset,
        },
                                array_specs={
                                    raw:
                                    gp.ArraySpec(roi=source_roi,
                                                 interpolatable=False)
                                }))

        if is_2d:
            source = (source + AddChannelDim(raw, axis=1))
        else:
            source = (source + AddChannelDim(raw, axis=0) + AddChannelDim(raw))

        source = (
            source
            # raw      : (c=1, roi)
        )

        with gp.build(source):
            raw_roi = source.spec[raw].roi
            logger.info(f"raw_roi: {raw_roi}")

        pipeline = (
            source + gp.Normalize(raw, factor=self.params['norm_factor']) +
            gp.Pad(raw, context) + gp.PreCache() +
            gp.torch.Predict(self.model,
                             inputs={'raw': raw},
                             outputs={0: embs},
                             array_specs={embs: gp.ArraySpec(roi=raw_roi)}))

        pipeline = (pipeline +
                    gp.ZarrWrite({
                        embs: 'embs',
                    },
                                 output_dir=self.curr_log_dir,
                                 output_filename=self.dataset + '_embs.zarr',
                                 compression_type='gzip') + gp.Scan(request))

        return pipeline, request, embs
Пример #14
0
def predict_frame(in_shape,
                  out_shape,
                  model_output,
                  model_configfile,
                  model_checkpoint,
                  input_dataset_file,
                  inference_frame,
                  out_dir,
                  out_filename,
                  out_key_or_index=1,
                  intermediate_layer=None,
                  dataset_raw_key="train/raw",
                  dataset_prediction_key="train/prediction",
                  dataset_intermediate_key="train/prediction_interm",
                  model_input_tensor_name="patches",
                  model_architecture="PatchedResnet",
                  num_workers=5):

    # initialize model
    if model_architecture == "PatchedResnet":
        model = PatchedResnet(1, 2, resnet_size=18)
    elif model_architecture == "unet":
        model = lisl.models.create(model_configfile)
    else:
        raise NotImplementedError(f"{model_architecture} not implemented")

    model.add_spatial_dim = True
    model.eval()

    # gp variables
    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    raw = gp.ArrayKey(f'RAW_{inference_frame}')
    prediction = gp.ArrayKey(f'PREDICTION_{inference_frame}')
    intermediate_prediction = gp.ArrayKey(f'ITERM_{inference_frame}')

    ds_key = f'{dataset_raw_key}/{inference_frame}'
    out_key = f'{dataset_prediction_key}/{inference_frame}'
    interm_key = f'{dataset_intermediate_key}/{inference_frame}'

    # build pipeline
    zsource = gp.ZarrSource(
        input_dataset_file, {raw: ds_key},
        {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))})

    pipeline = zsource
    with gp.build(zsource):
        raw_roi = zsource.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline += AddChannelDim(raw)
    pipeline += AddChannelDim(raw)

    pipeline += gp.Pad(raw, None)
    # setup prediction node
    pred_dict = {out_key_or_index: prediction}
    pred_spec = {prediction: gp.ArraySpec(roi=raw_roi)}
    if intermediate_layer is not None:
        pred_dict[intermediate_layer] = intermediate_prediction
        pred_spec[intermediate_prediction] = gp.ArraySpec(roi=raw_roi)

    pipeline += gp.torch.Predict(model,
                                 inputs={model_input_tensor_name: raw},
                                 outputs=pred_dict,
                                 array_specs=pred_spec,
                                 checkpoint=model_checkpoint,
                                 spawn_subprocess=True)

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    zarr_dict = {prediction: out_key}
    if intermediate_layer is not None:
        zarr_dict[intermediate_prediction] = interm_key
        request.add(intermediate_prediction, out_shape)
    pipeline += gp.Scan(request, num_workers=num_workers)
    pipeline += gp.ZarrWrite(zarr_dict,
                             output_dir=out_dir,
                             output_filename=out_filename,
                             compression_type='gzip')

    total_request = gp.BatchRequest()
    total_request[prediction] = gp.ArraySpec(roi=raw_roi)
    if intermediate_layer is not None:
        total_request[intermediate_prediction] = gp.ArraySpec(roi=raw_roi)
    with gp.build(pipeline):
        pipeline.request_batch(total_request)
Пример #15
0
def predict(iteration):

    ##################
    # DECLARE ARRAYS #
    ##################

    # raw intensities
    raw = gp.ArrayKey('RAW')

    # the predicted affinities
    pred_affs = gp.ArrayKey('PRED_AFFS')

    ####################
    # DECLARE REQUESTS #
    ####################

    with open('test_net_config.json', 'r') as f:
        net_config = json.load(f)

    # get the input and output size in world units (nm, in this case)
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size
    context = input_size - output_size

    # formulate the request for what a batch should contain
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(pred_affs, output_size)

    #############################
    # ASSEMBLE TESTING PIPELINE #
    #############################

    source = gp.Hdf5Source('sample_A_padded_20160501.hdf',
                           datasets={raw: 'volumes/raw'})

    # get the ROI provided for raw (we need it later to calculate the ROI in
    # which we can make predictions)
    with gp.build(source):
        raw_roi = source.spec[raw].roi

    pipeline = (

        # read from HDF5 file
        source +

        # convert raw to float in [0, 1]
        gp.Normalize(raw) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(
            graph='test_net.meta',
            checkpoint='train_net_checkpoint_%d' % iteration,
            inputs={net_config['raw']: raw},
            outputs={net_config['pred_affs']: pred_affs},
            array_specs={
                pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))
            }) +

        # store all passing batches in the same HDF5 file
        gp.Hdf5Write({
            raw: '/volumes/raw',
            pred_affs: '/volumes/pred_affs',
        },
                     output_filename='predictions_sample_A.hdf',
                     compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request))

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
def predict(**kwargs):
    name = kwargs['name']

    raw = gp.ArrayKey('RAW')
    pred_affs = gp.ArrayKey('PRED_AFFS')

    with open(os.path.join(kwargs['input_folder'],
                           name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['input_folder'],
                           name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size
    context = (input_shape_world - output_shape_world)//2

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape_world, voxel_size=voxel_size)
    request.add(pred_affs, output_shape_world, voxel_size=voxel_size)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("predict node for %s not implemented yet",
                                  kwargs['input_format'])
    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
        with h5py.File(os.path.join(kwargs['data_folder'],
                                    kwargs['sample'] + ".hdf"), 'r') as f:
            shape = f['volumes/raw'].shape
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource
        f = zarr.open(os.path.join(kwargs['data_folder'],
                                   kwargs['sample'] + ".zarr"), 'r')
        shape = f['volumes/raw'].shape
    # shape =
    source = sourceNode(
        os.path.join(kwargs['data_folder'],
                     kwargs['sample'] + "." + kwargs['input_format']),
        datasets = {
            raw: 'volumes/raw'
        },
        # array_specs = {
        #     raw: gp.ArraySpec(roi=gp.Roi(gp.Coordinate((0, 0, 400)),
        #                                  gp.Coordinate(input_shape_world)))
        # }
    )

    crop = []
    for d in range(-3, 0):
        if shape[d] < net_config['output_shape'][d]:
            crop.append((net_config['output_shape'][d]-shape[d])//2)
        else:
            crop.append(0)
    print("cropping", crop)
    context += gp.Coordinate(crop)

    if kwargs['output_format'] != "zarr":
        raise NotImplementedError("Please use zarr as prediction output")

    # open zarr file
    zf = zarr.open(os.path.join(kwargs['output_folder'],
                                kwargs['sample'] + '.zarr'), mode='w')
    zf.create('volumes/pred_affs',
              shape=[int(np.prod(kwargs['patchshape']))] + list(shape),
              chunks=[int(np.prod(kwargs['patchshape']))] + list(shape)[:-1] + [20],
              dtype=np.float32)
    zf['volumes/pred_affs'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size']

    zf.create('volumes/raw',
              shape=list(shape),
              chunks=list(shape)[:-1] + [20],
              dtype=np.float32)
    zf['volumes/raw'].attrs['offset'] = [0, 0, 0]
    zf['volumes/raw'].attrs['resolution'] = kwargs['voxel_size']

    outputs = {
        net_names['pred_affs']: pred_affs,
    }
    outVolumes = {
        # raw: '/volumes/raw',
        pred_affs: '/volumes/pred_affs',
    }


    pipeline = (
        source +
        gp.Pad(raw, context) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(
            graph=os.path.join(kwargs['input_folder'], name + '.meta'),
            checkpoint=kwargs['checkpoint'],
            inputs={
                net_names['raw']: raw
            },
            # array_specs={
            #     pred_affs: gp.ArraySpec(roi=gp.Roi(gp.Coordinate((46, 46, 46)),
            #                                        output_shape_world),
            #                             voxel_size=voxel_size)
            # },
            outputs=outputs) +


    # if max(crop) > 0:
    #     print("cropping", crop)
    #     pipeline += gp.Crop(pred_affs, absolute_negative=crop, absolute_positive=crop)
    # pipeline += (
        # store all passing batches in the same HDF5 file
        gp.ZarrWrite(
            outVolumes,
            output_dir=kwargs['output_folder'],
            output_filename=kwargs['sample'] + ".zarr",
            compression_type='gzip'
        ) +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request)
    )

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
Пример #17
0
def add_scan(pipeline, data_shapes):
    ref_request = gp.BatchRequest()
    for key, shape in data_shapes.items():
        ref_request.add(key, shape)
    pipeline = pipeline + gp.Scan(reference=ref_request)
    return pipeline
Пример #18
0
def predict_2d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = raw_data.shape
    dataset_roi = raw_data.roi
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError(
            "For 2D validation, please provide a 3D array where the first "
            "dimension indexes the samples.")

    num_samples = raw_data.num_samples

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    if gt_data:
        sources = (raw_data.get_source(raw, overwrite_spec=spec),
                   gt_data.get_source(gt, overwrite_spec=spec))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw, overwrite_spec=spec)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    # target: ([c,] s, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    if gt_data and predictor.target_channels == 0:
        pipeline += AddChannelDim(target)
    # raw: (c, s, h, w)
    # gt: ([c,] s, h, w)
    # target: (c, s, h, w)
    pipeline += TransposeDims(raw, (1, 0, 2, 3))
    if gt_data:
        pipeline += TransposeDims(target, (1, 0, 2, 3))
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    # prediction: (s, c, h, w)
    pipeline += gp.Scan(scan_request)

    total_request = gp.BatchRequest()
    total_request.add(raw, sample_size)
    total_request.add(prediction, sample_size)
    if gt_data:
        total_request.add(gt, sample_size)
        total_request.add(target, sample_size)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Пример #19
0
def predict_3d(raw_data, gt_data, model, predictor, aux_tasks):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = model.input_shape
    output_shape = model.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    model_output = gp.ArrayKey('MODEL_OUTPUT')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=model,
                                 inputs={'x': raw},
                                 outputs={0: model_output})
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': model_output},
                                 outputs={0: prediction})
    aux_predictions = []
    for aux_name, aux_predictor, _ in aux_tasks:
        aux_pred_key = gp.ArrayKey(f"PRED_{aux_name.upper()}")
        pipeline += gp_torch.Predict(model=aux_predictor,
                                     inputs={'x': model_output},
                                     outputs={0: aux_pred_key})
        aux_predictions.append((aux_name, aux_pred_key))
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(model_output, output_size)
    scan_request.add(prediction, output_size)
    for aux_name, aux_key in aux_predictions:
        scan_request.add(aux_key, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # only output where the gt exists
    context = (input_size - output_size) / 2

    output_roi = gt_data.roi.intersect(raw_data.roi.grow(-context, -context))
    input_roi = output_roi.grow(context, context)

    assert all([a > b for a, b in zip(input_roi.get_shape(), input_size)])
    assert all([a > b for a, b in zip(output_roi.get_shape(), output_size)])

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=input_roi)
    total_request[model_output] = gp.ArraySpec(roi=output_roi)
    total_request[prediction] = gp.ArraySpec(roi=output_roi)
    for aux_name, aux_key in aux_predictions:
        total_request[aux_key] = gp.ArraySpec(roi=output_roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=output_roi)
        total_request[target] = gp.ArraySpec(roi=output_roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {
            'raw': batch[raw],
            'model_out': batch[model_output],
            'prediction': batch[prediction]
        }
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        for aux_name, aux_key in aux_predictions:
            ret[aux_name] = batch[aux_key]
        return ret