コード例 #1
0
def predict(iteration,path_to_dataGP):
   
  
    input_size = (8, 96, 96)
    output_size = (4, 64, 64)
    amount_size = gp.Coordinate((2, 16, 16))
    model = SpineUNet(crop_output='output_size')

    raw = gp.ArrayKey('RAW')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

                                
    reference_request = gp.BatchRequest()
    reference_request.add(raw, input_size)
    reference_request.add(affs_predicted, output_size)
    
    source = gp.ZarrSource(
        path_to_dataGP,
        {
            raw: 'validate/sample1/raw'
        } 
    )
  
    with gp.build(source):
        source_roi = source.spec[raw].roi
    request = gp.BatchRequest()
    request[raw] = gp.ArraySpec(roi=source_roi)
    request[affs_predicted] = gp.ArraySpec(roi=source_roi)

    pipeline = (
        source +
       
        gp.Pad(raw,amount_size) +
        gp.Normalize(raw) +
        # raw: (d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        gp_torch.Predict(
            model,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Scan(reference_request)
    )

    with gp.build(pipeline):
        prediction = pipeline.request_batch(request)

    return prediction[raw].data, prediction[affs_predicted].data
コード例 #2
0
ファイル: predict_pipeline.py プロジェクト: funkelab/dacapo
def predict_2d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = raw_data.shape
    dataset_roi = raw_data.roi
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError(
            "For 2D validation, please provide a 3D array where the first "
            "dimension indexes the samples.")

    num_samples = raw_data.num_samples

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    if gt_data:
        sources = (raw_data.get_source(raw, overwrite_spec=spec),
                   gt_data.get_source(gt, overwrite_spec=spec))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw, overwrite_spec=spec)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    # target: ([c,] s, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    if gt_data and predictor.target_channels == 0:
        pipeline += AddChannelDim(target)
    # raw: (c, s, h, w)
    # gt: ([c,] s, h, w)
    # target: (c, s, h, w)
    pipeline += TransposeDims(raw, (1, 0, 2, 3))
    if gt_data:
        pipeline += TransposeDims(target, (1, 0, 2, 3))
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    # prediction: (s, c, h, w)
    pipeline += gp.Scan(scan_request)

    total_request = gp.BatchRequest()
    total_request.add(raw, sample_size)
    total_request.add(prediction, sample_size)
    if gt_data:
        total_request.add(gt, sample_size)
        total_request.add(target, sample_size)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
コード例 #3
0
ファイル: predict_pipeline.py プロジェクト: funkelab/dacapo
def predict_3d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # ensure validation ROI is at least the size of the network input
    roi = raw_data.roi.grow(input_size / 2, input_size / 2)

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=roi)
    total_request[prediction] = gp.ArraySpec(roi=roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=roi)
        total_request[target] = gp.ArraySpec(roi=roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
コード例 #4
0
ファイル: predict.py プロジェクト: pattonw/dacapo
def predict(
    model: Model,
    raw_array: Array,
    prediction_array_identifier: LocalArrayIdentifier,
    num_cpu_workers: int = 4,
    compute_context: ComputeContext = LocalTorch(),
    output_roi: Optional[Roi] = None,
):
    # get the model's input and output size

    input_voxel_size = Coordinate(raw_array.voxel_size)
    output_voxel_size = model.scale(input_voxel_size)
    input_shape = Coordinate(model.eval_input_shape)
    input_size = input_voxel_size * input_shape
    output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

    logger.info(
        "Predicting with input size %s, output size %s", input_size, output_size
    )

    # calculate input and output rois

    context = (input_size - output_size) / 2
    if output_roi is None:
        input_roi = raw_array.roi
        output_roi = input_roi.grow(-context, -context)
    else:
        input_roi = output_roi.grow(context, context)

    logger.info("Total input ROI: %s, output ROI: %s", input_roi, output_roi)

    # prepare prediction dataset
    axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"]
    ZarrArray.create_from_array_identifier(
        prediction_array_identifier,
        axes,
        output_roi,
        model.num_out_channels,
        output_voxel_size,
        np.float32,
    )

    # create gunpowder keys

    raw = gp.ArrayKey("RAW")
    prediction = gp.ArrayKey("PREDICTION")

    # assemble prediction pipeline

    # prepare data source
    pipeline = DaCapoArraySource(raw_array, raw)
    # raw: (c, d, h, w)
    pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims))
    # raw: (c, d, h, w)
    pipeline += gp.Unsqueeze([raw])
    # raw: (1, c, d, h, w)

    gt_padding = (output_size - output_roi.shape) % output_size
    prediction_roi = output_roi.grow(gt_padding)

    # predict
    pipeline += gp_torch.Predict(
        model=model,
        inputs={"x": raw},
        outputs={0: prediction},
        array_specs={
            prediction: gp.ArraySpec(
                roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32
            )
        },
        spawn_subprocess=False,
        device=str(compute_context.device),
    )
    # raw: (1, c, d, h, w)
    # prediction: (1, [c,] d, h, w)

    # prepare writing
    pipeline += gp.Squeeze([raw, prediction])
    # raw: (c, d, h, w)
    # prediction: (c, d, h, w)
    # raw: (c, d, h, w)
    # prediction: (c, d, h, w)

    # write to zarr
    pipeline += gp.ZarrWrite(
        {prediction: prediction_array_identifier.dataset},
        prediction_array_identifier.container.parent,
        prediction_array_identifier.container.name,
    )

    # create reference batch request
    ref_request = gp.BatchRequest()
    ref_request.add(raw, input_size)
    ref_request.add(prediction, output_size)
    pipeline += gp.Scan(ref_request)

    # build pipeline and predict in complete output ROI

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())

    container = zarr.open(prediction_array_identifier.container)
    dataset = container[prediction_array_identifier.dataset]
    dataset.attrs["axes"] = (
        raw_array.axes if "c" in raw_array.axes else ["c"] + raw_array.axes
    )
コード例 #5
0
def predict_3d(raw_data, gt_data, model, predictor, aux_tasks):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = model.input_shape
    output_shape = model.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    model_output = gp.ArrayKey('MODEL_OUTPUT')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=model,
                                 inputs={'x': raw},
                                 outputs={0: model_output})
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': model_output},
                                 outputs={0: prediction})
    aux_predictions = []
    for aux_name, aux_predictor, _ in aux_tasks:
        aux_pred_key = gp.ArrayKey(f"PRED_{aux_name.upper()}")
        pipeline += gp_torch.Predict(model=aux_predictor,
                                     inputs={'x': model_output},
                                     outputs={0: aux_pred_key})
        aux_predictions.append((aux_name, aux_pred_key))
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(model_output, output_size)
    scan_request.add(prediction, output_size)
    for aux_name, aux_key in aux_predictions:
        scan_request.add(aux_key, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # only output where the gt exists
    context = (input_size - output_size) / 2

    output_roi = gt_data.roi.intersect(raw_data.roi.grow(-context, -context))
    input_roi = output_roi.grow(context, context)

    assert all([a > b for a, b in zip(input_roi.get_shape(), input_size)])
    assert all([a > b for a, b in zip(output_roi.get_shape(), output_size)])

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=input_roi)
    total_request[model_output] = gp.ArraySpec(roi=output_roi)
    total_request[prediction] = gp.ArraySpec(roi=output_roi)
    for aux_name, aux_key in aux_predictions:
        total_request[aux_key] = gp.ArraySpec(roi=output_roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=output_roi)
        total_request[target] = gp.ArraySpec(roi=output_roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {
            'raw': batch[raw],
            'model_out': batch[model_output],
            'prediction': batch[prediction]
        }
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        for aux_name, aux_key in aux_predictions:
            ret[aux_name] = batch[aux_key]
        return ret