Beispiel #1
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_mask = gp.ArrayKey('GT_MASK')
    gt_dt = gp.ArrayKey('GT_DT')
    pred_dt = gp.ArrayKey('PRED_DT')
    loss_gradient = gp.ArrayKey('LOSS_GRADIENT')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_mask, output_shape)
    request.add(gt_dt, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    snapshot_request.add(gt_dt, output_shape)
    snapshot_request.add(pred_dt, output_shape)
    snapshot_request.add(loss_gradient, output_shape)

    # specify data source
    data_sources = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw: sample + '/raw',
                        gt_mask: sample + '/fg'
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask, np.uint8) +
                gp.Pad(raw, context) +
                gp.Pad(gt_mask, context) +
                gp.RandomLocation()
                for sample in f)

    pipeline = (
            data_sources +
            gp.RandomProvider() +
            gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) +
            DistanceTransform(gt_mask, gt_dt, 3) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2,-1) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=5) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt_dt']: gt_dt,
                },
                outputs={
                    net_names['pred_dt']: pred_dt,
                },
                gradients={
                    net_names['pred_dt']: loss_gradient,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    gt_mask: 'volumes/gt_mask',
                    gt_dt: 'volumes/gt_dt',
                    pred_dt: 'volumes/pred_dt',
                    loss_gradient: 'volumes/gradient',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2000) +
            gp.PrintProfilingStats(every=500)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Beispiel #2
0
    def create_train_pipeline(self, model):

        print(f"Creating training pipeline with batch size \
              {self.params['batch_size']}")

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        points = gp.GraphKey("POINTS")
        locations = gp.ArrayKey("LOCATIONS")
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_roi = gp.Coordinate(model.base_encoder.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_roi = out_roi * out_voxel_size
        out_shape = gp.Coordinate(
            (self.params["num_points"], *model.out_shape[2:]))

        context = (in_shape - out_roi) / 2
        gt_labels_out_shape = out_roi
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            gt_labels_out_shape = (1, *gt_labels_out_shape)

            points_roi = out_voxel_size * tuple((*self.params["point_roi"], ))
            points_pad = (0, *points_roi)
            context = gp.Coordinate((0, None, None))
        else:
            points_roi = source_voxel_size * tuple(self.params["point_roi"])
            points_pad = points_roi
            context = gp.Coordinate((None, None, None))

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")
        logger.info(f"out_voxel_size: {out_voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(points, points_roi)
        request.add(gt_labels, out_roi)
        request[locations] = gp.ArraySpec(nonspatial=True)
        request[predictions] = gp.ArraySpec(nonspatial=True)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))

        source = (
            (gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                           array_specs={
                               raw:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size,
                                            interpolatable=True),
                               gt_labels:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size)
                           }),
             PointsLabelsSource(points, self.data, scale=source_voxel_size)) +
            gp.MergeProvider() + gp.Pad(raw, context) +
            gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) +
            gp.RandomLocation(ensure_nonempty=points) +
            gp.Normalize(raw, self.params['norm_factor'])
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            # If 2d then source_roi = (1, input_shape) in order to select a RL
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # Batches seem to be rejected because points are chosen near the
            # edge of the points ROI and the augmentations remove them.
            # TODO: Figure out if this is an actual issue, and if anything can
            # be done.
            gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) +
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            AddChannelDim(raw) + AddChannelDim(gt_labels)
            # raw      : (c=1, source_roi)
            # gt_labels: (c=2, source_roi)
            # points   : (c=1, source_locations_shape)
        )

        if is_2d:
            pipeline = (
                # Remove extra dim the 2d roi had
                pipeline + RemoveSpatialDim(raw) +
                RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points)
                # raw      : (c=1, roi)
                # gt_labels: (c=1, roi)
                # points   : (c=1, locations_shape)
            )

        pipeline = (
            pipeline +
            FillLocations(raw, points, locations, is_2d=False, max_points=1) +
            gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, roi)
            # gt_labels: (b, c=1, roi)
            # locations: (b, c=1, locations_shape)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={
                    'raw': raw,
                    'points': locations
                },
                loss_inputs={
                    0: predictions,
                    1: gt_labels,
                    2: locations
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(nonspatial=True),
                    emb: gp.ArraySpec(voxel_size=out_voxel_size)
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, roi)
            # gt_labels  : (b, c=1, roi)
            # predictions: (b, num_points)
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            InspectBatch('END') + gp.PrintProfilingStats(every=500))

        return pipeline, request
Beispiel #3
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_instances = gp.ArrayKey('GT_INSTANCES')
    gt_mask = gp.ArrayKey('GT_MASK')
    pred_mask = gp.ArrayKey('PRED_MASK')
    #loss_weights = gp.ArrayKey('LOSS_WEIGHTS')
    loss_gradients = gp.ArrayKey('LOSS_GRADIENTS')

    # array keys for base and add volume
    raw_base = gp.ArrayKey('RAW_BASE')
    gt_instances_base = gp.ArrayKey('GT_INSTANCES_BASE')
    gt_mask_base = gp.ArrayKey('GT_MASK_BASE')
    raw_add = gp.ArrayKey('RAW_ADD')
    gt_instances_add = gp.ArrayKey('GT_INSTANCES_ADD')
    gt_mask_add = gp.ArrayKey('GT_MASK_ADD')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_instances, output_shape)
    request.add(gt_mask, output_shape)
    #request.add(loss_weights, output_shape)
    request.add(raw_base, input_shape)
    request.add(raw_add, input_shape)
    request.add(gt_mask_base, output_shape)
    request.add(gt_mask_add, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    #snapshot_request.add(raw_base, input_shape)
    #snapshot_request.add(raw_add, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    #snapshot_request.add(gt_mask_base, output_shape)
    #snapshot_request.add(gt_mask_add, output_shape)
    snapshot_request.add(pred_mask, output_shape)
    snapshot_request.add(loss_gradients, output_shape)

    # specify data source
    # data source for base volume
    data_sources_base = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_base += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_base: sample + '/raw',
                        gt_instances_base: sample + '/gt',
                        gt_mask_base: sample + '/fg',
                    },
                    array_specs={
                        raw_base: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_base: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_base: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_base, np.uint8) +
                gp.Pad(raw_base, context) +
                gp.Pad(gt_instances_base, context) +
                gp.Pad(gt_mask_base, context) +
                gp.RandomLocation(min_masked=0.005,  mask=gt_mask_base)
                #gp.Reject(gt_mask_base, min_masked=0.005, reject_probability=1.)
                for sample in f)
    data_sources_base += gp.RandomProvider()

    # data source for add volume
    data_sources_add = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources_add += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw_add: sample + '/raw',
                        gt_instances_add: sample + '/gt',
                        gt_mask_add: sample + '/fg',
                    },
                    array_specs={
                        raw_add: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_instances_add: gp.ArraySpec(interpolatable=False, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask_add: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask_add, np.uint8) +
                gp.Pad(raw_add, context) +
                gp.Pad(gt_instances_add, context) +
                gp.Pad(gt_mask_add, context) +
                gp.RandomLocation() +
                gp.Reject(gt_mask_add, min_masked=0.005, reject_probability=0.95)
                for sample in f)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add]) + gp.MergeProvider()

    pipeline = (
            data_sources +
            nl.FusionAugment(
                raw_base, raw_add, gt_instances_base, gt_instances_add, raw, gt_instances,
                blend_mode='labels_mask', blend_smoothness=5, num_blended_objects=0
            ) +
            BinarizeLabels(gt_instances, gt_mask) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2, -1) +
            #gp.BalanceLabels(gt_mask, loss_weights) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=10) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt']: gt_mask,
                    #net_names['loss_weights']: loss_weights,
                },
                outputs={
                    net_names['pred']: pred_mask,
                },
                gradients={
                    net_names['output']: loss_gradients,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    pred_mask: 'volumes/pred_mask',
                    gt_mask: 'volumes/gt_mask',
                    #loss_weights: 'volumes/loss_weights',
                    loss_gradients: 'volumes/loss_gradients',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2500) +
            gp.PrintProfilingStats(every=1000)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Beispiel #4
0
    def build_batch_provider(self, datasets, model, task, snapshot_container=None):
        input_shape = Coordinate(model.input_shape)
        output_shape = Coordinate(model.output_shape)

        # get voxel sizes
        raw_voxel_size = datasets[0].raw.voxel_size
        prediction_voxel_size = model.scale(raw_voxel_size)

        # define input and output size:
        # switch to world units
        input_size = raw_voxel_size * input_shape
        output_size = prediction_voxel_size * output_shape

        # padding of groundtruth/mask
        gt_mask_padding = output_size + task.predictor.padding(prediction_voxel_size)

        # define keys:
        raw_key = gp.ArrayKey("RAW")
        gt_key = gp.ArrayKey("GT")
        mask_key = gp.ArrayKey("MASK")

        target_key = gp.ArrayKey("TARGET")
        weight_key = gp.ArrayKey("WEIGHT")

        # Get source nodes
        dataset_sources = []
        for dataset in datasets:

            raw_source = DaCapoArraySource(dataset.raw, raw_key)
            raw_source += gp.Pad(raw_key, None, 0)
            gt_source = DaCapoArraySource(dataset.gt, gt_key)
            gt_source += gp.Pad(gt_key, gt_mask_padding, 0)
            if dataset.mask is not None:
                mask_source = DaCapoArraySource(dataset.mask, mask_key)
            else:
                # Always provide a mask. By default it is simply an array
                # of ones with the same shape/roi as gt. Avoids making us
                # specially handle no mask case and allows padding of the
                # ground truth without worrying about training on incorrect
                # data.
                mask_source = DaCapoArraySource(OnesArray.like(dataset.gt), mask_key)
            mask_source += gp.Pad(mask_key, gt_mask_padding, 0)
            array_sources = [raw_source, gt_source, mask_source]

            dataset_source = (
                tuple(array_sources) + gp.MergeProvider() + gp.RandomLocation()
            )

            dataset_sources.append(dataset_source)
        pipeline = tuple(dataset_sources) + gp.RandomProvider()

        for augment in self.augments:
            pipeline += augment.node(raw_key, gt_key, mask_key)

        pipeline += gp.Reject(mask_key, min_masked=self.min_masked)

        # Add predictor nodes to pipeline
        pipeline += DaCapoTargetFilter(
            task.predictor,
            gt_key=gt_key,
            target_key=target_key,
            weights_key=weight_key,
            mask_key=mask_key,
        )

        # Trainer attributes:
        if self.num_data_fetchers > 1:
            pipeline += gp.PreCache(num_workers=self.num_data_fetchers)

        # stack to create a batch dimension
        pipeline += gp.Stack(self.batch_size)

        # print profiling stats
        pipeline += gp.PrintProfilingStats(every=self.print_profiling)

        # generate request for all necessary inputs to training
        request = gp.BatchRequest()
        request.add(raw_key, input_size)
        request.add(target_key, output_size)
        request.add(weight_key, output_size)
        # request additional keys for snapshots
        request.add(gt_key, output_size)
        request.add(mask_key, output_size)

        self._request = request
        self._pipeline = pipeline
        self._raw_key = raw_key
        self._gt_key = gt_key
        self._mask_key = mask_key
        self._weight_key = weight_key
        self._target_key = target_key
        self._loss = task.loss

        self.snapshot_container = snapshot_container
Beispiel #5
0
def train_until(**kwargs):
    print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"])
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_labels = gp.ArrayKey('GT_LABELS')
    gt_instances = gp.ArrayKey('GT_INSTANCES')
    gt_affs = gp.ArrayKey('GT_AFFS')
    gt_numinst = gp.ArrayKey('GT_NUMINST')
    gt_fgbg = gp.ArrayKey('GT_FGBG')
    gt_sample_mask = gp.ArrayKey('GT_SAMPLE_MASK')

    pred_code = gp.ArrayKey('PRED_CODE')
    # pred_code_gradients = gp.ArrayKey('PRED_CODE_GRADIENTS')
    pred_numinst = gp.ArrayKey('PRED_NUMINST')
    pred_fgbg = gp.ArrayKey('PRED_FGBG')

    with open(os.path.join(kwargs['output_folder'],
                           kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['output_folder'],
                           kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape'])*voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape'])*voxel_size
    context = gp.Coordinate(input_shape_world - output_shape_world) / 2

    raw_key = kwargs.get('raw_key', 'volumes/raw')

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_labels, output_shape_world)
    request.add(gt_instances, output_shape_world)
    request.add(gt_sample_mask, output_shape_world)
    request.add(gt_affs, output_shape_world)
    if kwargs['overlapping_inst']:
        request.add(gt_numinst, output_shape_world)
    else:
        request.add(gt_fgbg, output_shape_world)
    # request.add(loss_weights_affs, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(pred_code, output_shape_world)
    # snapshot_request.add(pred_code_gradients, output_shape_world)
    if kwargs['overlapping_inst']:
        snapshot_request.add(pred_numinst, output_shape_world)
    else:
        snapshot_request.add(pred_fgbg, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw_bf']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw_bf']
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    neighborhood = []
    psH = np.array(kwargs['patchshape'])//2
    for i in range(-psH[1], psH[1]+1, kwargs['patchstride'][1]):
        for j in range(-psH[2], psH[2]+1, kwargs['patchstride'][2]):
            neighborhood.append([i,j])

    datasets = {
        raw: 'volumes/raw_bf',
        gt_labels: 'volumes/gt_labels',
        gt_instances: 'volumes/gt_instances',
    }
    array_specs = {
        raw: gp.ArraySpec(interpolatable=True),
        gt_labels: gp.ArraySpec(interpolatable=False),
        gt_instances: gp.ArraySpec(interpolatable=False),
    }
    inputs = {
        net_names['raw']: raw,
        net_names['gt_affs']: gt_affs,
    }

    outputs = {
        net_names['pred_code']: pred_code,
        net_names['raw_cropped']: raw_cropped,
    }
    snapshot = {
        raw: '/volumes/raw',
        raw_cropped: 'volumes/raw_cropped',
        gt_affs: '/volumes/gt_affs',
        pred_code: '/volumes/pred_code',
        # pred_code_gradients: '/volumes/pred_code_gradients',
    }
    if kwargs['overlapping_inst']:
        datasets[gt_numinst] = '/volumes/gt_numinst'
        array_specs[gt_numinst] = gp.ArraySpec(interpolatable=False)
        inputs[net_names['gt_numinst']] = gt_numinst
        outputs[net_names['pred_numinst']] = pred_numinst
        snapshot[pred_numinst] = '/volumes/pred_numinst'
    else:
        datasets[gt_fgbg] = '/volumes/gt_fgbg'
        array_specs[gt_fgbg] = gp.ArraySpec(interpolatable=False)
        inputs[net_names['gt_fgbg']] = gt_fgbg
        outputs[net_names['pred_fgbg']] = pred_fgbg
        snapshot[pred_fgbg] = '/volumes/pred_fgbg'

    augmentation = kwargs['augmentation']
    sampling = kwargs['sampling']

    source_fg = tuple(
        sourceNode(
            fls[t] + "." + kwargs['input_format'],
            datasets=datasets,
            array_specs=array_specs
        ) +
        gp.Pad(raw, context) +

        # chose a random location for each requested batch
        nl.CountOverlap(gt_labels, gt_sample_mask, maxnuminst=1) +
        gp.RandomLocation(
            min_masked=sampling['min_masked'],
            mask=gt_sample_mask
        )
        for t in range(ln)
    )
    source_fg += gp.RandomProvider()

    if kwargs['overlapping_inst']:
        source_overlap = tuple(
            sourceNode(
                fls[t] + "." + kwargs['input_format'],
                datasets=datasets,
                array_specs=array_specs
            ) +
            gp.Pad(raw, context) +

            # chose a random location for each requested batch
            nl.MaskCloseDistanceToOverlap(
                gt_labels, gt_sample_mask,
                sampling['overlap_min_dist'],
                sampling['overlap_max_dist']
            ) +
            gp.RandomLocation(
                min_masked=sampling['min_masked_overlap'],
                mask=gt_sample_mask
            )
            for t in range(ln)
        )
        source_overlap += gp.RandomProvider()

        source = (
            (source_fg, source_overlap) +

            # chose a random source (i.e., sample) from the above
            gp.RandomProvider(probabilities=[sampling['probability_fg'],
                                             sampling['probability_overlap']]))
    else:
        source = source_fg

    pipeline = (
        source +

        # elastically deform the batch
        gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0]) +

        gp.Reject(gt_sample_mask, min_masked=0.002, reject_probability=1) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(
            mirror_only=augmentation['simple'].get("mirror"),
            transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        gp.IntensityScaleShift(raw, 2, -1) +

        # convert labels into affinities between voxels
        nl.AddAffinities(
            neighborhood,
            gt_labels if kwargs['overlapping_inst'] else gt_instances,
            gt_affs,
            multiple_labels=kwargs['overlapping_inst']) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs=inputs,
            outputs=outputs,
            gradients={
                # net_names['pred_code']: pred_code_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            snapshot,
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info(
                "Batch: iteration=%d, time=%f",
                i, time_of_iteration)
            # exit()
    print("Training finished")