Example #1
0
def predict(iteration,path_to_dataGP):
   
  
    input_size = (8, 96, 96)
    output_size = (4, 64, 64)
    amount_size = gp.Coordinate((2, 16, 16))
    model = SpineUNet(crop_output='output_size')

    raw = gp.ArrayKey('RAW')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

                                
    reference_request = gp.BatchRequest()
    reference_request.add(raw, input_size)
    reference_request.add(affs_predicted, output_size)
    
    source = gp.ZarrSource(
        path_to_dataGP,
        {
            raw: 'validate/sample1/raw'
        } 
    )
  
    with gp.build(source):
        source_roi = source.spec[raw].roi
    request = gp.BatchRequest()
    request[raw] = gp.ArraySpec(roi=source_roi)
    request[affs_predicted] = gp.ArraySpec(roi=source_roi)

    pipeline = (
        source +
       
        gp.Pad(raw,amount_size) +
        gp.Normalize(raw) +
        # raw: (d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        gp_torch.Predict(
            model,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Scan(reference_request)
    )

    with gp.build(pipeline):
        prediction = pipeline.request_batch(request)

    return prediction[raw].data, prediction[affs_predicted].data
Example #2
0
    def test_delete_points_in_context(self):
        points = gp.PointsKey("POINTS")
        pv_array = gp.ArrayKey("PARENT_VECTORS")
        mask = gp.ArrayKey("MASK")
        radius = [0.1, 0.1, 0.1, 0.1]
        ts = TracksSource(TEST_FILE, points)
        apv = AddParentVectors(points, pv_array, mask, radius)
        request = gp.BatchRequest()
        request.add(points, gp.Coordinate((1, 4, 4, 4)))
        request.add(pv_array, gp.Coordinate((1, 4, 4, 4)))
        request.add(mask, gp.Coordinate((1, 4, 4, 4)))

        pipeline = (ts + gp.Pad(points, None) + apv)
        with gp.build(pipeline):
            pipeline.request_batch(request)
Example #3
0
    def test_add_parent_vectors(self):
        points = gp.PointsKey("POINTS")
        pv_array = gp.ArrayKey("PARENT_VECTORS")
        mask = gp.ArrayKey("MASK")
        radius = [0.1, 0.1, 0.1, 0.1]
        ts = TracksSource(TEST_FILE, points)
        apv = AddParentVectors(points, pv_array, mask, radius)
        request = gp.BatchRequest()
        request.add(points, gp.Coordinate((3, 4, 4, 4)))
        request.add(pv_array, gp.Coordinate((1, 4, 4, 4)))
        request.add(mask, gp.Coordinate((1, 4, 4, 4)))

        pipeline = (ts + gp.Pad(points, None) + apv)
        with gp.build(pipeline):
            batch = pipeline.request_batch(request)

        points = batch[points].data
        expected_mask = np.zeros(shape=(1, 4, 4, 4))
        expected_mask[0, 0, 0, 0] = 1
        expected_mask[0, 1, 2, 3] = 1

        expected_parent_vectors_z = np.zeros(shape=(1, 4, 4, 4))
        expected_parent_vectors_z[0, 1, 2, 3] = -1.0

        expected_parent_vectors_y = np.zeros(shape=(1, 4, 4, 4))
        expected_parent_vectors_y[0, 1, 2, 3] = -2.0

        expected_parent_vectors_x = np.zeros(shape=(1, 4, 4, 4))
        expected_parent_vectors_x[0, 1, 2, 3] = -3.0
        # print("MASK")
        # print(batch[mask].data)
        self.assertListEqual(expected_mask.tolist(), batch[mask].data.tolist())

        parent_vectors = batch[pv_array].data
        self.assertListEqual(expected_parent_vectors_z.tolist(),
                             parent_vectors[0].tolist())
        self.assertListEqual(expected_parent_vectors_y.tolist(),
                             parent_vectors[1].tolist())
        self.assertListEqual(expected_parent_vectors_x.tolist(),
                             parent_vectors[2].tolist())
Example #4
0
def build_pipeline(data_dir, model, checkpoint_file, input_size, output_size,
                   raw, labels, affs_predicted, dataset_shape, num_samples,
                   sample_size):

    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint['model_state_dict'])

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(affs_predicted, output_size)
    scan_request.add(labels, output_size)

    pipeline = (
        gp.ZarrSource(str(data_dir), {
            raw: 'validate/raw',
            labels: 'validate/gt'
        }) + gp.Pad(raw, size=None) + gp.Normalize(raw) +
        # raw: (s, h, w)
        # labels: (s, h, w)
        train.AddChannelDim(raw) +
        # raw: (c=1, s, h, w)
        # labels: (s, h, w)
        train.TransposeDims(raw, (1, 0, 2, 3)) +
        # raw: (s, c=1, h, w)
        # labels: (s, h, w)
        Predict(model=model, inputs={'x': raw}, outputs={0: affs_predicted}) +
        # raw: (s, c=1, h, w)
        # affs_predicted: (s, c=2, h, w)
        # labels: (s, h, w)
        train.TransposeDims(raw, (1, 0, 2, 3)) + train.RemoveChannelDim(raw) +
        # raw: (s, h, w)
        # affs_predicted: (s, c=2, h, w)
        # labels: (s, h, w)
        gp.PrintProfilingStats(every=100) + gp.Scan(scan_request))

    return pipeline
Example #5
0
def predict(
    model: Model,
    raw_array: Array,
    prediction_array_identifier: LocalArrayIdentifier,
    num_cpu_workers: int = 4,
    compute_context: ComputeContext = LocalTorch(),
    output_roi: Optional[Roi] = None,
):
    # get the model's input and output size

    input_voxel_size = Coordinate(raw_array.voxel_size)
    output_voxel_size = model.scale(input_voxel_size)
    input_shape = Coordinate(model.eval_input_shape)
    input_size = input_voxel_size * input_shape
    output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

    logger.info(
        "Predicting with input size %s, output size %s", input_size, output_size
    )

    # calculate input and output rois

    context = (input_size - output_size) / 2
    if output_roi is None:
        input_roi = raw_array.roi
        output_roi = input_roi.grow(-context, -context)
    else:
        input_roi = output_roi.grow(context, context)

    logger.info("Total input ROI: %s, output ROI: %s", input_roi, output_roi)

    # prepare prediction dataset
    axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"]
    ZarrArray.create_from_array_identifier(
        prediction_array_identifier,
        axes,
        output_roi,
        model.num_out_channels,
        output_voxel_size,
        np.float32,
    )

    # create gunpowder keys

    raw = gp.ArrayKey("RAW")
    prediction = gp.ArrayKey("PREDICTION")

    # assemble prediction pipeline

    # prepare data source
    pipeline = DaCapoArraySource(raw_array, raw)
    # raw: (c, d, h, w)
    pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims))
    # raw: (c, d, h, w)
    pipeline += gp.Unsqueeze([raw])
    # raw: (1, c, d, h, w)

    gt_padding = (output_size - output_roi.shape) % output_size
    prediction_roi = output_roi.grow(gt_padding)

    # predict
    pipeline += gp_torch.Predict(
        model=model,
        inputs={"x": raw},
        outputs={0: prediction},
        array_specs={
            prediction: gp.ArraySpec(
                roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32
            )
        },
        spawn_subprocess=False,
        device=str(compute_context.device),
    )
    # raw: (1, c, d, h, w)
    # prediction: (1, [c,] d, h, w)

    # prepare writing
    pipeline += gp.Squeeze([raw, prediction])
    # raw: (c, d, h, w)
    # prediction: (c, d, h, w)
    # raw: (c, d, h, w)
    # prediction: (c, d, h, w)

    # write to zarr
    pipeline += gp.ZarrWrite(
        {prediction: prediction_array_identifier.dataset},
        prediction_array_identifier.container.parent,
        prediction_array_identifier.container.name,
    )

    # create reference batch request
    ref_request = gp.BatchRequest()
    ref_request.add(raw, input_size)
    ref_request.add(prediction, output_size)
    pipeline += gp.Scan(ref_request)

    # build pipeline and predict in complete output ROI

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())

    container = zarr.open(prediction_array_identifier.container)
    dataset = container[prediction_array_identifier.dataset]
    dataset.attrs["axes"] = (
        raw_array.axes if "c" in raw_array.axes else ["c"] + raw_array.axes
    )
Example #6
0
def create_pipeline_3d(task, predictor, optimizer, batch_size, outdir,
                       snapshot_every):

    raw_channels = max(1, task.data.raw.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = task.data.raw.train.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    weights = gp.ArrayKey('WEIGHTS')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = task.data.raw.train.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D training not yet implemented")

    sources = (task.data.raw.train.get_source(raw),
               task.data.gt.train.get_source(gt))
    pipeline = sources + gp.MergeProvider()
    pipeline += gp.Pad(raw, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.RandomLocation()
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    for augmentation in eval(task.augmentations):
        pipeline += augmentation
    pipeline += predictor.add_target(gt, target)
    # (don't care about gt anymore)
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    weights_node = task.loss.add_weights(target, weights)
    if weights_node:
        pipeline += weights_node
        loss_inputs = {0: prediction, 1: target, 2: weights}
    else:
        loss_inputs = {0: prediction, 1: target}
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    pipeline += gp.PreCache()
    pipeline += gp.Stack(batch_size)
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    pipeline += gp_torch.Train(model=predictor,
                               loss=task.loss,
                               optimizer=optimizer,
                               inputs={'x': raw},
                               loss_inputs=loss_inputs,
                               outputs={0: prediction},
                               save_every=1e6)
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    # prediction: (b, [c,] d, h, w)
    if snapshot_every > 0:
        # get channels first
        pipeline += TransposeDims(raw, (1, 0, 2, 3, 4))
        if predictor.target_channels > 0:
            pipeline += TransposeDims(target, (1, 0, 2, 3, 4))
            if weights_node:
                pipeline += TransposeDims(weights, (1, 0, 2, 3, 4))
        if predictor.prediction_channels > 0:
            pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4))
        # raw: (c, b, d, h, w)
        # target: ([c,] b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: ([c,] b, d, h, w)
        if channel_dims == 0:
            pipeline += RemoveChannelDim(raw)
        # raw: ([c,] b, d, h, w)
        # target: (c, b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: (c, b, d, h, w)
        pipeline += gp.Snapshot(dataset_names={
            raw: 'raw',
            target: 'target',
            prediction: 'prediction',
            weights: 'weights'
        },
                                every=snapshot_every,
                                output_dir=os.path.join(outdir, 'snapshots'),
                                output_filename="{iteration}.hdf")
    pipeline += gp.PrintProfilingStats(every=100)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, output_size)
    request.add(target, output_size)
    if weights_node:
        request.add(weights, output_size)
    request.add(prediction, output_size)

    return pipeline, request
Example #7
0
def create_pipeline_2d(task, predictor, optimizer, batch_size, outdir,
                       snapshot_every):

    raw_channels = task.data.raw.num_channels
    filename = task.data.raw.train.filename
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = task.data.raw.train.shape
    dataset_roi = task.data.raw.train.roi
    voxel_size = task.data.raw.train.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    weights = gp.ArrayKey('WEIGHTS')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError("For 2D training, please provide a 3D array where "
                           "the first dimension indexes the samples.")

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    sources = (task.data.raw.train.get_source(raw, overwrite_spec=spec),
               task.data.gt.train.get_source(gt, overwrite_spec=spec))
    pipeline = sources + gp.MergeProvider()
    pipeline += gp.Pad(raw, None)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d=1, h, w)
    # gt: ([c,] d=1, h, w)
    pipeline += gp.RandomLocation()
    # raw: ([c,] d=1, h, w)
    # gt: ([c,] d=1, h, w)
    for augmentation in eval(task.augmentations):
        pipeline += augmentation
    pipeline += predictor.add_target(gt, target)
    # (don't care about gt anymore)
    # raw: ([c,] d=1, h, w)
    # target: ([c,] d=1, h, w)
    weights_node = task.loss.add_weights(target, weights)
    if weights_node:
        pipeline += weights_node
        loss_inputs = {0: prediction, 1: target, 2: weights}
    else:
        loss_inputs = {0: prediction, 1: target}
    # raw: ([c,] d=1, h, w)
    # target: ([c,] d=1, h, w)
    # [weights: ([c,] d=1, h, w)]
    # get rid of z dim:
    pipeline += Squash(dim=-3)
    # raw: ([c,] h, w)
    # target: ([c,] h, w)
    # [weights: ([c,] h, w)]
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, h, w)
    # target: ([c,] h, w)
    # [weights: ([c,] h, w)]
    pipeline += gp.PreCache()
    pipeline += gp.Stack(batch_size)
    # raw: (b, c, h, w)
    # target: (b, [c,] h, w)
    # [weights: (b, [c,] h, w)]
    pipeline += gp_torch.Train(model=predictor,
                               loss=task.loss,
                               optimizer=optimizer,
                               inputs={'x': raw},
                               loss_inputs=loss_inputs,
                               outputs={0: prediction},
                               save_every=1e6)
    # raw: (b, c, h, w)
    # target: (b, [c,] h, w)
    # [weights: (b, [c,] h, w)]
    # prediction: (b, [c,] h, w)
    if snapshot_every > 0:
        # get channels first
        pipeline += TransposeDims(raw, (1, 0, 2, 3))
        if predictor.target_channels > 0:
            pipeline += TransposeDims(target, (1, 0, 2, 3))
            if weights_node:
                pipeline += TransposeDims(weights, (1, 0, 2, 3))
        if predictor.prediction_channels > 0:
            pipeline += TransposeDims(prediction, (1, 0, 2, 3))
        # raw: (c, b, h, w)
        # target: ([c,] b, h, w)
        # [weights: ([c,] b, h, w)]
        # prediction: ([c,] b, h, w)
        if channel_dims == 0:
            pipeline += RemoveChannelDim(raw)
        # raw: ([c,] b, h, w)
        # target: ([c,] b, h, w)
        # [weights: ([c,] b, h, w)]
        # prediction: ([c,] b, h, w)
        pipeline += gp.Snapshot(dataset_names={
            raw: 'raw',
            target: 'target',
            prediction: 'prediction',
            weights: 'weights'
        },
                                every=snapshot_every,
                                output_dir=os.path.join(outdir, 'snapshots'),
                                output_filename="{iteration}.hdf")
    pipeline += gp.PrintProfilingStats(every=100)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, output_size)
    request.add(target, output_size)
    if weights_node:
        request.add(weights, output_size)
    request.add(prediction, output_size)

    return pipeline, request
Example #8
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_instances = gp.ArrayKey('GT_INSTANCES')
    gt_mask = gp.ArrayKey('GT_MASK')
    pred_mask = gp.ArrayKey('PRED_MASK')
    #loss_weights = gp.ArrayKey('LOSS_WEIGHTS')
    loss_gradients = gp.ArrayKey('LOSS_GRADIENTS')

    # array keys for base and add volume
    raw_base = gp.ArrayKey('RAW_BASE')
    gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE')
    gt_mask_base = gp.ArrayKey('GT_MASK_BASE')
    raw_add = gp.ArrayKey('RAW_ADD')
    gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD')
    gt_mask_add = gp.ArrayKey('GT_MASK_ADD')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_instances, output_shape)
    request.add(gt_mask, output_shape)
    #request.add(loss_weights, output_shape)
    request.add(raw_base, input_shape)
    request.add(raw_add, input_shape)
    request.add(gt_mask_base, output_shape)
    request.add(gt_mask_add, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    #snapshot_request.add(raw_base, input_shape)
    #snapshot_request.add(raw_add, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    #snapshot_request.add(gt_mask_base, output_shape)
    #snapshot_request.add(gt_mask_add, output_shape)
    snapshot_request.add(pred_mask, output_shape)
    snapshot_request.add(loss_gradients, output_shape)

    # specify data source
    # data source for base volume
    data_sources_base = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_base += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_base: sample + '/raw',
                        gt_instances_base: sample + '/gt',
                        gt_mask_base: sample + '/fg',
                    },
                    array_specs={
                        raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_base, np.uint8) +
                gp.Pad(raw_base, context) +
                gp.Pad(gt_instances_base, context) +
                gp.Pad(gt_mask_base, context) +
                gp.RandomLocation(min_masked=0.005,  mask=gt_mask_base)
                #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.)
                for sample in f)
    data_sources_base += gp.RandomProvider()

    # data source for add volume
    data_sources_add = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_add += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_add: sample + '/raw',
                        gt_instances_add: sample + '/gt',
                        gt_mask_add: sample + '/fg',
                    },
                    array_specs={
                        raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_add, np.uint8) +
                gp.Pad(raw_add, context) +
                gp.Pad(gt_instances_add, context) +
                gp.Pad(gt_mask_add, context) +
                gp.RandomLocation() +
                gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95)
                for sample in f)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider()

    pipeline = (
            data_sources +
            nl.FusionAugment(
                raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances,
                blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0
            ) +
            BinarizeLabels(gt_instances, gt_mask) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2, -1) +
            #gp.BalanceLabels(gt_mask, loss_weights) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=10) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt']: gt_mask,
                    #net_names['loss_weights']: loss_weights,
                },
                outputs={
                    net_names['pred']: pred_mask,
                },
                gradients={
                    net_names['output']: loss_gradients,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    pred_mask: 'volumes/pred_mask',
                    gt_mask: 'volumes/gt_mask',
                    #loss_weights: 'volumes/loss_weights',
                    loss_gradients: 'volumes/loss_gradients',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2500) +
            gp.PrintProfilingStats(every=1000)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
def predict(**kwargs):
    name = kwargs['name']

    raw = gp.ArrayKey('RAW')
    pred_affs = gp.ArrayKey('PRED_AFFS')

    with open(os.path.join(kwargs['input_folder'],
                           name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['input_folder'],
                           name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size
    context = (input_shape_world - output_shape_world)//2

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape_world, voxel_size=voxel_size)
    request.add(pred_affs, output_shape_world, voxel_size=voxel_size)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("predict node for %s not implemented yet",
                                  kwargs['input_format'])
    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
        with h5py.File(os.path.join(kwargs['data_folder'],
                                    kwargs['sample'] + ".hdf"), 'r') as f:
            shape = f['volumes/raw'].shape
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource
        f = zarr.open(os.path.join(kwargs['data_folder'],
                                   kwargs['sample'] + ".zarr"), 'r')
        shape = f['volumes/raw'].shape
    # shape =
    source = sourceNode(
        os.path.join(kwargs['data_folder'],
                     kwargs['sample'] + "." + kwargs['input_format']),
        datasets = {
            raw: 'volumes/raw'
        },
        # array_specs = {
        #     raw: gp.ArraySpec(roi=gp.Roi(gp.Coordinate((0, 0, 400)),
        #                                  gp.Coordinate(input_shape_world)))
        # }
    )

    crop = []
    for d in range(-3, 0):
        if shape[d] < net_config['output_shape'][d]:
            crop.append((net_config['output_shape'][d]-shape[d])//2)
        else:
            crop.append(0)
    print("cropping", crop)
    context += gp.Coordinate(crop)

    if kwargs['output_format'] != "zarr":
        raise NotImplementedError("Please use zarr as prediction output")

    # open zarr file
    zf = zarr.open(os.path.join(kwargs['output_folder'],
                                kwargs['sample'] + '.zarr'), mode='w')
    zf.create('volumes/pred_affs',
              shape=[int(np.prod(kwargs['patchshape']))] + list(shape),
              chunks=[int(np.prod(kwargs['patchshape']))] + list(shape)[:-1] + [20],
              dtype=np.float32)
    zf['volumes/pred_affs'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size']

    zf.create('volumes/raw',
              shape=list(shape),
              chunks=list(shape)[:-1] + [20],
              dtype=np.float32)
    zf['volumes/raw'].attrs['offset'] = [0, 0, 0]
    zf['volumes/raw'].attrs['resolution'] = kwargs['voxel_size']

    outputs = {
        net_names['pred_affs']: pred_affs,
    }
    outVolumes = {
        # raw: '/volumes/raw',
        pred_affs: '/volumes/pred_affs',
    }


    pipeline = (
        source +
        gp.Pad(raw, context) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(
            graph=os.path.join(kwargs['input_folder'], name + '.meta'),
            checkpoint=kwargs['checkpoint'],
            inputs={
                net_names['raw']: raw
            },
            # array_specs={
            #     pred_affs: gp.ArraySpec(roi=gp.Roi(gp.Coordinate((46, 46, 46)),
            #                                        output_shape_world),
            #                             voxel_size=voxel_size)
            # },
            outputs=outputs) +


    # if max(crop) > 0:
    #     print("cropping", crop)
    #     pipeline += gp.Crop(pred_affs, absolute_negative=crop, absolute_positive=crop)
    # pipeline += (
        # store all passing batches in the same HDF5 file
        gp.ZarrWrite(
            outVolumes,
            output_dir=kwargs['output_folder'],
            output_filename=kwargs['sample'] + ".zarr",
            compression_type='gzip'
        ) +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request)
    )

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
Example #10
0
def predict_volume(model,
                   dataset,
                   out_dir,
                   out_filename,
                   out_ds_names,
                   input_key='0/raw',
                   normalize_factor=None,
                   model_output=0,
                   in_shape=None,
                   out_shape=None,
                   spawn_subprocess=True,
                   num_workers=0):

    raw = gp.ArrayKey('RAW')
    prediction = gp.ArrayKey('PREDICTION')

    data = daisy.open_ds(dataset.filename, dataset.ds_names[0])
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    data_dims = len(data.shape)

    # Get in and out shape
    if in_shape is None:
        in_shape = model.in_shape
    if out_shape is None:
        out_shape = model.out_shape

    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    spatial_dims = in_shape.dims()

    if apply_voxel_size:
        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    context = (in_shape - out_shape) / 2

    print("context", context, in_shape, out_shape)

    source = (gp.ZarrSource(
        dataset.filename, {
            raw: dataset.ds_names[0],
        },
        array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)}))

    num_additional_channels = (2 + spatial_dims) - data_dims

    for _ in range(num_additional_channels):
        source += AddChannelDim(raw)

    # prediction requires samples first, channels second
    source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims)))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline = source

    if normalize_factor != "skip":
        pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor)

    pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict(
        model,
        inputs={input_name: raw},
        outputs={model_output: prediction},
        array_specs={prediction: gp.ArraySpec(roi=raw_roi)},
        checkpoint=checkpoint,
        spawn_subprocess=spawn_subprocess))

    # # remove sample dimension for 3D data
    # pipeline += RemoveChannelDim(raw)
    # pipeline += RemoveChannelDim(prediction)

    pipeline += (gp.ZarrWrite({
        prediction: out_ds_names[0],
    },
                              output_dir=out_dir,
                              output_filename=out_filename,
                              compression_type='gzip') +
                 gp.Scan(request, num_workers=num_workers))

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
Example #11
0
def predict_2d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = raw_data.shape
    dataset_roi = raw_data.roi
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError(
            "For 2D validation, please provide a 3D array where the first "
            "dimension indexes the samples.")

    num_samples = raw_data.num_samples

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    if gt_data:
        sources = (raw_data.get_source(raw, overwrite_spec=spec),
                   gt_data.get_source(gt, overwrite_spec=spec))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw, overwrite_spec=spec)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    # target: ([c,] s, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    if gt_data and predictor.target_channels == 0:
        pipeline += AddChannelDim(target)
    # raw: (c, s, h, w)
    # gt: ([c,] s, h, w)
    # target: (c, s, h, w)
    pipeline += TransposeDims(raw, (1, 0, 2, 3))
    if gt_data:
        pipeline += TransposeDims(target, (1, 0, 2, 3))
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    # prediction: (s, c, h, w)
    pipeline += gp.Scan(scan_request)

    total_request = gp.BatchRequest()
    total_request.add(raw, sample_size)
    total_request.add(prediction, sample_size)
    if gt_data:
        total_request.add(gt, sample_size)
        total_request.add(target, sample_size)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Example #12
0
def predict(**kwargs):
    name = kwargs['name']

    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_fgbg = gp.ArrayKey('PRED_FGBG')

    with open(os.path.join(kwargs['input_folder'], name + '_config.json'),
              'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['input_folder'], name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size
    context = (input_shape_world - output_shape_world) // 2

    # formulate the request for what a batch should contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(pred_affs, output_shape_world)
    request.add(pred_fgbg, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("predict node for %s not implemented yet",
                                  kwargs['input_format'])
    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
        with h5py.File(
                os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"),
                'r') as f:
            shape = f['volumes/raw'].shape
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource
        f = zarr.open(
            os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"),
            'r')
        shape = f['volumes/raw'].shape
    source = sourceNode(os.path.join(
        kwargs['data_folder'],
        kwargs['sample'] + "." + kwargs['input_format']),
                        datasets={raw: 'volumes/raw'})

    if kwargs['output_format'] != "zarr":
        raise NotImplementedError("Please use zarr as prediction output")
    # pre-create zarr file
    zf = zarr.open(os.path.join(kwargs['output_folder'],
                                kwargs['sample'] + '.zarr'),
                   mode='w')
    zf.create('volumes/pred_affs',
              shape=[3] + list(shape),
              chunks=[3] + list(shape),
              dtype=np.float32)
    zf['volumes/pred_affs'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size']

    zf.create('volumes/pred_fgbg',
              shape=[1] + list(shape),
              chunks=[1] + list(shape),
              dtype=np.float32)
    zf['volumes/pred_fgbg'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_fgbg'].attrs['resolution'] = kwargs['voxel_size']

    zf.create('volumes/raw_cropped',
              shape=[1] + list(shape),
              chunks=[1] + list(shape),
              dtype=np.float32)
    zf['volumes/raw_cropped'].attrs['offset'] = [0, 0, 0]
    zf['volumes/raw_cropped'].attrs['resolution'] = kwargs['voxel_size']

    pipeline = (

        # read from HDF5 file
        source + gp.Pad(raw, context) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(graph=os.path.join(kwargs['input_folder'],
                                                 name + '.meta'),
                              checkpoint=kwargs['checkpoint'],
                              inputs={net_names['raw']: raw},
                              outputs={
                                  net_names['pred_affs']: pred_affs,
                                  net_names['pred_fgbg']: pred_fgbg,
                                  net_names['raw_cropped']: raw_cropped
                              }) +

        # store all passing batches in the same HDF5 file
        gp.ZarrWrite(
            {
                raw_cropped: '/volumes/raw_cropped',
                pred_affs: '/volumes/pred_affs',
                pred_fgbg: '/volumes/pred_fgbg',
            },
            output_dir=kwargs['output_folder'],
            output_filename=kwargs['sample'] + ".zarr",
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request))

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
Example #13
0
def predict(iteration,
            raw_file,
            raw_dataset,
            out_file,
            db_host,
            db_name,
            worker_config,
            network_config,
            out_properties={},
            **kwargs):
    setup_dir = os.path.dirname(os.path.realpath(__file__))

    with open(
            os.path.join(setup_dir,
                         '{}_net_config.json'.format(network_config)),
            'r') as f:
        net_config = json.load(f)

    # voxels
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    # nm
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size

    parameterfile = os.path.join(setup_dir, 'parameter.json')
    if os.path.exists(parameterfile):
        with open(parameterfile, 'r') as f:
            parameters = json.load(f)
    else:
        parameters = {}

    raw = gp.ArrayKey('RAW')
    pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS')
    pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR')

    chunk_request = gp.BatchRequest()
    chunk_request.add(raw, input_size)
    chunk_request.add(pred_postpre_vectors, output_size)
    chunk_request.add(pred_post_indicator, output_size)

    d_property = out_properties[
        'pred_partner_vectors'] if 'pred_partner_vectors' in out_properties else None
    m_property = out_properties[
        'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None

    # Hdf5Source
    if raw_file.endswith('.hdf'):
        pipeline = gp.Hdf5Source(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'):
        pipeline = gp.ZarrSource(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    else:
        raise RuntimeError('unknwon input data format {}'.format(raw_file))

    pipeline += gp.Pad(raw, size=None)

    pipeline += gp.Normalize(raw)

    pipeline += gp.IntensityScaleShift(raw, 2, -1)

    pipeline += gp.tensorflow.Predict(
        os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration),
        inputs={net_config['raw']: raw},
        outputs={
            net_config['pred_syn_indicator_out']: pred_post_indicator,
            net_config['pred_partner_vectors']: pred_postpre_vectors
        },
        graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config)))
    d_scale = parameters['d_scale'] if 'd_scale' in parameters else None
    if d_scale != 1 and d_scale is not None:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors, 1. / d_scale,
                                           0)  # Map back to nm world.
    if m_property is not None and 'scale' in m_property:
        if m_property['scale'] != 1:
            pipeline += gp.IntensityScaleShift(pred_post_indicator,
                                               m_property['scale'], 0)
    if d_property is not None and 'scale' in d_property:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors,
                                           d_property['scale'], 0)
    if d_property is not None and 'dtype' in d_property:
        assert d_property['dtype'] == 'int8' or d_property[
            'dtype'] == 'float32', 'predict not adapted to dtype {}'.format(
                d_property['dtype'])
        if d_property['dtype'] == 'int8':
            pipeline += IntensityScaleShiftClip(pred_postpre_vectors,
                                                1,
                                                0,
                                                clip=(-128, 127))

    pipeline += gp.ZarrWrite(dataset_names={
        pred_post_indicator:
        'volumes/pred_syn_indicator',
        pred_postpre_vectors:
        'volumes/pred_partner_vectors',
    },
                             output_filename=out_file)

    pipeline += gp.PrintProfilingStats(every=10)

    pipeline += gp.DaisyRequestBlocks(
        chunk_request,
        roi_map={
            raw: 'read_roi',
            pred_postpre_vectors: 'write_roi',
            pred_post_indicator: 'write_roi'
        },
        num_workers=worker_config['num_cache_workers'],
        block_done_callback=lambda b, s, d: block_done_callback(
            db_host, db_name, worker_config, b, s, d))

    print("Starting prediction...")
    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
    print("Prediction finished")
Example #14
0
def predict_frame(in_shape,
                  out_shape,
                  model_output,
                  model_configfile,
                  model_checkpoint,
                  input_dataset_file,
                  inference_frame,
                  out_dir,
                  out_filename,
                  out_key_or_index=1,
                  intermediate_layer=None,
                  dataset_raw_key="train/raw",
                  dataset_prediction_key="train/prediction",
                  dataset_intermediate_key="train/prediction_interm",
                  model_input_tensor_name="patches",
                  model_architecture="PatchedResnet",
                  num_workers=5):

    # initialize model
    if model_architecture == "PatchedResnet":
        model = PatchedResnet(1, 2, resnet_size=18)
    elif model_architecture == "unet":
        model = lisl.models.create(model_configfile)
    else:
        raise NotImplementedError(f"{model_architecture} not implemented")

    model.add_spatial_dim = True
    model.eval()

    # gp variables
    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    raw = gp.ArrayKey(f'RAW_{inference_frame}')
    prediction = gp.ArrayKey(f'PREDICTION_{inference_frame}')
    intermediate_prediction = gp.ArrayKey(f'ITERM_{inference_frame}')

    ds_key = f'{dataset_raw_key}/{inference_frame}'
    out_key = f'{dataset_prediction_key}/{inference_frame}'
    interm_key = f'{dataset_intermediate_key}/{inference_frame}'

    # build pipeline
    zsource = gp.ZarrSource(
        input_dataset_file, {raw: ds_key},
        {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))})

    pipeline = zsource
    with gp.build(zsource):
        raw_roi = zsource.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline += AddChannelDim(raw)
    pipeline += AddChannelDim(raw)

    pipeline += gp.Pad(raw, None)
    # setup prediction node
    pred_dict = {out_key_or_index: prediction}
    pred_spec = {prediction: gp.ArraySpec(roi=raw_roi)}
    if intermediate_layer is not None:
        pred_dict[intermediate_layer] = intermediate_prediction
        pred_spec[intermediate_prediction] = gp.ArraySpec(roi=raw_roi)

    pipeline += gp.torch.Predict(model,
                                 inputs={model_input_tensor_name: raw},
                                 outputs=pred_dict,
                                 array_specs=pred_spec,
                                 checkpoint=model_checkpoint,
                                 spawn_subprocess=True)

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    zarr_dict = {prediction: out_key}
    if intermediate_layer is not None:
        zarr_dict[intermediate_prediction] = interm_key
        request.add(intermediate_prediction, out_shape)
    pipeline += gp.Scan(request, num_workers=num_workers)
    pipeline += gp.ZarrWrite(zarr_dict,
                             output_dir=out_dir,
                             output_filename=out_filename,
                             compression_type='gzip')

    total_request = gp.BatchRequest()
    total_request[prediction] = gp.ArraySpec(roi=raw_roi)
    if intermediate_layer is not None:
        total_request[intermediate_prediction] = gp.ArraySpec(roi=raw_roi)
    with gp.build(pipeline):
        pipeline.request_batch(total_request)
Example #15
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_threeclass = gp.ArrayKey('GT_THREECLASS')

    loss_weights_threeclass = gp.ArrayKey('LOSS_WEIGHTS_THREECLASS')

    pred_threeclass = gp.ArrayKey('PRED_THREECLASS')

    pred_threeclass_gradients = gp.ArrayKey('PRED_THREECLASS_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_threeclass, output_shape_world)
    request.add(anchor, output_shape_world)
    request.add(loss_weights_threeclass, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(gt_threeclass, output_shape_world)
    snapshot_request.add(pred_threeclass, output_shape_world)
    # snapshot_request.add(pred_threeclass_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for {} not implemented".format(
            kwargs['input_format']))

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    # padR = 46
    # padGT = 32

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    pipeline = (
        tuple(
            # read batches from the HDF5 file
            sourceNode(
                fls[t] + "." + kwargs['input_format'],
                datasets={
                    raw: 'volumes/raw',
                    gt_threeclass: 'volumes/gt_threeclass',
                    anchor: 'volumes/gt_threeclass',
                },
                array_specs={
                    raw: gp.ArraySpec(interpolatable=True),
                    gt_threeclass: gp.ArraySpec(interpolatable=False),
                    anchor: gp.ArraySpec(interpolatable=False)
                }
            )
            + gp.MergeProvider()
            + gp.Pad(raw, None)
            + gp.Pad(gt_threeclass, None)
            + gp.Pad(anchor, gp.Coordinate((2,2,2)))


            # chose a random location for each requested batch
            + gp.RandomLocation()

            for t in range(ln)
        ) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        # grow a boundary between labels
        # TODO: check
        # gp.GrowBoundary(
        #     gt_threeclass,
        #     steps=1,
        #     only_xy=False) +

        gp.BalanceLabels(
            gt_threeclass,
            loss_weights_threeclass,
            num_classes=3) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['anchor']: anchor,
                net_names['gt_threeclass']: gt_threeclass,
                net_names['loss_weights_threeclass']: loss_weights_threeclass
            },
            outputs={
                net_names['pred_threeclass']: pred_threeclass,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                net_names['pred_threeclass']: pred_threeclass_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_threeclass: '/volumes/gt_threeclass',
                pred_threeclass: '/volumes/pred_threeclass',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
Example #16
0
    def build_batch_provider(self, datasets, model, task, snapshot_container=None):
        input_shape = Coordinate(model.input_shape)
        output_shape = Coordinate(model.output_shape)

        # get voxel sizes
        raw_voxel_size = datasets[0].raw.voxel_size
        prediction_voxel_size = model.scale(raw_voxel_size)

        # define input and output size:
        # switch to world units
        input_size = raw_voxel_size * input_shape
        output_size = prediction_voxel_size * output_shape

        # padding of groundtruth/mask
        gt_mask_padding = output_size + task.predictor.padding(prediction_voxel_size)

        # define keys:
        raw_key = gp.ArrayKey("RAW")
        gt_key = gp.ArrayKey("GT")
        mask_key = gp.ArrayKey("MASK")

        target_key = gp.ArrayKey("TARGET")
        weight_key = gp.ArrayKey("WEIGHT")

        # Get source nodes
        dataset_sources = []
        for dataset in datasets:

            raw_source = DaCapoArraySource(dataset.raw, raw_key)
            raw_source += gp.Pad(raw_key, None, 0)
            gt_source = DaCapoArraySource(dataset.gt, gt_key)
            gt_source += gp.Pad(gt_key, gt_mask_padding, 0)
            if dataset.mask is not None:
                mask_source = DaCapoArraySource(dataset.mask, mask_key)
            else:
                # Always provide a mask. By default it is simply an array
                # of ones with the same shape/roi as gt. Avoids making us
                # specially handle no mask case and allows padding of the
                # ground truth without worrying about training on incorrect
                # data.
                mask_source = DaCapoArraySource(OnesArray.like(dataset.gt), mask_key)
            mask_source += gp.Pad(mask_key, gt_mask_padding, 0)
            array_sources = [raw_source, gt_source, mask_source]

            dataset_source = (
                tuple(array_sources) + gp.MergeProvider() + gp.RandomLocation()
            )

            dataset_sources.append(dataset_source)
        pipeline = tuple(dataset_sources) + gp.RandomProvider()

        for augment in self.augments:
            pipeline += augment.node(raw_key, gt_key, mask_key)

        pipeline += gp.Reject(mask_key, min_masked=self.min_masked)

        # Add predictor nodes to pipeline
        pipeline += DaCapoTargetFilter(
            task.predictor,
            gt_key=gt_key,
            target_key=target_key,
            weights_key=weight_key,
            mask_key=mask_key,
        )

        # Trainer attributes:
        if self.num_data_fetchers > 1:
            pipeline += gp.PreCache(num_workers=self.num_data_fetchers)

        # stack to create a batch dimension
        pipeline += gp.Stack(self.batch_size)

        # print profiling stats
        pipeline += gp.PrintProfilingStats(every=self.print_profiling)

        # generate request for all necessary inputs to training
        request = gp.BatchRequest()
        request.add(raw_key, input_size)
        request.add(target_key, output_size)
        request.add(weight_key, output_size)
        # request additional keys for snapshots
        request.add(gt_key, output_size)
        request.add(mask_key, output_size)

        self._request = request
        self._pipeline = pipeline
        self._raw_key = raw_key
        self._gt_key = gt_key
        self._mask_key = mask_key
        self._weight_key = weight_key
        self._target_key = target_key
        self._loss = task.loss

        self.snapshot_container = snapshot_container
Example #17
0
def predict_volume(model,
                   dataset,
                   out_dir,
                   out_filename,
                   out_ds_names,
                   checkpoint,
                   input_name='raw_0',
                   normalize_factor=None,
                   model_output=0,
                   in_shape=None,
                   out_shape=None,
                   spawn_subprocess=True,
                   num_workers=0,
                   apply_voxel_size=True):

    raw = gp.ArrayKey('RAW')
    prediction = gp.ArrayKey('PREDICTION')

    data = daisy.open_ds(dataset.filename, dataset.ds_names[0])
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    data_dims = len(data.shape)

    # Get in and out shape
    if in_shape is None:
        in_shape = model.in_shape
    if out_shape is None:
        out_shape = model.out_shape

    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    spatial_dims = in_shape.dims()
    is_2d = spatial_dims == 2

    in_shape = in_shape * voxel_size
    out_shape = out_shape * voxel_size

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    context = (in_shape - out_shape) / 2

    source = (gp.ZarrSource(
        dataset.filename, {
            raw: dataset.ds_names[0],
        },
        array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)}))

    # ensure raw has sample and channel dims
    #
    # n = number of samples
    # c = number of channels

    # 2D raw is either (n, y, x) or (c, n, y, x)
    # 3D raw is either (z, y, x) or (c, z, y, x)
    for _ in range((2 + spatial_dims) - data_dims):
        source += AddChannelDim(raw)

    # 2D raw: (c, n, y, x)
    # 3D raw: (c, n=1, z, y, x)

    # prediction requires samples first, channels second
    source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims)))

    # 2D raw: (n, c, y, x)
    # 3D raw: (n=1, c, z, y, x)

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline = source

    if normalize_factor != "skip":
        pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor)

    pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict(
        model,
        inputs={input_name: raw},
        outputs={model_output: prediction},
        array_specs={prediction: gp.ArraySpec(roi=raw_roi)},
        checkpoint=checkpoint,
        spawn_subprocess=spawn_subprocess))

    # 2D raw       : (n, c, y, x)
    # 2D prediction: (n, c, y, x)
    # 3D raw       : (n=1, c, z, y, x)
    # 3D prediction: (n=1, c, z, y, x)

    if is_2d:

        # restore channels first for 2D data
        pipeline += TransposeDims(raw,
                                  (1, 0) + tuple(range(2, 2 + spatial_dims)))
        pipeline += TransposeDims(prediction,
                                  (1, 0) + tuple(range(2, 2 + spatial_dims)))

    else:

        # remove sample dimension for 3D data
        pipeline += RemoveChannelDim(raw)
        pipeline += RemoveChannelDim(prediction)

    # 2D raw       : (c, n, y, x)
    # 2D prediction: (c, n, y, x)
    # 3D raw       : (c, z, y, x)
    # 3D prediction: (c, z, y, x)

    pipeline += (gp.ZarrWrite({
        prediction: out_ds_names[0],
    },
                              output_dir=out_dir,
                              output_filename=out_filename,
                              compression_type='gzip') +
                 gp.Scan(request, num_workers=num_workers))

    logger.info("Writing prediction to %s/%s[%s]", out_dir, out_filename,
                out_ds_names[0])

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
Example #18
0
def train_until(max_iteration):

    in_channels = 1
    num_fmaps = 12
    fmap_inc_factors = 6
    downsample_factors = [(1, 3, 3), (1, 3, 3), (3, 3, 3)]

    unet = UNet(in_channels,
                num_fmaps,
                fmap_inc_factors,
                downsample_factors,
                constant_upsample=True)

    model = Convolve(unet, 12, 1)

    loss = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

    # start of gunpowder part:

    raw = gp.ArrayKey('RAW')
    points = gp.GraphKey('POINTS')
    groundtruth = gp.ArrayKey('RASTER')
    prediction = gp.ArrayKey('PRED_POINT')
    grad = gp.ArrayKey('GRADIENT')

    voxel_size = gp.Coordinate((40, 4, 4))

    input_shape = (96, 430, 430)
    output_shape = (60, 162, 162)

    input_size = gp.Coordinate(input_shape) * voxel_size
    output_size = gp.Coordinate(output_shape) * voxel_size

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(points, output_size)
    request.add(groundtruth, output_size)
    request.add(prediction, output_size)
    request.add(grad, output_size)

    pos_sources = tuple(
        gp.ZarrSource(filename, {raw: 'volumes/raw'},
                      {raw: gp.ArraySpec(interpolatable=True)}) +
        AddCenterPoint(points, raw) + gp.Pad(raw, None) +
        gp.RandomLocation(ensure_nonempty=points)
        for filename in pos_samples) + gp.RandomProvider()
    neg_sources = tuple(
        gp.ZarrSource(filename, {raw: 'volumes/raw'},
                      {raw: gp.ArraySpec(interpolatable=True)}) +
        AddNoPoint(points, raw) + gp.RandomLocation()
        for filename in neg_samples) + gp.RandomProvider()

    data_sources = (pos_sources, neg_sources)
    data_sources += gp.RandomProvider(probabilities=[0.9, 0.1])
    data_sources += gp.Normalize(raw)

    train_pipeline = data_sources
    train_pipeline += gp.ElasticAugment(control_point_spacing=[4, 40, 40],
                                        jitter_sigma=[0, 2, 2],
                                        rotation_interval=[0, math.pi / 2.0],
                                        prob_slip=0.05,
                                        prob_shift=0.05,
                                        max_misalign=10,
                                        subsample=8)
    train_pipeline += gp.SimpleAugment(transpose_only=[1, 2])

    train_pipeline += gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1, \
            z_section_wise=True)
    train_pipeline += gp.RasterizePoints(
        points,
        groundtruth,
        array_spec=gp.ArraySpec(voxel_size=voxel_size),
        settings=gp.RasterizationSettings(radius=(100, 100, 100), mode='peak'))
    train_pipeline += gp.PreCache(cache_size=40, num_workers=10)

    train_pipeline += Reshape(raw, (1, 1) + input_shape)
    train_pipeline += Reshape(groundtruth, (1, 1) + output_shape)

    train_pipeline += gp_torch.Train(model=model,
                                     loss=loss,
                                     optimizer=optimizer,
                                     inputs={'x': raw},
                                     outputs={0: prediction},
                                     loss_inputs={
                                         0: prediction,
                                         1: groundtruth
                                     },
                                     gradients={0: grad},
                                     save_every=1000,
                                     log_dir='log')

    train_pipeline += Reshape(raw, input_shape)
    train_pipeline += Reshape(groundtruth, output_shape)
    train_pipeline += Reshape(prediction, output_shape)
    train_pipeline += Reshape(grad, output_shape)

    train_pipeline += gp.Snapshot(
        {
            raw: 'volumes/raw',
            groundtruth: 'volumes/groundtruth',
            prediction: 'volumes/prediction',
            grad: 'volumes/gradient'
        },
        every=500,
        output_filename='test_{iteration}.hdf')
    train_pipeline += gp.PrintProfilingStats(every=10)

    with gp.build(train_pipeline):
        for i in range(max_iteration):
            train_pipeline.request_batch(request)
Example #19
0
    def make_pipeline(self):
        raw = gp.ArrayKey('RAW')
        embs = gp.ArrayKey('EMBS')

        source_shape = zarr.open(self.data_file)[self.dataset].shape
        raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:])

        data = daisy.open_ds(self.data_file, self.dataset)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(self.model.in_shape)
        out_shape = gp.Coordinate(self.model.out_shape[2:])

        is_2d = in_shape.dims() == 2

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")
        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(embs, out_shape)

        context = (in_shape - out_shape) / 2

        source = (gp.ZarrSource(self.data_file, {
            raw: self.dataset,
        },
                                array_specs={
                                    raw:
                                    gp.ArraySpec(roi=source_roi,
                                                 interpolatable=False)
                                }))

        if is_2d:
            source = (source + AddChannelDim(raw, axis=1))
        else:
            source = (source + AddChannelDim(raw, axis=0) + AddChannelDim(raw))

        source = (
            source
            # raw      : (c=1, roi)
        )

        with gp.build(source):
            raw_roi = source.spec[raw].roi
            logger.info(f"raw_roi: {raw_roi}")

        pipeline = (
            source + gp.Normalize(raw, factor=self.params['norm_factor']) +
            gp.Pad(raw, context) + gp.PreCache() +
            gp.torch.Predict(self.model,
                             inputs={'raw': raw},
                             outputs={0: embs},
                             array_specs={embs: gp.ArraySpec(roi=raw_roi)}))

        pipeline = (pipeline +
                    gp.ZarrWrite({
                        embs: 'embs',
                    },
                                 output_dir=self.curr_log_dir,
                                 output_filename=self.dataset + '_embs.zarr',
                                 compression_type='gzip') + gp.Scan(request))

        return pipeline, request, embs
Example #20
0
def create_pipeline_3d(
    task, data, predictor, optimizer, batch_size, outdir, snapshot_every
):

    raw_channels = max(1, data.raw.num_channels)
    input_shape = gp.Coordinate(task.model.input_shape)
    output_shape = gp.Coordinate(task.model.output_shape)
    voxel_size = data.raw.train.voxel_size

    task.predictor = task.predictor.to("cuda")

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey("RAW")
    gt = gp.ArrayKey("GT")
    mask = gp.ArrayKey("MASK")
    target = gp.ArrayKey("TARGET")
    weights = gp.ArrayKey("WEIGHTS")
    model_outputs = gp.ArrayKey("MODEL_OUTPUTS")
    model_output_grads = gp.ArrayKey("MODEL_OUT_GRAD")
    prediction = gp.ArrayKey("PREDICTION")
    pred_gradients = gp.ArrayKey("PRED_GRADIENTS")

    snapshot_dataset_names = {
        raw: "raw",
        model_outputs: "model_outputs",
        model_output_grads: "model_out_grad",
        target: "target",
        prediction: "prediction",
        pred_gradients: "pred_gradients",
        weights: "weights",
    }

    aux_keys = {}
    aux_grad_keys = {}
    for name, _, _ in task.aux_tasks:
        aux_keys[name] = (
            gp.ArrayKey(f"{name.upper()}_PREDICTION"),
            gp.ArrayKey(f"{name.upper()}_TARGET"),
            None,
        )
        aux_grad_keys[name] = gp.ArrayKey(f"{name.upper()}_PRED_GRAD")

        aux_pred, aux_target, _ = aux_keys[name]

        snapshot_dataset_names[aux_pred] = f"{name}_pred"
        snapshot_dataset_names[aux_target] = f"{name}_target"
        
        aux_grad = aux_grad_keys[name]
        snapshot_dataset_names[aux_grad] = f"{name}_aux_grad"

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = data.raw.train.num_samples
    assert num_samples == 0, "Multiple samples for 3D training not yet implemented"

    sources = (data.raw.train.get_source(raw), data.gt.train.get_source(gt))
    pipeline = sources + gp.MergeProvider()
    pipeline += gp.Pad(raw, input_shape / 2 * voxel_size)
    # pipeline += gp.Pad(gt, input_shape / 2 * voxel_size)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    
    mask_node = task.loss.add_mask(gt, mask)
    if mask_node is not None:
        pipeline += mask_node
        pipeline += gp.RandomLocation(min_masked=1e-6, mask=mask)
    else:
        # raw: ([c,] d, h, w)
        # gt: ([c,] d, h, w)
        pipeline += gp.RandomLocation()
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    for augmentation in eval(task.augmentations):
        pipeline += augmentation
    pipeline += predictor.add_target(gt, target)
    # (don't care about gt anymore)
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    weights_node = task.loss.add_weights(target, weights)
    loss_inputs = []
    if weights_node:
        pipeline += weights_node
        loss_inputs.append({0: prediction, 1: target, 2: weights})
    else:
        loss_inputs.append({0: prediction, 1: target})

    head_outputs = []
    head_gradients = []
    for name, aux_predictor, aux_loss in task.aux_tasks:
        aux_prediction, aux_target, aux_weights = aux_keys[name]
        pipeline += aux_predictor.add_target(gt, aux_target)
        aux_weights_node = aux_loss.add_weights(aux_target, aux_weights)
        if aux_weights_node:
            aux_weights = gp.ArrayKey(f"{name.upper()}_WEIGHTS")
            aux_keys[name] = (
                aux_prediction,
                aux_target,
                aux_weights,
            )
            pipeline += aux_weights_node
            loss_inputs.append({0: aux_prediction, 1: aux_target, 2: aux_weights})
            snapshot_dataset_names[aux_weights] = f"{name}_weights"
        else:
            loss_inputs.append({0: aux_prediction, 1: aux_target})
        head_outputs.append({0: aux_prediction})
        aux_pred_gradient = aux_grad_keys[name]
        head_gradients.append({0: aux_pred_gradient})
    # raw: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # target: ([c,] d, h, w)
    # [weights: ([c,] d, h, w)]
    pipeline += gp.PreCache()
    pipeline += gp.Stack(batch_size)
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    pipeline += Train(
        model=task.model,
        heads=[("opt", predictor)]
        + [(name, aux_pred) for name, aux_pred, _ in task.aux_tasks],
        losses=[task.loss] + [loss for _, _, loss in task.aux_tasks],
        optimizer=optimizer,
        inputs={"x": raw},
        outputs={0: model_outputs},
        head_outputs=[{0: prediction}] + head_outputs,
        loss_inputs=loss_inputs,
        gradients=[{0: model_output_grads}, {0: pred_gradients}] + head_gradients,
        save_every=1e6,
    )
    # raw: (b, c, d, h, w)
    # target: (b, [c,] d, h, w)
    # [weights: (b, [c,] d, h, w)]
    # prediction: (b, [c,] d, h, w)
    if snapshot_every > 0:
        # get channels first
        pipeline += TransposeDims(raw, (1, 0, 2, 3, 4))
        if predictor.target_channels > 0:
            pipeline += TransposeDims(target, (1, 0, 2, 3, 4))
            if weights_node:
                pipeline += TransposeDims(weights, (1, 0, 2, 3, 4))
        if predictor.prediction_channels > 0:
            pipeline += TransposeDims(prediction, (1, 0, 2, 3, 4))
        # raw: (c, b, d, h, w)
        # target: ([c,] b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: ([c,] b, d, h, w)
        if channel_dims == 0:
            pipeline += RemoveChannelDim(raw)
        # raw: ([c,] b, d, h, w)
        # target: (c, b, d, h, w)
        # [weights: ([c,] b, d, h, w)]
        # prediction: (c, b, d, h, w)
        pipeline += gp.Snapshot(
            dataset_names=snapshot_dataset_names,
            every=snapshot_every,
            output_dir=os.path.join(outdir, "snapshots"),
            output_filename="{iteration}.hdf",
        )
    pipeline += gp.PrintProfilingStats(every=10)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, output_size)
    if mask_node is not None:
        request.add(mask, output_size)
    request.add(target, output_size)
    for name, _, _ in task.aux_tasks:
        aux_pred, aux_target, aux_weight = aux_keys[name]
        request.add(aux_pred, output_size)
        request.add(aux_target, output_size)
        if aux_weight is not None:
            request.add(aux_weight, output_size)
        aux_pred_grad = aux_grad_keys[name]
        request.add(aux_pred_grad, output_size)
    if weights_node:
        request.add(weights, output_size)
    request.add(prediction, output_size)
    request.add(pred_gradients, output_size)

    return pipeline, request
Example #21
0
    def create_train_pipeline(self, model):

        print(
            f"Creating training pipeline with batch size {self.params['batch_size']}"
        )

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        gt_aff = gp.ArrayKey('AFFINITIES')
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_shape = gp.Coordinate(model.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_shape = out_shape * out_voxel_size

        context = (in_shape - out_shape) / 2
        gt_labels_out_shape = out_shape
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            aff_neighborhood = [[0, -1, 0], [0, 0, -1]]
            gt_labels_out_shape = (1, *gt_labels_out_shape)
        else:
            aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(gt_aff, out_shape)
        request.add(predictions, out_shape)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))
        snapshot_request[gt_labels] = gp.ArraySpec(
            roi=gp.Roi(context, gt_labels_out_shape))

        source = (
            gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                          array_specs={
                              raw:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size,
                                           interpolatable=True),
                              gt_labels:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size)
                          }) + gp.Normalize(raw, self.params['norm_factor']) +
            gp.Pad(raw, context) + gp.Pad(gt_labels, context) +
            gp.RandomLocation()
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
            gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) +
            SetDtype(gt_aff, np.float32) +
            # raw      : (l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
            AddChannelDim(raw)
            # raw      : (c=1, l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
        )

        if is_2d:
            pipeline = (
                pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff)
                # raw      : (c=1, h, w)
                # gt_aff   : (c=2, h, w)
            )

        pipeline = (
            pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, h, w)
            # gt_aff   : (b, c=2, h, w)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={'raw': raw},
                loss_inputs={
                    0: predictions,
                    1: gt_aff
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(voxel_size=out_voxel_size),
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, h, w)
            # gt_aff     : (b, c=2, h, w)
            # predictions: (b, c=2, h, w)

            # Crop GT to look at labels
            gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) +
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            gt_aff: 'gt_aff',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            gp.PrintProfilingStats(every=500))

        return pipeline, request
Example #22
0
def random_point_pairs_pipeline(model,
                                loss,
                                optimizer,
                                dataset,
                                augmentation_parameters,
                                point_density,
                                out_dir,
                                normalize_factor=None,
                                checkpoint_interval=5000,
                                snapshot_interval=5000):

    raw_0 = gp.ArrayKey('RAW_0')
    points_0 = gp.GraphKey('POINTS_0')
    locations_0 = gp.ArrayKey('LOCATIONS_0')
    emb_0 = gp.ArrayKey('EMBEDDING_0')
    raw_1 = gp.ArrayKey('RAW_1')
    points_1 = gp.GraphKey('POINTS_1')
    locations_1 = gp.ArrayKey('LOCATIONS_1')
    emb_1 = gp.ArrayKey('EMBEDDING_1')

    # TODO parse this key from somewhere
    key = 'train/raw/0'

    data = daisy.open_ds(dataset.filename, key)
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    emb_voxel_size = voxel_size

    # Get in and out shape
    in_shape = gp.Coordinate(model.in_shape)
    out_shape = gp.Coordinate(model.out_shape)

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw_0, in_shape)
    request.add(raw_1, in_shape)
    request.add(points_0, out_shape)
    request.add(points_1, out_shape)
    request[locations_0] = gp.ArraySpec(nonspatial=True)
    request[locations_1] = gp.ArraySpec(nonspatial=True)

    snapshot_request = gp.BatchRequest()
    snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi)
    snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi)

    # Let's hardcode this for now
    # TODO read actual number from zarr file keys
    n_samples = 447
    batch_size = 1
    dim = 2
    padding = (100, 100)

    sources = []
    for i in range(n_samples):

        ds_key = f'train/raw/{i}'
        image_sources = tuple(
            gp.ZarrSource(
                dataset.filename, {raw: ds_key},
                {raw: gp.ArraySpec(interpolatable=True, voxel_size=(1, 1))}) +
            gp.Pad(raw, None) for raw in [raw_0, raw_1])

        random_point_generator = RandomPointGenerator(density=point_density,
                                                      repetitions=2)

        point_sources = tuple(
            (RandomPointSource(points_0,
                               dim,
                               random_point_generator=random_point_generator),
             RandomPointSource(points_1,
                               dim,
                               random_point_generator=random_point_generator)))

        # TODO: get augmentation parameters from some config file!
        points_and_image_sources = tuple(
            (img_source, point_source) + gp.MergeProvider() + \
            gp.SimpleAugment() + \
            gp.ElasticAugment(
                spatial_dims=2,
                control_point_spacing=(10, 10),
                jitter_sigma=(0.0, 0.0),
                rotation_interval=(0, math.pi/2)) + \
            gp.IntensityAugment(r,
                                scale_min=0.8,
                                scale_max=1.2,
                                shift_min=-0.2,
                                shift_max=0.2,
                                clip=False) + \
            gp.NoiseAugment(r, var=0.01, clip=False)
            for r, img_source, point_source
            in zip([raw_0, raw_1], image_sources, point_sources))

        sample_source = points_and_image_sources + gp.MergeProvider()

        data = daisy.open_ds(dataset.filename, ds_key)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        sample_source += gp.Crop(raw_0, source_roi)
        sample_source += gp.Crop(raw_1, source_roi)
        sample_source += gp.Pad(raw_0, padding)
        sample_source += gp.Pad(raw_1, padding)
        sample_source += gp.RandomLocation()
        sources.append(sample_source)

    sources = tuple(sources)

    pipeline = sources + gp.RandomProvider()
    pipeline += gp.Unsqueeze([raw_0, raw_1])

    pipeline += PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0,
                             locations_1)

    # How does prepare batch relate to Stack?????
    pipeline += RejectArray(ensure_nonempty=locations_1)
    pipeline += RejectArray(ensure_nonempty=locations_0)

    # batch content
    # raw_0:          (1, h, w)
    # raw_1:          (1, h, w)
    # locations_0:    (n, 2)
    # locations_1:    (n, 2)

    pipeline += gp.Stack(batch_size)

    # batch content
    # raw_0:          (b, 1, h, w)
    # raw_1:          (b, 1, h, w)
    # locations_0:    (b, n, 2)
    # locations_1:    (b, n, 2)

    pipeline += gp.PreCache(num_workers=10)

    pipeline += gp.torch.Train(
        model,
        loss,
        optimizer,
        inputs={
            'raw_0': raw_0,
            'raw_1': raw_1
        },
        loss_inputs={
            'emb_0': emb_0,
            'emb_1': emb_1,
            'locations_0': locations_0,
            'locations_1': locations_1
        },
        outputs={
            2: emb_0,
            3: emb_1
        },
        array_specs={
            emb_0: gp.ArraySpec(voxel_size=emb_voxel_size),
            emb_1: gp.ArraySpec(voxel_size=emb_voxel_size)
        },
        checkpoint_basename=os.path.join(out_dir, 'model'),
        save_every=checkpoint_interval)

    pipeline += gp.Snapshot(
        {
            raw_0: 'raw_0',
            raw_1: 'raw_1',
            emb_0: 'emb_0',
            emb_1: 'emb_1',
            # locations_0 : 'locations_0',
            # locations_1 : 'locations_1',
        },
        every=snapshot_interval,
        additional_request=snapshot_request)

    return pipeline, request
Example #23
0
    def create_train_pipeline(self, model):

        print(f"Creating training pipeline with batch size \
              {self.params['batch_size']}")

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        points = gp.GraphKey("POINTS")
        locations = gp.ArrayKey("LOCATIONS")
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_roi = gp.Coordinate(model.base_encoder.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_roi = out_roi * out_voxel_size
        out_shape = gp.Coordinate(
            (self.params["num_points"], *model.out_shape[2:]))

        context = (in_shape - out_roi) / 2
        gt_labels_out_shape = out_roi
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            gt_labels_out_shape = (1, *gt_labels_out_shape)

            points_roi = out_voxel_size * tuple((*self.params["point_roi"], ))
            points_pad = (0, *points_roi)
            context = gp.Coordinate((0, None, None))
        else:
            points_roi = source_voxel_size * tuple(self.params["point_roi"])
            points_pad = points_roi
            context = gp.Coordinate((None, None, None))

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")
        logger.info(f"out_voxel_size: {out_voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(points, points_roi)
        request.add(gt_labels, out_roi)
        request[locations] = gp.ArraySpec(nonspatial=True)
        request[predictions] = gp.ArraySpec(nonspatial=True)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))

        source = (
            (gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                           array_specs={
                               raw:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size,
                                            interpolatable=True),
                               gt_labels:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size)
                           }),
             PointsLabelsSource(points, self.data, scale=source_voxel_size)) +
            gp.MergeProvider() + gp.Pad(raw, context) +
            gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) +
            gp.RandomLocation(ensure_nonempty=points) +
            gp.Normalize(raw, self.params['norm_factor'])
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            # If 2d then source_roi = (1, input_shape) in order to select a RL
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # Batches seem to be rejected because points are chosen near the
            # edge of the points ROI and the augmentations remove them.
            # TODO: Figure out if this is an actual issue, and if anything can
            # be done.
            gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) +
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            AddChannelDim(raw) + AddChannelDim(gt_labels)
            # raw      : (c=1, source_roi)
            # gt_labels: (c=2, source_roi)
            # points   : (c=1, source_locations_shape)
        )

        if is_2d:
            pipeline = (
                # Remove extra dim the 2d roi had
                pipeline + RemoveSpatialDim(raw) +
                RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points)
                # raw      : (c=1, roi)
                # gt_labels: (c=1, roi)
                # points   : (c=1, locations_shape)
            )

        pipeline = (
            pipeline +
            FillLocations(raw, points, locations, is_2d=False, max_points=1) +
            gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, roi)
            # gt_labels: (b, c=1, roi)
            # locations: (b, c=1, locations_shape)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={
                    'raw': raw,
                    'points': locations
                },
                loss_inputs={
                    0: predictions,
                    1: gt_labels,
                    2: locations
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(nonspatial=True),
                    emb: gp.ArraySpec(voxel_size=out_voxel_size)
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, roi)
            # gt_labels  : (b, c=1, roi)
            # predictions: (b, num_points)
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            InspectBatch('END') + gp.PrintProfilingStats(every=500))

        return pipeline, request
Example #24
0
def train_until(**kwargs):
    print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"])
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_labels = gp.ArrayKey('GT_LABELS')
    gt_instances = gp.ArrayKey('GT_INSTANCES')
    gt_affs = gp.ArrayKey('GT_AFFS')
    gt_numinst = gp.ArrayKey('GT_NUMINST')
    gt_sample_mask = gp.ArrayKey('GT_SAMPLE_MASK')

    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')
    pred_numinst = gp.ArrayKey('PRED_NUMINST')

    with open(os.path.join(kwargs['output_folder'],
                           kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['output_folder'],
                           kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size
    context = gp.Coordinate(input_shape_world - output_shape_world) / 2

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_labels, output_shape_world)
    request.add(gt_instances, output_shape_world)
    request.add(gt_sample_mask, output_shape_world)
    request.add(gt_affs, output_shape_world)
    if kwargs['overlapping_inst']:
        request.add(gt_numinst, output_shape_world)
    # request.add(loss_weights_affs, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(pred_affs, output_shape_world)
    if kwargs['overlapping_inst']:
        snapshot_request.add(pred_numinst, output_shape_world)
    # snapshot_request.add(pred_affs_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    raw_key = kwargs.get('raw_key', 'volumes/raw')
    print('raw key: ', raw_key)

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')[raw_key]
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')[raw_key]
        # print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    neighborhood = []
    psH = np.array(kwargs['patchshape'])//2
    for i in range(-psH[1], psH[1]+1, kwargs['patchstride'][1]):
        for j in range(-psH[2], psH[2]+1, kwargs['patchstride'][2]):
            neighborhood.append([i,j])

    datasets = {
        raw: raw_key,
        gt_labels: 'volumes/gt_labels',
        gt_instances: 'volumes/gt_instances'
    }
    array_specs = {
        raw: gp.ArraySpec(interpolatable=True),
        gt_labels: gp.ArraySpec(interpolatable=False),
        gt_instances: gp.ArraySpec(interpolatable=False)
    }
    inputs = {
        net_names['raw']: raw,
        net_names['gt_affs']: gt_affs,
        # net_names['loss_weights_affs']: loss_weights_affs,
    }

    outputs = {
        net_names['pred_affs']: pred_affs,
        net_names['raw_cropped']: raw_cropped,
    }
    snapshot = {
        raw: '/volumes/raw',
        raw_cropped: 'volumes/raw_cropped',
        gt_affs: '/volumes/gt_affs',
        pred_affs: '/volumes/pred_affs',
        pred_affs_gradients: '/volumes/pred_affs_gradients',
    }
    if kwargs['overlapping_inst']:
        datasets[gt_numinst] = 'volumes/gt_numinst'
        array_specs[gt_numinst] = gp.ArraySpec(interpolatable=False)
        inputs[net_names['gt_numinst']] = gt_numinst
        outputs[net_names['pred_numinst']] = pred_numinst
        snapshot[gt_numinst] = '/volumes/gt_numinst'
        snapshot[pred_numinst] = '/volumes/pred_numinst'

    augmentation = kwargs['augmentation']
    sampling = kwargs['sampling']

    source_fg = tuple(
        sourceNode(
            fls[t] + "." + kwargs['input_format'],
            datasets=datasets,
            array_specs=array_specs
        ) +
        gp.Pad(raw, context) +

        # chose a random location for each requested batch
        nl.CountOverlap(gt_labels, gt_sample_mask, maxnuminst=1) +
        gp.RandomLocation(
            min_masked=sampling['min_masked'],
            mask=gt_sample_mask
        )
        for t in range(ln)
    )
    source_fg += gp.RandomProvider()

    source_overlap = tuple(
        sourceNode(
            fls[t] + "." + kwargs['input_format'],
            datasets=datasets,
            array_specs=array_specs
        ) +
        gp.Pad(raw, context) +

        # chose a random location for each requested batch
        nl.MaskCloseDistanceToOverlap(
            gt_labels, gt_sample_mask,
            sampling['overlap_min_dist'],
            sampling['overlap_max_dist']
        ) +
        gp.RandomLocation(
            min_masked=sampling['min_masked_overlap'],
            mask=gt_sample_mask
        )
        for t in range(ln)
    )
    source_overlap += gp.RandomProvider()

    pipeline = (
        (source_fg, source_overlap) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider(probabilities=[sampling['probability_fg'],
                                         sampling['probability_overlap']]) +

        # elastically deform the batch
        gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0]) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(
            mirror_only=augmentation['simple'].get("mirror"),
            transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        gp.IntensityScaleShift(raw, 2, -1) +

        # convert labels into affinities between voxels
        nl.AddAffinities(
            neighborhood,
            gt_labels,
            gt_affs,
            multiple_labels=kwargs['overlapping_inst']) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs=inputs,
            outputs=outputs,
            gradients={
                net_names['pred_affs']: pred_affs_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            snapshot,
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info(
                "Batch: iteration=%d, time=%f",
                i, time_of_iteration)
            # exit()
    print("Training finished")
Example #25
0
    def make_pipeline(self):
        raw = gp.ArrayKey('RAW')
        pred_affs = gp.ArrayKey('PREDICTIONS')

        source_shape = zarr.open(self.data_file)[self.dataset].shape
        raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:])

        data = daisy.open_ds(self.data_file, self.dataset)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(self.model.in_shape)
        out_shape = gp.Coordinate(self.model.out_shape[2:])

        is_2d = in_shape.dims() == 2

        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(pred_affs, out_shape)

        context = (in_shape - out_shape) / 2

        source = (gp.ZarrSource(self.data_file, {
            raw: self.dataset,
        },
                                array_specs={
                                    raw:
                                    gp.ArraySpec(roi=source_roi,
                                                 interpolatable=False)
                                }))

        in_dims = len(self.model.in_shape)
        if is_2d:
            # 2D: [samples, y, x] or [samples, channels, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 1)
            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=1))
            # raw [samples, channels, y, x]
        else:
            # 3D: [z, y, x] or [channel, z, y, x] or [sample, channel, z, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 0)
            needs_batch_fix = (len(data.shape) - in_dims <= 1)

            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=0))
            # Batch fix
            if needs_batch_fix:
                source = (source + AddChannelDim(raw))
            # raw: [sample, channels, z, y, x]

        with gp.build(source):
            raw_roi = source.spec[raw].roi
            logger.info(f"raw_roi: {raw_roi}")

        pipeline = (source +
                    gp.Normalize(raw, factor=self.params['norm_factor']) +
                    gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict(
                        self.model,
                        inputs={'raw': raw},
                        outputs={0: pred_affs},
                        array_specs={pred_affs: gp.ArraySpec(roi=raw_roi)}))

        pipeline = (pipeline + gp.ZarrWrite({
            pred_affs: 'predictions',
        },
                                            output_dir=self.curr_log_dir,
                                            output_filename='predictions.zarr',
                                            compression_type='gzip') +
                    gp.Scan(request))

        return pipeline, request, pred_affs
Example #26
0
def predict(data_dir,
            train_dir,
            iteration,
            sample,
            test_net_name='train_net',
            train_net_name='train_net',
            output_dir='.',
            clip_max=1000):

    if "hdf" not in data_dir:
        return

    print("Predicting ", sample)
    print(
        'checkpoint: ',
        os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration))

    checkpoint = os.path.join(train_dir,
                              train_net_name + '_checkpoint_%d' % iteration)

    with open(os.path.join(train_dir, test_net_name + '_config.json'),
              'r') as f:
        net_config = json.load(f)

    with open(os.path.join(train_dir, test_net_name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    # ArrayKeys
    raw = gp.ArrayKey('RAW')
    pred_mask = gp.ArrayKey('PRED_MASK')

    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    voxel_size = gp.Coordinate((1, 1, 1))
    context = gp.Coordinate(input_shape - output_shape) / 2

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape, voxel_size=voxel_size)
    request.add(pred_mask, output_shape, voxel_size=voxel_size)

    print("chunk request %s" % request)

    source = (gp.Hdf5Source(
        data_dir,
        datasets={
            raw: sample + '/raw',
        },
        array_specs={
            raw:
            gp.ArraySpec(
                interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
        },
    ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) +
              gp.Normalize(raw, factor=1.0 / clip_max) +
              gp.IntensityScaleShift(raw, 2, -1))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        print("raw_roi: %s" % raw_roi)
        sample_shape = raw_roi.grow(-context, -context).get_shape()

    print(sample_shape)

    # create zarr file with corresponding chunk size
    zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w')

    zf.create('volumes/pred_mask',
              shape=sample_shape,
              chunks=output_shape,
              dtype=np.float16)
    zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1]

    pipeline = (
        source + gp.tensorflow.Predict(
            graph=os.path.join(train_dir, test_net_name + '.meta'),
            checkpoint=checkpoint,
            inputs={
                net_names['raw']: raw,
            },
            outputs={
                net_names['pred']: pred_mask,
            },
            array_specs={
                pred_mask:
                gp.ArraySpec(roi=raw_roi.grow(-context, -context),
                             voxel_size=voxel_size),
            },
            max_shared_memory=1024 * 1024 * 1024) +
        Convert(pred_mask, np.float16) + gp.ZarrWrite(
            dataset_names={
                pred_mask: 'volumes/pred_mask',
            },
            output_dir=output_dir,
            output_filename=sample + '.zarr',
            compression_type='gzip',
            dataset_dtypes={pred_mask: np.float16}) +

        # show a summary of time spend in each node every x iterations
        gp.PrintProfilingStats(every=100) +
        gp.Scan(reference=request, num_workers=5, cache_size=50))

    with gp.build(pipeline):

        pipeline.request_batch(gp.BatchRequest())
Example #27
0
def predict_3d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # ensure validation ROI is at least the size of the network input
    roi = raw_data.roi.grow(input_size / 2, input_size / 2)

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=roi)
    total_request[prediction] = gp.ArraySpec(roi=roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=roi)
        total_request[target] = gp.ArraySpec(roi=roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Example #28
0
def predict(**kwargs):
    name = kwargs['name']

    raw = gp.ArrayKey('RAW')
    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_numinst = gp.ArrayKey('PRED_NUMINST')

    with open(os.path.join(kwargs['input_folder'], name + '_config.json'),
              'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['input_folder'], name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size
    context = (input_shape_world - output_shape_world) // 2
    chunksize = list(np.asarray(output_shape_world) // 2)

    raw_key = kwargs.get('raw_key', 'volumes/raw')

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape_world, voxel_size=voxel_size)
    request.add(pred_affs, output_shape_world, voxel_size=voxel_size)
    if kwargs['overlapping_inst']:
        request.add(pred_numinst, output_shape_world, voxel_size=voxel_size)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("predict node for %s not implemented yet",
                                  kwargs['input_format'])
    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
        with h5py.File(
                os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"),
                'r') as f:
            shape = f[raw_key].shape[1:]
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource
        f = zarr.open(
            os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"),
            'r')
        shape = f[raw_key].shape[1:]
    source = sourceNode(os.path.join(
        kwargs['data_folder'],
        kwargs['sample'] + "." + kwargs['input_format']),
                        datasets={raw: raw_key})

    if kwargs['output_format'] != "zarr":
        raise NotImplementedError("Please use zarr as prediction output")

    # open zarr file
    zf = zarr.open(os.path.join(kwargs['output_folder'],
                                kwargs['sample'] + '.zarr'),
                   mode='w')
    zf.create('volumes/pred_affs',
              shape=[int(np.prod(kwargs['patchshape']))] + list(shape),
              chunks=[int(np.prod(kwargs['patchshape']))] + list(chunksize),
              dtype=np.float16)
    zf['volumes/pred_affs'].attrs['offset'] = [0, 0]
    zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size']

    if kwargs['overlapping_inst']:
        zf.create('volumes/pred_numinst',
                  shape=[int(kwargs['max_num_inst']) + 1] + list(shape),
                  chunks=[int(kwargs['max_num_inst']) + 1] + list(chunksize),
                  dtype=np.float16)
        zf['volumes/pred_numinst'].attrs['offset'] = [0, 0]
        zf['volumes/pred_numinst'].attrs['resolution'] = kwargs['voxel_size']

    outputs = {
        net_names['pred_affs']: pred_affs,
    }
    outVolumes = {
        pred_affs: '/volumes/pred_affs',
    }
    if kwargs['overlapping_inst']:
        outputs[net_names['pred_numinst']] = pred_numinst
        outVolumes[pred_numinst] = '/volumes/pred_numinst'

    pipeline = (
        source + gp.Pad(raw, context) + gp.IntensityScaleShift(raw, 2, -1) +
        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(graph=os.path.join(kwargs['input_folder'],
                                                 name + '.meta'),
                              checkpoint=kwargs['checkpoint'],
                              inputs={net_names['raw']: raw},
                              outputs=outputs) +

        # store all passing batches in the same HDF5 file
        gp.ZarrWrite(outVolumes,
                     output_dir=kwargs['output_folder'],
                     output_filename=kwargs['sample'] + ".zarr",
                     compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=100) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request))

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
Example #29
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_mask = gp.ArrayKey('GT_MASK')
    gt_dt = gp.ArrayKey('GT_DT')
    pred_dt = gp.ArrayKey('PRED_DT')
    loss_gradient = gp.ArrayKey('LOSS_GRADIENT')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_mask, output_shape)
    request.add(gt_dt, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    snapshot_request.add(gt_dt, output_shape)
    snapshot_request.add(pred_dt, output_shape)
    snapshot_request.add(loss_gradient, output_shape)

    # specify data source
    data_sources = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw: sample + '/raw',
                        gt_mask: sample + '/fg'
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask, np.uint8) +
                gp.Pad(raw, context) +
                gp.Pad(gt_mask, context) +
                gp.RandomLocation()
                for sample in f)

    pipeline = (
            data_sources +
            gp.RandomProvider() +
            gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) +
            DistanceTransform(gt_mask, gt_dt, 3) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2,-1) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=5) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt_dt']: gt_dt,
                },
                outputs={
                    net_names['pred_dt']: pred_dt,
                },
                gradients={
                    net_names['pred_dt']: loss_gradient,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    gt_mask: 'volumes/gt_mask',
                    gt_dt: 'volumes/gt_dt',
                    pred_dt: 'volumes/pred_dt',
                    loss_gradient: 'volumes/gradient',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2000) +
            gp.PrintProfilingStats(every=500)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Example #30
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')

    points = gp.PointsKey('POINTS')
    gt_cp = gp.ArrayKey('GT_CP')
    pred_cp = gp.ArrayKey('PRED_CP')
    pred_cp_gradients = gp.ArrayKey('PRED_CP_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_cp, output_shape_world)
    request.add(anchor, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(gt_cp, output_shape_world)
    snapshot_request.add(pred_cp, output_shape_world)
    # snapshot_request.add(pred_cp_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    shapes = []
    mn = []
    mx = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        mn.append(np.min(vol))
        mx.append(np.max(vol))
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    sources = tuple(
        (sourceNode(fls[t] + "." + kwargs['input_format'],
                    datasets={
                        raw: 'volumes/raw',
                        anchor: 'volumes/gt_fgbg',
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True),
                        anchor: gp.ArraySpec(interpolatable=False)
                    }),
         gp.CsvIDPointsSource(fls[t] + ".csv",
                              points,
                              points_spec=gp.PointsSpec(
                                  roi=gp.Roi(gp.Coordinate((
                                      0, 0, 0)), gp.Coordinate(shapes[t]))))) +
        gp.MergeProvider()
        # + Clip(raw, mn=mn[t], mx=mx[t])
        # + NormalizeMinMax(raw, mn=mn[t], mx=mx[t])
        + gp.Pad(raw, None) + gp.Pad(points, None)

        # chose a random location for each requested batch
        + gp.RandomLocation() for t in range(ln))
    pipeline = (
        sources +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

       # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +
        # (gp.SimpleAugment(
        #     mirror_only=augmentation['simple'].get("mirror"),
        #     transpose_only=augmentation['simple'].get("transpose")) \
        # if augmentation.get('simple') is not None and \
        #    augmentation.get('simple') != {} else NoOp())  +

        # # scale and shift the intensity of the raw array
        (gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) \
        if augmentation.get('intensity') is not None and \
           augmentation.get('intensity') != {} else NoOp())  +

        gp.RasterizePoints(
            points,
            gt_cp,
            array_spec=gp.ArraySpec(voxel_size=voxel_size),
            settings=gp.RasterizationSettings(
                radius=(2, 2, 2),
                mode='peak')) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['gt_cp']: gt_cp,
                net_names['anchor']: anchor,
            },
            outputs={
                net_names['pred_cp']: pred_cp,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                # net_names['pred_cp']: pred_cp_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_cp: '/volumes/gt_cp',
                pred_cp: '/volumes/pred_cp',
                # pred_cp_gradients: '/volumes/pred_cp_gradients',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")