Exemple #1
0
 def prepare(self, request):
     deps = gp.BatchRequest()
     deps[self.array] = request[self.array].copy()
     return deps
Exemple #2
0
 def prepare(self, request):
     deps = gp.BatchRequest()
     deps[self.gt] = request[self.gt_binary].copy()
     return deps
Exemple #3
0
 def prepare(self, request):
     deps = gp.BatchRequest()
     deps[self.array] = copy.deepcopy(request[self.mask])
     return deps
Exemple #4
0
 def prepare(self, request):
     deps = gp.BatchRequest()
     deps[self.points] = request[self.points].copy()
     return deps
Exemple #5
0
def predict_volume(model,
                   dataset,
                   out_dir,
                   out_filename,
                   out_ds_names,
                   input_key='0/raw',
                   normalize_factor=None,
                   model_output=0,
                   in_shape=None,
                   out_shape=None,
                   spawn_subprocess=True,
                   num_workers=0):

    raw = gp.ArrayKey('RAW')
    prediction = gp.ArrayKey('PREDICTION')

    data = daisy.open_ds(dataset.filename, dataset.ds_names[0])
    source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
    voxel_size = gp.Coordinate(data.voxel_size)
    data_dims = len(data.shape)

    # Get in and out shape
    if in_shape is None:
        in_shape = model.in_shape
    if out_shape is None:
        out_shape = model.out_shape

    in_shape = gp.Coordinate(in_shape)
    out_shape = gp.Coordinate(out_shape)
    spatial_dims = in_shape.dims()

    if apply_voxel_size:
        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

    logger.info(f"source roi: {source_roi}")
    logger.info(f"in_shape: {in_shape}")
    logger.info(f"out_shape: {out_shape}")
    logger.info(f"voxel_size: {voxel_size}")

    request = gp.BatchRequest()
    request.add(raw, in_shape)
    request.add(prediction, out_shape)

    context = (in_shape - out_shape) / 2

    print("context", context, in_shape, out_shape)

    source = (gp.ZarrSource(
        dataset.filename, {
            raw: dataset.ds_names[0],
        },
        array_specs={raw: gp.ArraySpec(roi=source_roi, interpolatable=True)}))

    num_additional_channels = (2 + spatial_dims) - data_dims

    for _ in range(num_additional_channels):
        source += AddChannelDim(raw)

    # prediction requires samples first, channels second
    source += TransposeDims(raw, (1, 0) + tuple(range(2, 2 + spatial_dims)))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        logger.info(f"raw_roi: {raw_roi}")

    pipeline = source

    if normalize_factor != "skip":
        pipeline = pipeline + gp.Normalize(raw, factor=normalize_factor)

    pipeline = pipeline + (gp.Pad(raw, context) + gp.torch.Predict(
        model,
        inputs={input_name: raw},
        outputs={model_output: prediction},
        array_specs={prediction: gp.ArraySpec(roi=raw_roi)},
        checkpoint=checkpoint,
        spawn_subprocess=spawn_subprocess))

    # # remove sample dimension for 3D data
    # pipeline += RemoveChannelDim(raw)
    # pipeline += RemoveChannelDim(prediction)

    pipeline += (gp.ZarrWrite({
        prediction: out_ds_names[0],
    },
                              output_dir=out_dir,
                              output_filename=out_filename,
                              compression_type='gzip') +
                 gp.Scan(request, num_workers=num_workers))

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_labels = gp.ArrayKey('GT_LABELS')
    gt_affs = gp.ArrayKey('GT_AFFS')
    gt_fgbg = gp.ArrayKey('GT_FGBG')

    # loss_weights_affs = gp.ArrayKey('LOSS_WEIGHTS_AFFS')
    loss_weights_fgbg = gp.ArrayKey('LOSS_WEIGHTS_FGBG')

    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_fgbg = gp.ArrayKey('PRED_FGBG')

    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')
    pred_fgbg_gradients = gp.ArrayKey('PRED_FGBG_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_labels, output_shape_world)
    request.add(gt_fgbg, output_shape_world)
    request.add(anchor, output_shape_world)
    request.add(gt_affs, output_shape_world)
    # request.add(loss_weights_affs, output_shape_world)
    request.add(loss_weights_fgbg, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(pred_affs, output_shape_world)
    # snapshot_request.add(pred_affs_gradients, output_shape_world)
    snapshot_request.add(gt_fgbg, output_shape_world)
    snapshot_request.add(pred_fgbg, output_shape_world)
    # snapshot_request.add(pred_fgbg_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    # padR = 46
    # padGT = 32

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    pipeline = (
        tuple(
            sourceNode(
                fls[t] + "." + kwargs['input_format'],
                datasets={
                    raw: 'volumes/raw',
                    gt_labels: 'volumes/gt_labels',
                    gt_fgbg: 'volumes/gt_fgbg',
                    anchor: 'volumes/gt_fgbg',
                },
                array_specs={
                    raw: gp.ArraySpec(interpolatable=True),
                    gt_labels: gp.ArraySpec(interpolatable=False),
                    gt_fgbg: gp.ArraySpec(interpolatable=False),
                    anchor: gp.ArraySpec(interpolatable=False)
                }
            )
            + gp.Pad(raw, None)
            + gp.Pad(gt_labels, None)
            + gp.Pad(gt_fgbg, None)

            # chose a random location for each requested batch
            + gp.RandomLocation()

            for t in range(ln)
        ) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        # grow a boundary between labels
        gp.GrowBoundary(
            gt_labels,
            steps=1,
            only_xy=False) +

        # convert labels into affinities between voxels
        gp.AddAffinities(
            [[-1, 0, 0], [0, -1, 0], [0, 0, -1]],
            gt_labels,
            gt_affs) +

        # create a weight array that balances positive and negative samples in
        # the affinity array
        # gp.BalanceLabels(
        #     gt_affs,
        #     loss_weights_affs) +

        gp.BalanceLabels(
            gt_fgbg,
            loss_weights_fgbg) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['gt_affs']: gt_affs,
                net_names['gt_fgbg']: gt_fgbg,
                net_names['anchor']: anchor,
                net_names['gt_labels']: gt_labels,
                # net_names['loss_weights_affs']: loss_weights_affs,
                net_names['loss_weights_fgbg']: loss_weights_fgbg
            },
            outputs={
                net_names['pred_affs']: pred_affs,
                net_names['pred_fgbg']: pred_fgbg,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                net_names['pred_affs']: pred_affs_gradients,
                net_names['pred_fgbg']: pred_fgbg_gradients,
            },
            malis=True,
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_labels: '/volumes/gt_labels',
                gt_affs: '/volumes/gt_affs',
                gt_fgbg: '/volumes/gt_fgbg',
                pred_affs: '/volumes/pred_affs',
                pred_affs_gradients: '/volumes/pred_affs_gradients',
                pred_fgbg: '/volumes/pred_fgbg',
                pred_fgbg_gradients: '/volumes/pred_fgbg_gradients',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')

    points = gp.PointsKey('POINTS')
    gt_cp = gp.ArrayKey('GT_CP')
    pred_cp = gp.ArrayKey('PRED_CP')
    pred_cp_gradients = gp.ArrayKey('PRED_CP_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_cp, output_shape_world)
    request.add(anchor, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(gt_cp, output_shape_world)
    snapshot_request.add(pred_cp, output_shape_world)
    # snapshot_request.add(pred_cp_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    shapes = []
    mn = []
    mx = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        mn.append(np.min(vol))
        mx.append(np.max(vol))
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    sources = tuple(
        (sourceNode(fls[t] + "." + kwargs['input_format'],
                    datasets={
                        raw: 'volumes/raw',
                        anchor: 'volumes/gt_fgbg',
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True),
                        anchor: gp.ArraySpec(interpolatable=False)
                    }),
         gp.CsvIDPointsSource(fls[t] + ".csv",
                              points,
                              points_spec=gp.PointsSpec(
                                  roi=gp.Roi(gp.Coordinate((
                                      0, 0, 0)), gp.Coordinate(shapes[t]))))) +
        gp.MergeProvider()
        # + Clip(raw, mn=mn[t], mx=mx[t])
        # + NormalizeMinMax(raw, mn=mn[t], mx=mx[t])
        + gp.Pad(raw, None) + gp.Pad(points, None)

        # chose a random location for each requested batch
        + gp.RandomLocation() for t in range(ln))
    pipeline = (
        sources +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

       # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +
        # (gp.SimpleAugment(
        #     mirror_only=augmentation['simple'].get("mirror"),
        #     transpose_only=augmentation['simple'].get("transpose")) \
        # if augmentation.get('simple') is not None and \
        #    augmentation.get('simple') != {} else NoOp())  +

        # # scale and shift the intensity of the raw array
        (gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) \
        if augmentation.get('intensity') is not None and \
           augmentation.get('intensity') != {} else NoOp())  +

        gp.RasterizePoints(
            points,
            gt_cp,
            array_spec=gp.ArraySpec(voxel_size=voxel_size),
            settings=gp.RasterizationSettings(
                radius=(2, 2, 2),
                mode='peak')) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['gt_cp']: gt_cp,
                net_names['anchor']: anchor,
            },
            outputs={
                net_names['pred_cp']: pred_cp,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                # net_names['pred_cp']: pred_cp_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_cp: '/volumes/gt_cp',
                pred_cp: '/volumes/pred_cp',
                # pred_cp_gradients: '/volumes/pred_cp_gradients',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
Exemple #8
0
def train(iterations):

    ##################
    # DECLARE ARRAYS #
    ##################

    # raw intensities
    raw = gp.ArrayKey('RAW')

    # objects labelled with unique IDs
    gt_labels = gp.ArrayKey('LABELS')

    # array of per-voxel affinities to direct neighbors
    gt_affs = gp.ArrayKey('AFFINITIES')

    # weights to use to balance the loss
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    # the predicted affinities
    pred_affs = gp.ArrayKey('PRED_AFFS')

    # the gredient of the loss wrt to the predicted affinities
    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')

    ####################
    # DECLARE REQUESTS #
    ####################

    with open('train_net_config.json', 'r') as f:
        net_config = json.load(f)

    # get the input and output size in world units (nm, in this case)
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt_affs, output_size)
    request.add(loss_weights, output_size)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request[pred_affs] = request[gt_affs]
    snapshot_request[pred_affs_gradients] = request[gt_affs]

    ##############################
    # ASSEMBLE TRAINING PIPELINE #
    ##############################

    pipeline = (

        # a tuple of sources, one for each sample (A, B, and C) provided by the
        # CREMI challenge
        tuple(

            # read batches from the HDF5 file
            gp.Hdf5Source('sample_' + s + '_padded_20160501.hdf',
                          datasets={
                              raw: 'volumes/raw',
                              gt_labels: 'volumes/labels/neuron_ids'
                          }) +

            # convert raw to float in [0, 1]
            gp.Normalize(raw) +

            # chose a random location for each requested batch
            gp.RandomLocation() for s in ['A', 'B', 'C']) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        gp.ElasticAugment([4, 40, 40], [0, 2, 2], [0, math.pi / 2.0],
                          prob_slip=0.05,
                          prob_shift=0.05,
                          max_misalign=25) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(transpose_only=[1, 2]) +

        # scale and shift the intensity of the raw array
        gp.IntensityAugment(raw,
                            scale_min=0.9,
                            scale_max=1.1,
                            shift_min=-0.1,
                            shift_max=0.1,
                            z_section_wise=True) +

        # grow a boundary between labels
        gp.GrowBoundary(gt_labels, steps=3, only_xy=True) +

        # convert labels into affinities between voxels
        gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels,
                         gt_affs) +

        # create a weight array that balances positive and negative samples in
        # the affinity array
        gp.BalanceLabels(gt_affs, loss_weights) +

        # pre-cache batches from the point upstream
        gp.PreCache(cache_size=10, num_workers=5) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            'train_net',
            net_config['optimizer'],
            net_config['loss'],
            inputs={
                net_config['raw']: raw,
                net_config['gt_affs']: gt_affs,
                net_config['loss_weights']: loss_weights
            },
            outputs={net_config['pred_affs']: pred_affs},
            gradients={net_config['pred_affs']: pred_affs_gradients},
            save_every=1) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                gt_labels: '/volumes/labels/neuron_ids',
                gt_affs: '/volumes/labels/affs',
                pred_affs: '/volumes/pred_affs',
                pred_affs_gradients: '/volumes/pred_affs_gradients'
            },
            output_dir='snapshots',
            output_filename='batch_{iteration}.hdf',
            every=100,
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10))

    #########
    # TRAIN #
    #########

    print("Training for", iterations, "iterations")

    with gp.build(pipeline):
        for i in range(iterations):
            pipeline.request_batch(request)

    print("Finished")
Exemple #9
0
    def create_train_pipeline(self, model):

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        filename = self.params['data_file']
        datasets = self.params['dataset']

        raw_0 = gp.ArrayKey('RAW_0')
        points_0 = gp.GraphKey('POINTS_0')
        locations_0 = gp.ArrayKey('LOCATIONS_0')
        emb_0 = gp.ArrayKey('EMBEDDING_0')
        raw_1 = gp.ArrayKey('RAW_1')
        points_1 = gp.GraphKey('POINTS_1')
        locations_1 = gp.ArrayKey('LOCATIONS_1')
        emb_1 = gp.ArrayKey('EMBEDDING_1')

        data = daisy.open_ds(filename, datasets[0])
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_shape = gp.Coordinate(model.out_shape[2:])
        is_2d = in_shape.dims() == 2

        emb_voxel_size = voxel_size

        cv_loss = ContrastiveVolumeLoss(self.params['temperature'],
                                        self.params['point_density'],
                                        out_shape * voxel_size)

        # Add fake 3rd dim
        if is_2d:
            in_shape = gp.Coordinate((1, *in_shape))
            out_shape = gp.Coordinate((1, *out_shape))
            voxel_size = gp.Coordinate((1, *voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (data.shape[0], *source_roi.get_shape()))

        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw_0, in_shape)
        request.add(raw_1, in_shape)
        request.add(points_0, out_shape)
        request.add(points_1, out_shape)
        request[locations_0] = gp.ArraySpec(nonspatial=True)
        request[locations_1] = gp.ArraySpec(nonspatial=True)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb_0] = gp.ArraySpec(roi=request[points_0].roi)
        snapshot_request[emb_1] = gp.ArraySpec(roi=request[points_1].roi)

        random_point_generator = RandomPointGenerator(
            density=self.params['point_density'], repetitions=2)

        # Use volume to calculate probabilities, RandomSourceGenerator will
        # normalize volumes to probablilties
        probabilities = np.array([
            np.product(daisy.open_ds(filename, dataset).shape)
            for dataset in datasets
        ])
        random_source_generator = RandomSourceGenerator(
            num_sources=len(datasets),
            probabilities=probabilities,
            repetitions=2)

        array_sources = tuple(
            tuple(
                gp.ZarrSource(
                    filename,
                    {raw: dataset},
                    # fake 3D data
                    array_specs={
                        raw:
                        gp.ArraySpec(roi=source_roi,
                                     voxel_size=voxel_size,
                                     interpolatable=True)
                    }) for dataset in datasets) for raw in [raw_0, raw_1])

        # Choose a random dataset to pull from
        array_sources = \
            tuple(arrays +
                  RandomMultiBranchSource(random_source_generator) +
                  gp.Normalize(raw, self.params['norm_factor']) +
                  gp.Pad(raw, None)
                  for raw, arrays
                  in zip([raw_0, raw_1], array_sources))

        point_sources = tuple(
            (RandomPointSource(points_0,
                               random_point_generator=random_point_generator),
             RandomPointSource(points_1,
                               random_point_generator=random_point_generator)))

        # Merge the point and array sources together.
        # There is one array and point source per branch.
        sources = tuple((array_source, point_source) + gp.MergeProvider()
                        for array_source, point_source in zip(
                            array_sources, point_sources))

        sources = tuple(
            self._make_train_augmentation_pipeline(raw, source)
            for raw, source in zip([raw_0, raw_1], sources))

        pipeline = (sources + gp.MergeProvider() + gp.Crop(raw_0, source_roi) +
                    gp.Crop(raw_1, source_roi) + gp.RandomLocation() +
                    PrepareBatch(raw_0, raw_1, points_0, points_1, locations_0,
                                 locations_1, is_2d) +
                    RejectArray(ensure_nonempty=locations_0) +
                    RejectArray(ensure_nonempty=locations_1))

        if not is_2d:
            pipeline = (pipeline + AddChannelDim(raw_0) + AddChannelDim(raw_1))

        pipeline = (pipeline + gp.PreCache() + gp.torch.Train(
            model,
            cv_loss,
            optimizer,
            inputs={
                'raw_0': raw_0,
                'raw_1': raw_1
            },
            loss_inputs={
                'emb_0': emb_0,
                'emb_1': emb_1,
                'locations_0': locations_0,
                'locations_1': locations_1
            },
            outputs={
                2: emb_0,
                3: emb_1
            },
            array_specs={
                emb_0: gp.ArraySpec(voxel_size=emb_voxel_size),
                emb_1: gp.ArraySpec(voxel_size=emb_voxel_size)
            },
            checkpoint_basename=self.logdir + '/contrastive/checkpoints/model',
            save_every=self.params['save_every'],
            log_dir=self.logdir + "/contrastive",
            log_every=self.log_every))

        if is_2d:
            pipeline = (
                pipeline +
                # everything is 3D, except emb_0 and emb_1
                AddSpatialDim(emb_0) + AddSpatialDim(emb_1))

        pipeline = (
            pipeline +
            # now everything is 3D
            RemoveChannelDim(raw_0) + RemoveChannelDim(raw_1) +
            RemoveChannelDim(emb_0) + RemoveChannelDim(emb_1) +
            gp.Snapshot(output_dir=self.logdir + '/contrastive/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw_0: 'raw_0',
                            raw_1: 'raw_1',
                            locations_0: 'locations_0',
                            locations_1: 'locations_1',
                            emb_0: 'emb_0',
                            emb_1: 'emb_1'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            gp.PrintProfilingStats(every=500))

        return pipeline, request
Exemple #10
0
def train(n_iterations):

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    gt_fg = gp.ArrayKey('GT_FP')
    embedding = gp.ArrayKey('EMBEDDING')
    fg = gp.ArrayKey('FG')
    maxima = gp.ArrayKey('MAXIMA')
    gradient_embedding = gp.ArrayKey('GRADIENT_EMBEDDING')
    gradient_fg = gp.ArrayKey('GRADIENT_FG')
    emst = gp.ArrayKey('EMST')
    edges_u = gp.ArrayKey('EDGES_U')
    edges_v = gp.ArrayKey('EDGES_V')

    request = gp.BatchRequest()
    request.add(raw, (200, 200))
    request.add(gt, (160, 160))

    snapshot_request = gp.BatchRequest()
    snapshot_request[embedding] = request[gt]
    snapshot_request[fg] = request[gt]
    snapshot_request[gt_fg] = request[gt]
    snapshot_request[maxima] = request[gt]
    snapshot_request[gradient_embedding] = request[gt]
    snapshot_request[gradient_fg] = request[gt]
    snapshot_request[emst] = gp.ArraySpec()
    snapshot_request[edges_u] = gp.ArraySpec()
    snapshot_request[edges_v] = gp.ArraySpec()

    pipeline = (
        Synthetic2DSource(raw, gt) + gp.Normalize(raw) +
        gp.tensorflow.Train('train_net',
                            optimizer=add_loss,
                            loss=None,
                            inputs={
                                tensor_names['raw']: raw,
                                tensor_names['gt_labels']: gt,
                            },
                            outputs={
                                tensor_names['embedding']: embedding,
                                tensor_names['fg']: fg,
                                'maxima:0': maxima,
                                'gt_fg:0': gt_fg,
                                emst_name: emst,
                                edges_u_name: edges_u,
                                edges_v_name: edges_v,
                            },
                            gradients={
                                tensor_names['embedding']: gradient_embedding,
                                tensor_names['fg']: gradient_fg,
                            }) +
        gp.Snapshot(output_filename='{iteration}.hdf',
                    dataset_names={
                        raw: 'volumes/raw',
                        gt: 'volumes/gt',
                        embedding: 'volumes/embedding',
                        fg: 'volumes/fg',
                        maxima: 'volumes/maxima',
                        gt_fg: 'volumes/gt_fg',
                        gradient_embedding: 'volumes/gradient_embedding',
                        gradient_fg: 'volumes/gradient_fg',
                        emst: 'emst',
                        edges_u: 'edges_u',
                        edges_v: 'edges_v',
                    },
                    dataset_dtypes={
                        maxima: np.float32,
                        gt_fg: np.float32
                    },
                    every=100,
                    additional_request=snapshot_request))

    with gp.build(pipeline):
        for i in range(n_iterations):
            pipeline.request_batch(request)
def predict_2d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = raw_data.shape
    dataset_roi = raw_data.roi
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError(
            "For 2D validation, please provide a 3D array where the first "
            "dimension indexes the samples.")

    num_samples = raw_data.num_samples

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    if gt_data:
        sources = (raw_data.get_source(raw, overwrite_spec=spec),
                   gt_data.get_source(gt, overwrite_spec=spec))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw, overwrite_spec=spec)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    # target: ([c,] s, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    if gt_data and predictor.target_channels == 0:
        pipeline += AddChannelDim(target)
    # raw: (c, s, h, w)
    # gt: ([c,] s, h, w)
    # target: (c, s, h, w)
    pipeline += TransposeDims(raw, (1, 0, 2, 3))
    if gt_data:
        pipeline += TransposeDims(target, (1, 0, 2, 3))
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    # prediction: (s, c, h, w)
    pipeline += gp.Scan(scan_request)

    total_request = gp.BatchRequest()
    total_request.add(raw, sample_size)
    total_request.add(prediction, sample_size)
    if gt_data:
        total_request.add(gt, sample_size)
        total_request.add(target, sample_size)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
def predict_3d(raw_data, gt_data, model, predictor, aux_tasks):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = model.input_shape
    output_shape = model.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    model_output = gp.ArrayKey('MODEL_OUTPUT')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=model,
                                 inputs={'x': raw},
                                 outputs={0: model_output})
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': model_output},
                                 outputs={0: prediction})
    aux_predictions = []
    for aux_name, aux_predictor, _ in aux_tasks:
        aux_pred_key = gp.ArrayKey(f"PRED_{aux_name.upper()}")
        pipeline += gp_torch.Predict(model=aux_predictor,
                                     inputs={'x': model_output},
                                     outputs={0: aux_pred_key})
        aux_predictions.append((aux_name, aux_pred_key))
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(model_output, output_size)
    scan_request.add(prediction, output_size)
    for aux_name, aux_key in aux_predictions:
        scan_request.add(aux_key, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # only output where the gt exists
    context = (input_size - output_size) / 2

    output_roi = gt_data.roi.intersect(raw_data.roi.grow(-context, -context))
    input_roi = output_roi.grow(context, context)

    assert all([a > b for a, b in zip(input_roi.get_shape(), input_size)])
    assert all([a > b for a, b in zip(output_roi.get_shape(), output_size)])

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=input_roi)
    total_request[model_output] = gp.ArraySpec(roi=output_roi)
    total_request[prediction] = gp.ArraySpec(roi=output_roi)
    for aux_name, aux_key in aux_predictions:
        total_request[aux_key] = gp.ArraySpec(roi=output_roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=output_roi)
        total_request[target] = gp.ArraySpec(roi=output_roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {
            'raw': batch[raw],
            'model_out': batch[model_output],
            'prediction': batch[prediction]
        }
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        for aux_name, aux_key in aux_predictions:
            ret[aux_name] = batch[aux_key]
        return ret
def train_until(**kwargs):
    print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"])
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_labels = gp.ArrayKey('GT_LABELS')
    gt_affs = gp.ArrayKey('GT_AFFS')

    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(pred_affs, output_shape_world)
    snapshot_request.add(gt_affs, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    neighborhood = []
    psH = np.array(kwargs['patchshape']) // 2
    for i in range(-psH[0], psH[0] + 1, kwargs['patchstride'][0]):
        for j in range(-psH[1], psH[1] + 1, kwargs['patchstride'][1]):
            for k in range(-psH[2], psH[2] + 1, kwargs['patchstride'][2]):
                neighborhood.append([i, j, k])

    datasets = {
        raw: 'volumes/raw',
        gt_labels: 'volumes/gt_labels',
        anchor: 'volumes/gt_fgbg',
    }
    input_specs = {
        raw:
        gp.ArraySpec(roi=gp.Roi((0, ) * len(input_shape_world),
                                input_shape_world),
                     interpolatable=True,
                     dtype=np.float32),
        gt_labels:
        gp.ArraySpec(roi=gp.Roi((0, ) * len(output_shape_world),
                                output_shape_world),
                     interpolatable=False,
                     dtype=np.uint16),
        anchor:
        gp.ArraySpec(roi=gp.Roi((0, ) * len(output_shape_world),
                                output_shape_world),
                     interpolatable=False,
                     dtype=np.uint8),
        gt_affs:
        gp.ArraySpec(roi=gp.Roi((0, ) * len(output_shape_world),
                                output_shape_world),
                     interpolatable=False,
                     dtype=np.uint8)
    }
    inputs = {
        net_names['raw']: raw,
        net_names['gt_affs']: gt_affs,
        net_names['anchor']: anchor,
    }

    outputs = {
        net_names['pred_affs']: pred_affs,
        net_names['raw_cropped']: raw_cropped,
    }
    snapshot = {
        raw_cropped: 'volumes/raw_cropped',
        gt_affs: '/volumes/gt_affs',
        pred_affs: '/volumes/pred_affs',
    }

    optimizer_args = None
    if kwargs['auto_mixed_precision']:
        optimizer_args = (kwargs['optimizer'], {
            'args': kwargs['args'],
            'kwargs': kwargs['kwargs']
        })
    augmentation = kwargs['augmentation']
    pipeline = (
        tuple(
            sourceNode(
                fls[t] + "." + kwargs['input_format'],
                datasets=datasets,
                # array_specs=array_specs
            )
            + gp.Pad(raw, None)
            + gp.Pad(gt_labels, None)

            # chose a random location for each requested batch
            + gp.RandomLocation()

            for t in range(ln)
        ) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=4) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +

        # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        # grow a boundary between labels
        gp.GrowBoundary(
            gt_labels,
            steps=1,
            only_xy=False) +

        # convert labels into affinities between voxels
        gp.AddAffinities(
            neighborhood,
            gt_labels,
            gt_affs) +

        # create a weight array that balances positive and negative samples in
        # the affinity array
        # gp.BalanceLabels(
        #     gt_affs,
        #     loss_weights_affs) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # pre-fetch batches from the point upstream
        (gp.tensorflow.TFData() \
         if kwargs.get('use_tf_data') else NoOp()) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs=inputs,
            outputs=outputs,
            array_specs=input_specs,
            gradients={
                net_names['pred_affs']: pred_affs_gradients,
            },
            auto_mixed_precision=kwargs['auto_mixed_precision'],
            optimizer_args=optimizer_args,
            use_tf_data=kwargs['use_tf_data'],
            save_every=kwargs['checkpoints'],
            snapshot_every=kwargs['snapshots']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            snapshot,
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    try:
        with gp.build(pipeline):
            print(pipeline)
            for i in range(trained_until, kwargs['max_iteration']):
                start = time.time()
                pipeline.request_batch(request)
                time_of_iteration = time.time() - start

                logger.info("Batch: iteration=%d, time=%f", i,
                            time_of_iteration)
            # exit()
    except KeyboardInterrupt:
        sys.exit()
    print("Training finished")
Exemple #14
0
def batch_data_aug_generator(input_path,
                             batch_size=12,
                             voxel_shape=[1, 1, 1],
                             input_shape=[240, 240, 4],
                             output_shape=[240, 240, 4],
                             without_background=False,
                             mix_output=False,
                             validate=False,
                             aug=seq):
    raw = gp.ArrayKey('raw')
    gt = gp.ArrayKey('ground_truth')
    files = os.listdir(input_path)
    files = [os.path.join(input_path, f) for f in files]
    pipeline = (
        tuple(
            gp.ZarrSource(
                files[t],  # the zarr container
                {
                    raw: 'raw',
                    gt: 'ground_truth'
                },  # which dataset to associate to the array key
                {
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 dtype=np.dtype('float32'),
                                 voxel_size=voxel_shape),
                    gt:
                    gp.ArraySpec(interpolatable=True,
                                 dtype=np.dtype('float32'),
                                 voxel_size=voxel_shape)
                }  # meta-information
            ) + gp.RandomLocation()
            for t in range(len(files))) + gp.RandomProvider()
        #    +gp.Stack(batch_size)
    )
    input_size = gp.Coordinate(input_shape)
    output_size = gp.Coordinate(output_shape)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, input_size)
    diff = input_shape[1] - output_shape[1]
    diff = int(diff / 2)
    max_p = input_shape[1] - diff
    different_shape = diff > 0
    if different_shape:
        print('Difference padding: {}'.format(diff))
    with gp.build(pipeline):
        while 1:
            b = 0
            imgs = []
            masks = []
            while b < batch_size:
                valid = False
                batch = pipeline.request_batch(request)
                if validate:
                    valid = validate_mask(batch[gt].data)
                else:
                    valid = True

                while (valid == False):
                    batch = pipeline.request_batch(request)
                    valid = validate_mask(batch[gt].data)
                im = batch[raw].data
                out = batch[gt].data
                # im,out = augmentation(im,out,seq)
                if different_shape:
                    out = out[diff:max_p, diff:max_p, :]
                if without_background:
                    out = out[:, :, 1:4]
                if mix_output:
                    out = out.argmax(axis=3).astype(float)
                imgs.append(im)
                masks.append(out)
                b = b + 1
            yield augmentation(np.asarray(imgs), np.asarray(masks), seq)
Exemple #15
0
def train_until(**kwargs):
    print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"])
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_labels = gp.ArrayKey('GT_LABELS')
    gt_instances = gp.ArrayKey('GT_INSTANCES')
    gt_affs = gp.ArrayKey('GT_AFFS')
    gt_numinst = gp.ArrayKey('GT_NUMINST')
    gt_fgbg = gp.ArrayKey('GT_FGBG')
    gt_sample_mask = gp.ArrayKey('GT_SAMPLE_MASK')

    pred_code = gp.ArrayKey('PRED_CODE')
    # pred_code_gradients = gp.ArrayKey('PRED_CODE_GRADIENTS')
    pred_numinst = gp.ArrayKey('PRED_NUMINST')
    pred_fgbg = gp.ArrayKey('PRED_FGBG')

    with open(os.path.join(kwargs['output_folder'],
                           kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['output_folder'],
                           kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size
    context = gp.Coordinate(input_shape_world - output_shape_world) / 2

    raw_key = kwargs.get('raw_key', 'volumes/raw')

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_labels, output_shape_world)
    request.add(gt_instances, output_shape_world)
    request.add(gt_sample_mask, output_shape_world)
    request.add(gt_affs, output_shape_world)
    if kwargs['overlapping_inst']:
        request.add(gt_numinst, output_shape_world)
    else:
        request.add(gt_fgbg, output_shape_world)
    # request.add(loss_weights_affs, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(pred_code, output_shape_world)
    # snapshot_request.add(pred_code_gradients, output_shape_world)
    if kwargs['overlapping_inst']:
        snapshot_request.add(pred_numinst, output_shape_world)
    else:
        snapshot_request.add(pred_fgbg, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw_bf']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw_bf']
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    neighborhood = []
    psH = np.array(kwargs['patchshape'])//2
    for i in range(-psH[1], psH[1]+1, kwargs['patchstride'][1]):
        for j in range(-psH[2], psH[2]+1, kwargs['patchstride'][2]):
            neighborhood.append([i,j])

    datasets = {
        raw: 'volumes/raw_bf',
        gt_labels: 'volumes/gt_labels',
        gt_instances: 'volumes/gt_instances',
    }
    array_specs = {
        raw: gp.ArraySpec(interpolatable=True),
        gt_labels: gp.ArraySpec(interpolatable=False),
        gt_instances: gp.ArraySpec(interpolatable=False),
    }
    inputs = {
        net_names['raw']: raw,
        net_names['gt_affs']: gt_affs,
    }

    outputs = {
        net_names['pred_code']: pred_code,
        net_names['raw_cropped']: raw_cropped,
    }
    snapshot = {
        raw: '/volumes/raw',
        raw_cropped: 'volumes/raw_cropped',
        gt_affs: '/volumes/gt_affs',
        pred_code: '/volumes/pred_code',
        # pred_code_gradients: '/volumes/pred_code_gradients',
    }
    if kwargs['overlapping_inst']:
        datasets[gt_numinst] = '/volumes/gt_numinst'
        array_specs[gt_numinst] = gp.ArraySpec(interpolatable=False)
        inputs[net_names['gt_numinst']] = gt_numinst
        outputs[net_names['pred_numinst']] = pred_numinst
        snapshot[pred_numinst] = '/volumes/pred_numinst'
    else:
        datasets[gt_fgbg] = '/volumes/gt_fgbg'
        array_specs[gt_fgbg] = gp.ArraySpec(interpolatable=False)
        inputs[net_names['gt_fgbg']] = gt_fgbg
        outputs[net_names['pred_fgbg']] = pred_fgbg
        snapshot[pred_fgbg] = '/volumes/pred_fgbg'

    augmentation = kwargs['augmentation']
    sampling = kwargs['sampling']

    source_fg = tuple(
        sourceNode(
            fls[t] + "." + kwargs['input_format'],
            datasets=datasets,
            array_specs=array_specs
        ) +
        gp.Pad(raw, context) +

        # chose a random location for each requested batch
        nl.CountOverlap(gt_labels, gt_sample_mask, maxnuminst=1) +
        gp.RandomLocation(
            min_masked=sampling['min_masked'],
            mask=gt_sample_mask
        )
        for t in range(ln)
    )
    source_fg += gp.RandomProvider()

    if kwargs['overlapping_inst']:
        source_overlap = tuple(
            sourceNode(
                fls[t] + "." + kwargs['input_format'],
                datasets=datasets,
                array_specs=array_specs
            ) +
            gp.Pad(raw, context) +

            # chose a random location for each requested batch
            nl.MaskCloseDistanceToOverlap(
                gt_labels, gt_sample_mask,
                sampling['overlap_min_dist'],
                sampling['overlap_max_dist']
            ) +
            gp.RandomLocation(
                min_masked=sampling['min_masked_overlap'],
                mask=gt_sample_mask
            )
            for t in range(ln)
        )
        source_overlap += gp.RandomProvider()

        source = (
            (source_fg, source_overlap) +

            # chose a random source (i.e., sample) from the above
            gp.RandomProvider(probabilities=[sampling['probability_fg'],
                                             sampling['probability_overlap']]))
    else:
        source = source_fg

    pipeline = (
        source +

        # elastically deform the batch
        gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0]) +

        gp.Reject(gt_sample_mask, min_masked=0.002, reject_probability=1) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(
            mirror_only=augmentation['simple'].get("mirror"),
            transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        gp.IntensityScaleShift(raw, 2, -1) +

        # convert labels into affinities between voxels
        nl.AddAffinities(
            neighborhood,
            gt_labels if kwargs['overlapping_inst'] else gt_instances,
            gt_affs,
            multiple_labels=kwargs['overlapping_inst']) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs=inputs,
            outputs=outputs,
            gradients={
                # net_names['pred_code']: pred_code_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            snapshot,
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info(
                "Batch: iteration=%d, time=%f",
                i, time_of_iteration)
            # exit()
    print("Training finished")
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_mask = gp.ArrayKey('GT_MASK')
    pred_mask = gp.ArrayKey('PRED_MASK')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_mask, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    snapshot_request.add(pred_mask, output_shape)

    # specify data source
    data_sources = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw: sample + '/raw',
                        gt_mask: sample + '/fg'
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask, np.uint8) +
                gp.Pad(raw, context) +
                gp.Pad(gt_mask, context) +
                gp.RandomLocation()
                for sample in f)

    pipeline = (
            data_sources +
            gp.RandomProvider() +
            gp.Reject(gt_mask, min_masked=0.005, reject_probability=0.98) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2,-1) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=5) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt']: gt_mask,
                },
                outputs={
                    net_names['pred']: pred_mask,
                },
                gradients={
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    pred_mask: 'volumes/pred_mask',
                    gt_mask: 'volumes/gt_mask',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2000) +
            gp.PrintProfilingStats(every=500)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
def predict(data_dir,
            train_dir,
            iteration,
            sample,
            test_net_name='train_net',
            train_net_name='train_net',
            output_dir='.',
            clip_max=1000):

    if "hdf" not in data_dir:
        return

    print("Predicting ", sample)
    print(
        'checkpoint: ',
        os.path.join(train_dir, train_net_name + '_checkpoint_%d' % iteration))

    checkpoint = os.path.join(train_dir,
                              train_net_name + '_checkpoint_%d' % iteration)

    with open(os.path.join(train_dir, test_net_name + '_config.json'),
              'r') as f:
        net_config = json.load(f)

    with open(os.path.join(train_dir, test_net_name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    # ArrayKeys
    raw = gp.ArrayKey('RAW')
    pred_mask = gp.ArrayKey('PRED_MASK')

    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    voxel_size = gp.Coordinate((1, 1, 1))
    context = gp.Coordinate(input_shape - output_shape) / 2

    # add ArrayKeys to batch request
    request = gp.BatchRequest()
    request.add(raw, input_shape, voxel_size=voxel_size)
    request.add(pred_mask, output_shape, voxel_size=voxel_size)

    print("chunk request %s" % request)

    source = (gp.Hdf5Source(
        data_dir,
        datasets={
            raw: sample + '/raw',
        },
        array_specs={
            raw:
            gp.ArraySpec(
                interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
        },
    ) + gp.Pad(raw, context) + nl.Clip(raw, 0, clip_max) +
              gp.Normalize(raw, factor=1.0 / clip_max) +
              gp.IntensityScaleShift(raw, 2, -1))

    with gp.build(source):
        raw_roi = source.spec[raw].roi
        print("raw_roi: %s" % raw_roi)
        sample_shape = raw_roi.grow(-context, -context).get_shape()

    print(sample_shape)

    # create zarr file with corresponding chunk size
    zf = zarr.open(os.path.join(output_dir, sample + '.zarr'), mode='w')

    zf.create('volumes/pred_mask',
              shape=sample_shape,
              chunks=output_shape,
              dtype=np.float16)
    zf['volumes/pred_mask'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_mask'].attrs['resolution'] = [1, 1, 1]

    pipeline = (
        source + gp.tensorflow.Predict(
            graph=os.path.join(train_dir, test_net_name + '.meta'),
            checkpoint=checkpoint,
            inputs={
                net_names['raw']: raw,
            },
            outputs={
                net_names['pred']: pred_mask,
            },
            array_specs={
                pred_mask:
                gp.ArraySpec(roi=raw_roi.grow(-context, -context),
                             voxel_size=voxel_size),
            },
            max_shared_memory=1024 * 1024 * 1024) +
        Convert(pred_mask, np.float16) + gp.ZarrWrite(
            dataset_names={
                pred_mask: 'volumes/pred_mask',
            },
            output_dir=output_dir,
            output_filename=sample + '.zarr',
            compression_type='gzip',
            dataset_dtypes={pred_mask: np.float16}) +

        # show a summary of time spend in each node every x iterations
        gp.PrintProfilingStats(every=100) +
        gp.Scan(reference=request, num_workers=5, cache_size=50))

    with gp.build(pipeline):

        pipeline.request_batch(gp.BatchRequest())
Exemple #18
0
def predict(
        iteration,
        raw_file,
        raw_dataset,
        out_file,
        db_host,
        db_name,
        worker_config,
        network_config,
        out_properties={},
        **kwargs):

    setup_dir = os.path.dirname(os.path.realpath(__file__))

    with open(os.path.join(setup_dir,
                           '{}_net_config.json'.format(network_config)),
              'r') as f:
        net_config = json.load(f)

    # voxels
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    # nm
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size

    raw = gp.ArrayKey('RAW')
    pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR')

    chunk_request = gp.BatchRequest()
    chunk_request.add(raw, input_size)
    chunk_request.add(pred_post_indicator, output_size)

    m_property = out_properties[
        'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None

    # Select Source based on filesuffix.
    # Hdf5Source
    if raw_file.endswith('.hdf'):
        pipeline = gp.Hdf5Source(
            raw_file,
            datasets={
                raw: raw_dataset
            },
            array_specs={
                raw: gp.ArraySpec(interpolatable=True),
            }
        )
    elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'):
        pipeline = gp.ZarrSource(
            raw_file,
            datasets={
                raw: raw_dataset
            },
            array_specs={
                raw: gp.ArraySpec(interpolatable=True),
            }
        )
    else:
        raise RuntimeError('unknwon input data format {}'.format(raw_file))

    pipeline += gp.Pad(raw, size=None)

    pipeline += gp.Normalize(raw)

    pipeline += gp.IntensityScaleShift(raw, 2, -1)

    pipeline += gp.tensorflow.Predict(
        os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration),
        inputs={
            net_config['raw']: raw
        },
        outputs={
            net_config['pred_syn_indicator_out']: pred_post_indicator,
        },
        graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config))
    )

    if m_property is not None and 'scale' in m_property:
        if m_property['scale'] != 1:
            pipeline += gp.IntensityScaleShift(pred_post_indicator,
                                               m_property['scale'], 0)

    pipeline += gp.ZarrWrite(
        dataset_names={
            pred_post_indicator: 'volumes/pred_syn_indicator',
        },
        output_filename=out_file
    )

    pipeline += gp.PrintProfilingStats(every=10)

    pipeline += gp.DaisyRequestBlocks(
        chunk_request,
        roi_map={
            raw: 'read_roi',
            pred_post_indicator: 'write_roi'
        },
        num_workers=worker_config['num_cache_workers'],
        block_done_callback=lambda b, s, d: block_done_callback(
            db_host,
            db_name,
            worker_config,
            b, s, d))

    print("Starting prediction...")
    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
    print("Prediction finished")
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_instances = gp.ArrayKey('GT_INSTANCES')
    gt_mask = gp.ArrayKey('GT_MASK')
    pred_mask = gp.ArrayKey('PRED_MASK')
    #loss_weights = gp.ArrayKey('LOSS_WEIGHTS')
    loss_gradients = gp.ArrayKey('LOSS_GRADIENTS')

    # array keys for base and add volume
    raw_base = gp.ArrayKey('RAW_BASE')
    gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE')
    gt_mask_base = gp.ArrayKey('GT_MASK_BASE')
    raw_add = gp.ArrayKey('RAW_ADD')
    gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD')
    gt_mask_add = gp.ArrayKey('GT_MASK_ADD')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_instances, output_shape)
    request.add(gt_mask, output_shape)
    #request.add(loss_weights, output_shape)
    request.add(raw_base, input_shape)
    request.add(raw_add, input_shape)
    request.add(gt_mask_base, output_shape)
    request.add(gt_mask_add, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    #snapshot_request.add(raw_base, input_shape)
    #snapshot_request.add(raw_add, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    #snapshot_request.add(gt_mask_base, output_shape)
    #snapshot_request.add(gt_mask_add, output_shape)
    snapshot_request.add(pred_mask, output_shape)
    snapshot_request.add(loss_gradients, output_shape)

    # specify data source
    # data source for base volume
    data_sources_base = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_base += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_base: sample + '/raw',
                        gt_instances_base: sample + '/gt',
                        gt_mask_base: sample + '/fg',
                    },
                    array_specs={
                        raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_base, np.uint8) +
                gp.Pad(raw_base, context) +
                gp.Pad(gt_instances_base, context) +
                gp.Pad(gt_mask_base, context) +
                gp.RandomLocation(min_masked=0.005,  mask=gt_mask_base)
                #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.)
                for sample in f)
    data_sources_base += gp.RandomProvider()

    # data source for add volume
    data_sources_add = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_add += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_add: sample + '/raw',
                        gt_instances_add: sample + '/gt',
                        gt_mask_add: sample + '/fg',
                    },
                    array_specs={
                        raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_add, np.uint8) +
                gp.Pad(raw_add, context) +
                gp.Pad(gt_instances_add, context) +
                gp.Pad(gt_mask_add, context) +
                gp.RandomLocation() +
                gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95)
                for sample in f)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider()

    pipeline = (
            data_sources +
            nl.FusionAugment(
                raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances,
                blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0
            ) +
            BinarizeLabels(gt_instances, gt_mask) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2, -1) +
            #gp.BalanceLabels(gt_mask, loss_weights) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=10) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt']: gt_mask,
                    #net_names['loss_weights']: loss_weights,
                },
                outputs={
                    net_names['pred']: pred_mask,
                },
                gradients={
                    net_names['output']: loss_gradients,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    pred_mask: 'volumes/pred_mask',
                    gt_mask: 'volumes/gt_mask',
                    #loss_weights: 'volumes/loss_weights',
                    loss_gradients: 'volumes/loss_gradients',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2500) +
            gp.PrintProfilingStats(every=1000)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Exemple #20
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for fused volume
    raw = gp.ArrayKey("RAW")
    labels = gp.ArrayKey("LABELS")
    labels_fg = gp.ArrayKey("LABELS_FG")

    # array keys for base volume
    raw_base = gp.ArrayKey("RAW_BASE")
    labels_base = gp.ArrayKey("LABELS_BASE")
    swc_base = gp.PointsKey("SWC_BASE")
    swc_center_base = gp.PointsKey("SWC_CENTER_BASE")

    # array keys for add volume
    raw_add = gp.ArrayKey("RAW_ADD")
    labels_add = gp.ArrayKey("LABELS_ADD")
    swc_add = gp.PointsKey("SWC_ADD")
    swc_center_add = gp.PointsKey("SWC_CENTER_ADD")

    # output data
    fg = gp.ArrayKey("FG")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    voxel_size = gp.Coordinate((4, 1, 1))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size
    output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_fg, output_size)
    request.add(loss_weights, output_size)
    request.add(swc_center_base, output_size)
    request.add(swc_center_add, output_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(fg, output_size)
    snapshot_request.add(labels_fg, output_size)
    snapshot_request.add(gradient_fg, output_size)
    snapshot_request.add(raw_base, input_size)
    snapshot_request.add(raw_add, input_size)
    snapshot_request.add(labels_base, input_size)
    snapshot_request.add(labels_add, input_size)

    # data source for "base" volume
    data_sources_base = tuple(
        (
            gp.Hdf5Source(
                filename,
                datasets={raw_base: "/volume"},
                array_specs={
                    raw_base:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
                channels_first=False,
            ),
            SwcSource(
                filename=filename,
                dataset="/reconstruction",
                points=(swc_center_base, swc_base),
                scale=voxel_size,
            ),
        ) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton(
            points=swc_base,
            array=labels_base,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            radius=5.0,
        ) for filename in files)

    # data source for "add" volume
    data_sources_add = tuple(
        (
            gp.Hdf5Source(
                file,
                datasets={raw_add: "/volume"},
                array_specs={
                    raw_add:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
                channels_first=False,
            ),
            SwcSource(
                filename=file,
                dataset="/reconstruction",
                points=(swc_center_add, swc_add),
                scale=voxel_size,
            ),
        ) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton(
            points=swc_add,
            array=labels_add,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            radius=5.0,
        ) for file in files)
    data_sources = (
        (data_sources_base + gp.RandomProvider()),
        (data_sources_add + gp.RandomProvider()),
    ) + gp.MergeProvider()

    pipeline = (
        data_sources + FusionAugment(
            raw_base,
            raw_add,
            labels_base,
            labels_add,
            raw,
            labels,
            blend_mode="labels_mask",
            blend_smoothness=10,
            num_blended_objects=0,
        ) +
        # augment
        gp.ElasticAugment([40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0],
                          subsample=4) +
        gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +
        gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(labels, labels_fg) +
        gp.BalanceLabels(labels_fg, loss_weights) +
        # train
        gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train(
            "./train_net",
            optimizer=net_names["optimizer"],
            loss=net_names["loss"],
            inputs={
                net_names["raw"]: raw,
                net_names["labels_fg"]: labels_fg,
                net_names["loss_weights"]: loss_weights,
            },
            outputs={net_names["fg"]: fg},
            gradients={net_names["fg"]: gradient_fg},
            save_every=100000,
        ) +
        # visualize
        gp.Snapshot(
            output_filename="snapshot_{iteration}.hdf",
            dataset_names={
                raw: "volumes/raw",
                raw_base: "volumes/raw_base",
                raw_add: "volumes/raw_add",
                labels: "volumes/labels",
                labels_base: "volumes/labels_base",
                labels_add: "volumes/labels_add",
                fg: "volumes/fg",
                labels_fg: "volumes/labels_fg",
                gradient_fg: "volumes/gradient_fg",
            },
            additional_request=snapshot_request,
            every=100,
        ) + gp.PrintProfilingStats(every=100))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Exemple #21
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    swcs = gp.PointsKey("SWCS")

    voxel_size = gp.Coordinate((10, 3, 3))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size * 2

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(swcs, input_size)

    data_sources = tuple((
        gp.N5Source(
            filename=str((
                filename /
                "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5"
            ).absolute()),
            datasets={raw: "volume"},
            array_specs={
                raw:
                gp.ArraySpec(interpolatable=True,
                             voxel_size=voxel_size,
                             dtype=np.uint16)
            },
        ),
        MouselightSwcFileSource(
            filename=str((
                filename /
                "consensus-neurons-with-machine-centerpoints-labelled-as-swcs/G-002.swc"
            ).absolute()),
            points=(swcs, ),
            scale=voxel_size,
            transpose=(2, 1, 0),
            transform_file=str((filename / "transform.txt").absolute()),
        ),
    ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swcs,
                                               ensure_centered=True)
                         for filename in Path(sample_dir).iterdir()
                         if "2018-08-01" in filename.name)

    pipeline = data_sources + gp.RandomProvider()

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            batch = pipeline.request_batch(request)
            vis_points_with_array(batch[raw].data,
                                  points_to_graph(batch[swcs].data),
                                  np.array(voxel_size))