Exemplo n.º 1
0
    def add_target(self, gt, target):

        return (gp.AddAffinities(affinity_neighborhood=self.neighborhood,
                                 labels=gt,
                                 affinities=target) +
                # ensure affs are float
                gp.Normalize(target, factor=1.0))
Exemplo n.º 2
0
    def add_target(self, gt, target):

        return (
            gp.AddAffinities(affinity_neighborhood=self.neighborhood,
                             labels=gt,
                             affinities=target)
            # TODO: Fix Error: Found dtype Byte but expected Float
            # This can occur when backpropogating through MSE where
            # the predictions are floats but the targets are uint8
        )
Exemplo n.º 3
0
def train(iterations):

    ##################
    # DECLARE ARRAYS #
    ##################

    # raw intensities
    raw = gp.ArrayKey('RAW')

    # objects labelled with unique IDs
    gt_labels = gp.ArrayKey('LABELS')

    # array of per-voxel affinities to direct neighbors
    gt_affs = gp.ArrayKey('AFFINITIES')

    # weights to use to balance the loss
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    # the predicted affinities
    pred_affs = gp.ArrayKey('PRED_AFFS')

    # the gredient of the loss wrt to the predicted affinities
    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')

    ####################
    # DECLARE REQUESTS #
    ####################

    with open('train_net_config.json', 'r') as f:
        net_config = json.load(f)

    # get the input and output size in world units (nm, in this case)
    voxel_size = gp.Coordinate((8, 8, 8))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt_affs, output_size)
    request.add(loss_weights, output_size)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request[pred_affs] = request[gt_affs]
    snapshot_request[pred_affs_gradients] = request[gt_affs]

    ##############################
    # ASSEMBLE TRAINING PIPELINE #
    ##############################

    pipeline = (

        # a tuple of sources, one for each sample (A, B, and C) provided by the
        # CREMI challenge
        tuple(

            # read batches from the HDF5 file
            gp.Hdf5Source(os.path.join(data_dir, 'fib.hdf'),
                          datasets={
                              raw: 'volumes/raw',
                              gt_labels: 'volumes/labels/neuron_ids'
                          }) +

            # convert raw to float in [0, 1]
            gp.Normalize(raw) +

            # chose a random location for each requested batch
            gp.RandomLocation()) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        gp.ElasticAugment([8, 8, 8], [0, 2, 2], [0, math.pi / 2.0],
                          prob_slip=0.05,
                          prob_shift=0.05,
                          max_misalign=25) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(transpose_only=[1, 2]) +

        # scale and shift the intensity of the raw array
        gp.IntensityAugment(raw,
                            scale_min=0.9,
                            scale_max=1.1,
                            shift_min=-0.1,
                            shift_max=0.1,
                            z_section_wise=True) +

        # grow a boundary between labels
        gp.GrowBoundary(gt_labels, steps=3, only_xy=True) +

        # convert labels into affinities between voxels
        gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels,
                         gt_affs) +

        # create a weight array that balances positive and negative samples in
        # the affinity array
        gp.BalanceLabels(gt_affs, loss_weights) +

        # pre-cache batches from the point upstream
        gp.PreCache(cache_size=10, num_workers=5) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            'train_net',
            net_config['optimizer'],
            net_config['loss'],
            inputs={
                net_config['raw']: raw,
                net_config['gt_affs']: gt_affs,
                net_config['loss_weights']: loss_weights
            },
            outputs={net_config['pred_affs']: pred_affs},
            gradients={net_config['pred_affs']: pred_affs_gradients},
            save_every=10000) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                gt_labels: '/volumes/labels/neuron_ids',
                gt_affs: '/volumes/labels/affs',
                pred_affs: '/volumes/pred_affs',
                pred_affs_gradients: '/volumes/pred_affs_gradients'
            },
            output_dir='snapshots',
            output_filename='batch_{iteration}.hdf',
            every=1000,
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=1000))

    #########
    # TRAIN #
    #########

    print("Training for", iterations, "iterations")

    with gp.build(pipeline):
        for i in range(iterations):
            pipeline.request_batch(request)

    print("Finished")
Exemplo n.º 4
0
    def create_train_pipeline(self, model):

        print(
            f"Creating training pipeline with batch size {self.params['batch_size']}"
        )

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        gt_aff = gp.ArrayKey('AFFINITIES')
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_shape = gp.Coordinate(model.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_shape = out_shape * out_voxel_size

        context = (in_shape - out_shape) / 2
        gt_labels_out_shape = out_shape
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            aff_neighborhood = [[0, -1, 0], [0, 0, -1]]
            gt_labels_out_shape = (1, *gt_labels_out_shape)
        else:
            aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(gt_aff, out_shape)
        request.add(predictions, out_shape)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))
        snapshot_request[gt_labels] = gp.ArraySpec(
            roi=gp.Roi(context, gt_labels_out_shape))

        source = (
            gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                          array_specs={
                              raw:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size,
                                           interpolatable=True),
                              gt_labels:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size)
                          }) + gp.Normalize(raw, self.params['norm_factor']) +
            gp.Pad(raw, context) + gp.Pad(gt_labels, context) +
            gp.RandomLocation()
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
            gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) +
            SetDtype(gt_aff, np.float32) +
            # raw      : (l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
            AddChannelDim(raw)
            # raw      : (c=1, l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
        )

        if is_2d:
            pipeline = (
                pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff)
                # raw      : (c=1, h, w)
                # gt_aff   : (c=2, h, w)
            )

        pipeline = (
            pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, h, w)
            # gt_aff   : (b, c=2, h, w)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={'raw': raw},
                loss_inputs={
                    0: predictions,
                    1: gt_aff
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(voxel_size=out_voxel_size),
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, h, w)
            # gt_aff     : (b, c=2, h, w)
            # predictions: (b, c=2, h, w)

            # Crop GT to look at labels
            gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) +
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            gt_aff: 'gt_aff',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            gp.PrintProfilingStats(every=500))

        return pipeline, request
Exemplo n.º 5
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_labels = gp.ArrayKey('GT_LABELS')
    gt_affs = gp.ArrayKey('GT_AFFS')
    gt_fgbg = gp.ArrayKey('GT_FGBG')

    loss_weights_affs = gp.ArrayKey('LOSS_WEIGHTS_AFFS')
    loss_weights_fgbg = gp.ArrayKey('LOSS_WEIGHTS_FGBG')

    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_fgbg = gp.ArrayKey('PRED_FGBG')

    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')
    pred_fgbg_gradients = gp.ArrayKey('PRED_FGBG_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_labels, output_shape_world)
    request.add(gt_fgbg, output_shape_world)
    request.add(anchor, output_shape_world)
    request.add(gt_affs, output_shape_world)
    request.add(loss_weights_affs, output_shape_world)
    request.add(loss_weights_fgbg, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(pred_affs, output_shape_world)
    # snapshot_request.add(pred_affs_gradients, output_shape_world)
    snapshot_request.add(gt_fgbg, output_shape_world)
    snapshot_request.add(pred_fgbg, output_shape_world)
    # snapshot_request.add(pred_fgbg_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    # padR = 46
    # padGT = 32

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    pipeline = (
        tuple(
            sourceNode(
                fls[t] + "." + kwargs['input_format'],
                datasets={
                    raw: 'volumes/raw',
                    gt_labels: 'volumes/gt_labels',
                    gt_fgbg: 'volumes/gt_fgbg',
                    anchor: 'volumes/gt_fgbg',
                },
                array_specs={
                    raw: gp.ArraySpec(interpolatable=True),
                    gt_labels: gp.ArraySpec(interpolatable=False),
                    gt_fgbg: gp.ArraySpec(interpolatable=False),
                    anchor: gp.ArraySpec(interpolatable=False)
                }
            )
            + gp.Pad(raw, None)
            + gp.Pad(gt_labels, None)
            + gp.Pad(gt_fgbg, None)

            # chose a random location for each requested batch
            + gp.RandomLocation()

            for t in range(ln)
        ) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        # grow a boundary between labels
        gp.GrowBoundary(
            gt_labels,
            steps=1,
            only_xy=False) +

        # convert labels into affinities between voxels
        gp.AddAffinities(
            [[-1, 0, 0], [0, -1, 0], [0, 0, -1]],
            gt_labels,
            gt_affs) +

        # create a weight array that balances positive and negative samples in
        # the affinity array
        gp.BalanceLabels(
            gt_affs,
            loss_weights_affs) +

        gp.BalanceLabels(
            gt_fgbg,
            loss_weights_fgbg) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['gt_affs']: gt_affs,
                net_names['gt_fgbg']: gt_fgbg,
                net_names['anchor']: anchor,
                net_names['gt_labels']: gt_labels,
                net_names['loss_weights_affs']: loss_weights_affs,
                net_names['loss_weights_fgbg']: loss_weights_fgbg
            },
            outputs={
                net_names['pred_affs']: pred_affs,
                net_names['pred_fgbg']: pred_fgbg,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                net_names['pred_affs']: pred_affs_gradients,
                net_names['pred_fgbg']: pred_fgbg_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_labels: '/volumes/gt_labels',
                gt_affs: '/volumes/gt_affs',
                gt_fgbg: '/volumes/gt_fgbg',
                pred_affs: '/volumes/pred_affs',
                pred_affs_gradients: '/volumes/pred_affs_gradients',
                pred_fgbg: '/volumes/pred_fgbg',
                pred_fgbg_gradients: '/volumes/pred_fgbg_gradients',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
Exemplo n.º 6
0
def train(until):

    model = SpineUNet()
    loss = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    input_size = (8, 96, 96)

    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    affs = gp.ArrayKey('AFFS')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

    pipeline = (
        (
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample1/raw',
                    labels: 'train/sample1/labels'
                }),
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample2/raw',
                    labels: 'train/sample2/labels'
                }),
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample3/raw',
                    labels: 'train/sample3/labels'
                })
        ) +
        gp.RandomProvider() +
        gp.Normalize(raw) +
        gp.RandomLocation() +
        gp.SimpleAugment(transpose_only=(1, 2)) +
        gp.ElasticAugment((2, 10, 10), (0.0, 0.5, 0.5), [0, math.pi]) +
        gp.AddAffinities(
            [(1, 0, 0), (0, 1, 0), (0, 0, 1)],
            labels,
            affs) +
        gp.Normalize(affs, factor=1.0) +
        #gp.PreCache(num_workers=1) +
        # raw: (d, h, w)
        # affs: (3, d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        # affs: (1, 3, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        # affs: (1, 3, d, h, w)
        gp_torch.Train(
            model,
            loss,
            optimizer,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            loss_inputs={0: affs_predicted, 1: affs},
            save_every=10000) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs: (3, d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Snapshot(
            {
                raw: 'raw',
                labels: 'labels',
                affs: 'affs',
                affs_predicted: 'affs_predicted'
            },
            every=500,
            output_filename='iteration_{iteration}.hdf')
    )

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(affs, input_size)
    request.add(affs_predicted, input_size)

    with gp.build(pipeline):
        for i in range(until):
            pipeline.request_batch(request)
Exemplo n.º 7
0
def build_pipeline(
        data_dir,  
        model, 
        save_every,
        batch_size, 
        input_size, 
        output_size,
        raw, 
        labels,
        affs,
        affs_predicted,
        lr=1e-5): 

    dataset_shape = zarr.open(str(data_dir))['train/raw'].shape
    num_samples = dataset_shape[0]
    sample_size = dataset_shape[1:]

    loss = torch.nn.MSELoss()
    optimizer = RAdam(model.parameters(), lr=lr)
    
    pipeline = (
            gp.ZarrSource(
                data_dir,
                {
                    raw: 'train/raw',
                    labels: 'train/gt'
                },
                array_specs={
                    raw: gp.ArraySpec(
                        roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size),
                        voxel_size=(1, 1, 1)),
                    labels: gp.ArraySpec(
                        roi=gp.Roi((0, 0, 0), (num_samples,) + sample_size),
                        voxel_size=(1, 1, 1))
                }) +
            # raw: (d=1, h, w)
            # labels: (d=1, fmap_inc_factors=5h, w)
            gp.RandomLocation() +
            # raw: (d=1, h, w)
            # labels: (d=1, h, w)
            gp.AddAffinities(
                affinity_neighborhood=[(0, 1, 0), (0, 0, 1)],
                labels=labels,
                affinities=affs) +
            gp.Normalize(affs, factor=1.0) +
            # raw: (d=1, h, w)
            # affs: (c=2, d=1, h, w)
            Squash(dim=-3) +
            # get rid of z dim
            # raw: (h, w)
            # affs: (c=2, h, w)
            AddChannelDim(raw) +
            # raw: (c=1, h, w)
            # affs: (c=2, h, w)
            gp.PreCache() +
            gp.Stack(batch_size) +
            # raw: (b=10, c=1, h, w)
            # affs: (b=10, c=2, h, w)
            Train(
                model=model,
                loss=loss,
                optimizer=optimizer,
                inputs={'x': raw},
                target=affs,
                output=affs_predicted,
                save_every=save_every,
                log_dir='log') +
            # raw: (b=10, c=1, h, w)
            # affs: (b=10, c=2, h, w)
            # affs_predicted: (b=10, c=2, h, w)
            TransposeDims(raw,(1, 0, 2, 3)) +
            TransposeDims(affs,(1, 0, 2, 3)) +
            TransposeDims(affs_predicted,(1, 0, 2, 3)) +
            # raw: (c=1, b=10, h, w)
            # affs: (c=2, b=10, h, w)
            # affs_predicted: (c=2, b=10, h, w)
            RemoveChannelDim(raw) +
            # raw: (b=10, h, w)
            # affs: (c=2, b=10, h, w)
            # affs_predicted: (c=2, b=10, h, w)
            gp.Snapshot(
                dataset_names={
                    raw: 'raw',
                    labels: 'labels',
                    affs: 'affs',
                    affs_predicted: 'affs_predicted'
                },
                every=100) +
            gp.PrintProfilingStats(every=100)
        )
    return pipeline 
def train_until(**kwargs):
    print("cuda visibile devices", os.environ["CUDA_VISIBLE_DEVICES"])
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_labels = gp.ArrayKey('GT_LABELS')
    gt_affs = gp.ArrayKey('GT_AFFS')

    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(pred_affs, output_shape_world)
    snapshot_request.add(gt_affs, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    neighborhood = []
    psH = np.array(kwargs['patchshape']) // 2
    for i in range(-psH[0], psH[0] + 1, kwargs['patchstride'][0]):
        for j in range(-psH[1], psH[1] + 1, kwargs['patchstride'][1]):
            for k in range(-psH[2], psH[2] + 1, kwargs['patchstride'][2]):
                neighborhood.append([i, j, k])

    datasets = {
        raw: 'volumes/raw',
        gt_labels: 'volumes/gt_labels',
        anchor: 'volumes/gt_fgbg',
    }
    input_specs = {
        raw:
        gp.ArraySpec(roi=gp.Roi((0, ) * len(input_shape_world),
                                input_shape_world),
                     interpolatable=True,
                     dtype=np.float32),
        gt_labels:
        gp.ArraySpec(roi=gp.Roi((0, ) * len(output_shape_world),
                                output_shape_world),
                     interpolatable=False,
                     dtype=np.uint16),
        anchor:
        gp.ArraySpec(roi=gp.Roi((0, ) * len(output_shape_world),
                                output_shape_world),
                     interpolatable=False,
                     dtype=np.uint8),
        gt_affs:
        gp.ArraySpec(roi=gp.Roi((0, ) * len(output_shape_world),
                                output_shape_world),
                     interpolatable=False,
                     dtype=np.uint8)
    }
    inputs = {
        net_names['raw']: raw,
        net_names['gt_affs']: gt_affs,
        net_names['anchor']: anchor,
    }

    outputs = {
        net_names['pred_affs']: pred_affs,
        net_names['raw_cropped']: raw_cropped,
    }
    snapshot = {
        raw_cropped: 'volumes/raw_cropped',
        gt_affs: '/volumes/gt_affs',
        pred_affs: '/volumes/pred_affs',
    }

    optimizer_args = None
    if kwargs['auto_mixed_precision']:
        optimizer_args = (kwargs['optimizer'], {
            'args': kwargs['args'],
            'kwargs': kwargs['kwargs']
        })
    augmentation = kwargs['augmentation']
    pipeline = (
        tuple(
            sourceNode(
                fls[t] + "." + kwargs['input_format'],
                datasets=datasets,
                # array_specs=array_specs
            )
            + gp.Pad(raw, None)
            + gp.Pad(gt_labels, None)

            # chose a random location for each requested batch
            + gp.RandomLocation()

            for t in range(ln)
        ) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=4) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +

        # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        # grow a boundary between labels
        gp.GrowBoundary(
            gt_labels,
            steps=1,
            only_xy=False) +

        # convert labels into affinities between voxels
        gp.AddAffinities(
            neighborhood,
            gt_labels,
            gt_affs) +

        # create a weight array that balances positive and negative samples in
        # the affinity array
        # gp.BalanceLabels(
        #     gt_affs,
        #     loss_weights_affs) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # pre-fetch batches from the point upstream
        (gp.tensorflow.TFData() \
         if kwargs.get('use_tf_data') else NoOp()) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs=inputs,
            outputs=outputs,
            array_specs=input_specs,
            gradients={
                net_names['pred_affs']: pred_affs_gradients,
            },
            auto_mixed_precision=kwargs['auto_mixed_precision'],
            optimizer_args=optimizer_args,
            use_tf_data=kwargs['use_tf_data'],
            save_every=kwargs['checkpoints'],
            snapshot_every=kwargs['snapshots']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            snapshot,
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    try:
        with gp.build(pipeline):
            print(pipeline)
            for i in range(trained_until, kwargs['max_iteration']):
                start = time.time()
                pipeline.request_batch(request)
                time_of_iteration = time.time() - start

                logger.info("Batch: iteration=%d, time=%f", i,
                            time_of_iteration)
            # exit()
    except KeyboardInterrupt:
        sys.exit()
    print("Training finished")