예제 #1
0
    def test_unsqueeze(self):
        raw = gp.ArrayKey("RAW")
        labels = gp.ArrayKey("LABELS")

        voxel_size = gp.Coordinate((50, 5, 5))
        input_voxels = gp.Coordinate((10, 10, 10))
        input_size = input_voxels * voxel_size

        request = gp.BatchRequest()
        request.add(raw, input_size)
        request.add(labels, input_size)

        pipeline = (
            ExampleSourceUnsqueeze(voxel_size)
            + gp.Unsqueeze([raw, labels])
            + gp.Unsqueeze([raw], axis=1)
        )

        with gp.build(pipeline) as p:
            batch = p.request_batch(request)
            assert batch[raw].data.shape == (1,) + (1,) + input_voxels
            assert batch[labels].data.shape == (1,) + input_voxels
예제 #2
0
def test_gp_dacapo_array_source(array_config):

    # Create Array from config
    array = array_config.array_type(array_config)

    # Make sure the DaCapoArraySource can properly read
    # the data in `array`
    key = gp.ArrayKey("TEST")
    source_node = DaCapoArraySource(array, key)

    with gp.build(source_node):
        request = gp.BatchRequest()
        request[key] = gp.ArraySpec(roi=array.roi)
        batch = source_node.request_batch(request)
        data = batch[key].data
        assert (data - array[array.roi]).sum() == 0
예제 #3
0
    def test_pipeline3(self):
        array_key = gp.ArrayKey("TEST_ARRAY")
        points_key = gp.PointsKey("TEST_POINTS")
        voxel_size = gp.Coordinate((1, 1))
        spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file,
                                    {array_key: 'testdata'},
                                    array_specs={array_key: spec})
        csv_source = gp.CsvPointsSource(
            self.fake_points_file, points_key,
            gp.PointsSpec(
                roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0))))

        request = gp.BatchRequest()
        shape = gp.Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = gp.ShiftAugment(prob_slip=0.2,
                                     prob_shift=0.2,
                                     sigma=5,
                                     shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() +
                    gp.RandomLocation(ensure_nonempty=points_key) + shift_node)
        with gp.build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = request[points_key].data
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points.values()
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(result_val, list(result_points.values()), target_vals,
                        self.fake_points))
예제 #4
0
def add_fg_pred(config, pipeline, raw, block, model):
    checkpoint_file = config["FG_MODEL_CHECKPOINT"]

    device = config.get("DEVICE", "cuda")

    fg_pred = gp.ArrayKey(f"FG_PRED_{block}")

    pipeline = (pipeline + nl.gunpowder.nodes.helpers.UnSqueeze(raw) +
                gp.torch.Predict(
                    model,
                    inputs={"raw": raw},
                    outputs={0: fg_pred},
                    checkpoint=checkpoint_file,
                    device=device,
                ) + nl.gunpowder.nodes.helpers.Squeeze(raw) +
                nl.gunpowder.nodes.helpers.Squeeze(fg_pred))

    return pipeline, fg_pred
예제 #5
0
    def test_prepare1(self):

        key = gp.ArrayKey("TEST_ARRAY")
        spec = gp.ArraySpec(voxel_size=gp.Coordinate((1, 1)),
                            interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'},
                                    array_specs={key: spec})

        request = gp.BatchRequest()
        shape = gp.Coordinate((3, 3))
        request.add(key, shape, voxel_size=gp.Coordinate((1, 1)))

        shift_node = gp.ShiftAugment(sigma=1, shift_axis=0)
        with gp.build((hdf5_source + shift_node)):
            shift_node.prepare(request)
            self.assertTrue(shift_node.ndim == 2)
            self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0]))
예제 #6
0
    def test_pipeline2(self):

        key = gp.ArrayKey("TEST_ARRAY")
        spec = gp.ArraySpec(voxel_size=gp.Coordinate((3, 1)),
                            interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'},
                                    array_specs={key: spec})

        request = gp.BatchRequest()
        shape = gp.Coordinate((3, 3))
        request.add(key, shape, voxel_size=gp.Coordinate((3, 1)))

        shift_node = gp.ShiftAugment(prob_slip=0.2,
                                     prob_shift=0.2,
                                     sigma=1,
                                     shift_axis=0)
        with gp.build((hdf5_source + shift_node)) as b:
            b.request_batch(request)
예제 #7
0
def get_labels_snapshot_source(config, blocks):
    validation_blocks = Path(config["VALIDATION_BLOCKS"])

    labels = gp.ArrayKey("LABELS")
    gt = gp.GraphKey("GT")

    block_pipelines = []
    for block in blocks:

        pipeline = SnapshotSource(
            validation_blocks / f"block_{block}.hdf",
            {
                labels: "volumes/labels",
                gt: "points/gt"
            },
            directed={gt: True},
        )

        block_pipelines.append(pipeline)
    return block_pipelines, (labels, gt)
예제 #8
0
def add_nms(pipeline, config, foreground):
    model_config = copy.deepcopy(DEFAULT_CONFIG)
    model_config.update(json.load(open(config["EMB_MODEL_CONFIG"])))

    # Data properties
    voxel_size = gp.Coordinate(model_config["VOXEL_SIZE"])
    micron_scale = voxel_size[0]

    # Config options
    window_size = gp.Coordinate(model_config["NMS_WINDOW_SIZE"]) * micron_scale
    threshold = model_config["NMS_THRESHOLD"]

    # New array Key
    maxima = gp.ArrayKey("MAXIMA")

    pipeline = (pipeline + nl.gunpowder.nodes.helpers.UnSqueeze(foreground) +
                nl.gunpowder.nodes.NonMaxSuppression(foreground, maxima,
                                                     window_size, threshold) +
                nl.gunpowder.nodes.helpers.Squeeze(foreground) +
                nl.gunpowder.nodes.helpers.Squeeze(maxima))

    return pipeline, maxima
예제 #9
0
def add_emb_preds(config, pipelines, raw):
    model_config = copy.deepcopy(DEFAULT_CONFIG)
    model_config.update(json.load(open(config["EMB_MODEL_CONFIG"])))
    device = config.get("DEVICE", "cuda")

    checkpoint_file = config["EMB_MODEL_CHECKPOINT"]

    emb_pred = gp.ArrayKey("EMB_PRED")

    emb_pipelines = []
    for pipeline in pipelines:
        emb_pipelines.append(
            pipeline + nl.gunpowder.nodes.helpers.UnSqueeze(raw) +
            gp.torch.Predict(
                nl.networks.pytorch.EmbeddingUnet(model_config),
                inputs={"raw": raw},
                outputs={0: emb_pred},
                checkpoint=checkpoint_file,
                device=device,
            ) + nl.gunpowder.nodes.helpers.Squeeze(raw) +
            nl.gunpowder.nodes.helpers.Squeeze(emb_pred))

    return emb_pipelines, (emb_pred, )
예제 #10
0
def predict_2d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = raw_data.shape
    dataset_roi = raw_data.roi
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError(
            "For 2D validation, please provide a 3D array where the first "
            "dimension indexes the samples.")

    num_samples = raw_data.num_samples

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    if gt_data:
        sources = (raw_data.get_source(raw, overwrite_spec=spec),
                   gt_data.get_source(gt, overwrite_spec=spec))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw, overwrite_spec=spec)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    # target: ([c,] s, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    if gt_data and predictor.target_channels == 0:
        pipeline += AddChannelDim(target)
    # raw: (c, s, h, w)
    # gt: ([c,] s, h, w)
    # target: (c, s, h, w)
    pipeline += TransposeDims(raw, (1, 0, 2, 3))
    if gt_data:
        pipeline += TransposeDims(target, (1, 0, 2, 3))
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    # prediction: (s, c, h, w)
    pipeline += gp.Scan(scan_request)

    total_request = gp.BatchRequest()
    total_request.add(raw, sample_size)
    total_request.add(prediction, sample_size)
    if gt_data:
        total_request.add(gt, sample_size)
        total_request.add(target, sample_size)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
예제 #11
0
    fmap_inc_factor=4,  # this needs to be increased later (3)
    downsample_factors=[
        [1, 2, 2],
        [1, 2, 2],
        [1, 2, 2],
    ],
    kernel_size_down=[[[3, 3, 3], [3, 3, 3]]] * 4,
    kernel_size_up=[[[3, 3, 3], [3, 3, 3]]] * 3,
    padding='valid')
model = torch.nn.Sequential(unet,
                            ConvPass(24, 1, [(1, 1, 1)], activation='Sigmoid'))
loss = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

# declare gunpowder arrays
raw = gp.ArrayKey('RAW')
seg = gp.ArrayKey('SEGMENTATION')
out_cage_map = gp.ArrayKey('OUT_CAGE_MAP')
out_density_map = gp.ArrayKey('OUT_DENSITY_MAP')
prediction = gp.ArrayKey('PREDICTION')


class PrepareTrainingData(gp.BatchFilter):
    def process(self, batch, request):

        batch[out_cage_map].data = batch[out_cage_map].data.astype(np.float32)
        batch[out_cage_map].spec.dtype = np.float32


# assemble pipeline
sourceA = gp.ZarrSource('../data/cropped_sample_A.zarr', {
예제 #12
0
def predict(iteration):

    ##################
    # DECLARE ARRAYS #
    ##################

    # raw intensities
    raw = gp.ArrayKey('RAW')

    # the predicted affinities
    pred_affs = gp.ArrayKey('PRED_AFFS')

    ####################
    # DECLARE REQUESTS #
    ####################

    with open('test_net_config.json', 'r') as f:
        net_config = json.load(f)

    # get the input and output size in world units (nm, in this case)
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size
    context = input_size - output_size

    # formulate the request for what a batch should contain
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(pred_affs, output_size)

    #############################
    # ASSEMBLE TESTING PIPELINE #
    #############################

    source = gp.Hdf5Source('sample_A_padded_20160501.hdf',
                           datasets={raw: 'volumes/raw'})

    # get the ROI provided for raw (we need it later to calculate the ROI in
    # which we can make predictions)
    with gp.build(source):
        raw_roi = source.spec[raw].roi

    pipeline = (

        # read from HDF5 file
        source +

        # convert raw to float in [0, 1]
        gp.Normalize(raw) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(
            graph='test_net.meta',
            checkpoint='train_net_checkpoint_%d' % iteration,
            inputs={net_config['raw']: raw},
            outputs={net_config['pred_affs']: pred_affs},
            array_specs={
                pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))
            }) +

        # store all passing batches in the same HDF5 file
        gp.Hdf5Write({
            raw: '/volumes/raw',
            pred_affs: '/volumes/pred_affs',
        },
                     output_filename='predictions_sample_A.hdf',
                     compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request))

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
예제 #13
0
def predict(iteration,
            raw_file,
            raw_dataset,
            out_file,
            db_host,
            db_name,
            worker_config,
            network_config,
            out_properties={},
            **kwargs):
    setup_dir = os.path.dirname(os.path.realpath(__file__))

    with open(
            os.path.join(setup_dir,
                         '{}_net_config.json'.format(network_config)),
            'r') as f:
        net_config = json.load(f)

    # voxels
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])

    # nm
    voxel_size = gp.Coordinate((40, 4, 4))
    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size

    parameterfile = os.path.join(setup_dir, 'parameter.json')
    if os.path.exists(parameterfile):
        with open(parameterfile, 'r') as f:
            parameters = json.load(f)
    else:
        parameters = {}

    raw = gp.ArrayKey('RAW')
    pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS')
    pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR')

    chunk_request = gp.BatchRequest()
    chunk_request.add(raw, input_size)
    chunk_request.add(pred_postpre_vectors, output_size)
    chunk_request.add(pred_post_indicator, output_size)

    d_property = out_properties[
        'pred_partner_vectors'] if 'pred_partner_vectors' in out_properties else None
    m_property = out_properties[
        'pred_syn_indicator_out'] if 'pred_syn_indicator_out' in out_properties else None

    # Hdf5Source
    if raw_file.endswith('.hdf'):
        pipeline = gp.Hdf5Source(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    elif raw_file.endswith('.zarr') or raw_file.endswith('.n5'):
        pipeline = gp.ZarrSource(raw_file,
                                 datasets={raw: raw_dataset},
                                 array_specs={
                                     raw: gp.ArraySpec(interpolatable=True),
                                 })
    else:
        raise RuntimeError('unknwon input data format {}'.format(raw_file))

    pipeline += gp.Pad(raw, size=None)

    pipeline += gp.Normalize(raw)

    pipeline += gp.IntensityScaleShift(raw, 2, -1)

    pipeline += gp.tensorflow.Predict(
        os.path.join(setup_dir, 'train_net_checkpoint_%d' % iteration),
        inputs={net_config['raw']: raw},
        outputs={
            net_config['pred_syn_indicator_out']: pred_post_indicator,
            net_config['pred_partner_vectors']: pred_postpre_vectors
        },
        graph=os.path.join(setup_dir, '{}_net.meta'.format(network_config)))
    d_scale = parameters['d_scale'] if 'd_scale' in parameters else None
    if d_scale != 1 and d_scale is not None:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors, 1. / d_scale,
                                           0)  # Map back to nm world.
    if m_property is not None and 'scale' in m_property:
        if m_property['scale'] != 1:
            pipeline += gp.IntensityScaleShift(pred_post_indicator,
                                               m_property['scale'], 0)
    if d_property is not None and 'scale' in d_property:
        pipeline += gp.IntensityScaleShift(pred_postpre_vectors,
                                           d_property['scale'], 0)
    if d_property is not None and 'dtype' in d_property:
        assert d_property['dtype'] == 'int8' or d_property[
            'dtype'] == 'float32', 'predict not adapted to dtype {}'.format(
                d_property['dtype'])
        if d_property['dtype'] == 'int8':
            pipeline += IntensityScaleShiftClip(pred_postpre_vectors,
                                                1,
                                                0,
                                                clip=(-128, 127))

    pipeline += gp.ZarrWrite(dataset_names={
        pred_post_indicator:
        'volumes/pred_syn_indicator',
        pred_postpre_vectors:
        'volumes/pred_partner_vectors',
    },
                             output_filename=out_file)

    pipeline += gp.PrintProfilingStats(every=10)

    pipeline += gp.DaisyRequestBlocks(
        chunk_request,
        roi_map={
            raw: 'read_roi',
            pred_postpre_vectors: 'write_roi',
            pred_post_indicator: 'write_roi'
        },
        num_workers=worker_config['num_cache_workers'],
        block_done_callback=lambda b, s, d: block_done_callback(
            db_host, db_name, worker_config, b, s, d))

    print("Starting prediction...")
    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())
    print("Prediction finished")
예제 #14
0
파일: train.py 프로젝트: funkelab/synful
def build_pipeline(parameter, augment=True):
    voxel_size = gp.Coordinate(parameter['voxel_size'])

    # Array Specifications.
    raw = gp.ArrayKey('RAW')
    gt_neurons = gp.ArrayKey('GT_NEURONS')
    gt_postpre_vectors = gp.ArrayKey('GT_POSTPRE_VECTORS')
    gt_post_indicator = gp.ArrayKey('GT_POST_INDICATOR')
    post_loss_weight = gp.ArrayKey('POST_LOSS_WEIGHT')
    vectors_mask = gp.ArrayKey('VECTORS_MASK')

    pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS')
    pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR')

    grad_syn_indicator = gp.ArrayKey('GRAD_SYN_INDICATOR')
    grad_partner_vectors = gp.ArrayKey('GRAD_PARTNER_VECTORS')

    # Points specifications
    dummypostsyn = gp.PointsKey('DUMMYPOSTSYN')
    postsyn = gp.PointsKey('POSTSYN')
    presyn = gp.PointsKey('PRESYN')
    trg_context = 140  # AddPartnerVectorMap context in nm - pre-post distance

    with open('train_net_config.json', 'r') as f:
        net_config = json.load(f)

    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt_neurons, output_size)
    request.add(gt_postpre_vectors, output_size)
    request.add(gt_post_indicator, output_size)
    request.add(post_loss_weight, output_size)
    request.add(vectors_mask, output_size)
    request.add(dummypostsyn, output_size)

    for (key, request_spec) in request.items():
        print(key)
        print(request_spec.roi)
        request_spec.roi.contains(request_spec.roi)
    # slkfdms

    snapshot_request = gp.BatchRequest({
        pred_post_indicator:
        request[gt_postpre_vectors],
        pred_postpre_vectors:
        request[gt_postpre_vectors],
        grad_syn_indicator:
        request[gt_postpre_vectors],
        grad_partner_vectors:
        request[gt_postpre_vectors],
        vectors_mask:
        request[gt_postpre_vectors]
    })

    postsyn_rastersetting = gp.RasterizationSettings(
        radius=parameter['blob_radius'],
        mask=gt_neurons,
        mode=parameter['blob_mode'])

    pipeline = tuple([
        create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter,
                      gt_neurons) for sample in samples
    ])

    pipeline += gp.RandomProvider()
    if augment:
        pipeline += gp.ElasticAugment([4, 40, 40], [0, 2, 2],
                                      [0, math.pi / 2.0],
                                      prob_slip=0.05,
                                      prob_shift=0.05,
                                      max_misalign=10,
                                      subsample=8)
        pipeline += gp.SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2])
        pipeline += gp.IntensityAugment(raw,
                                        0.9,
                                        1.1,
                                        -0.1,
                                        0.1,
                                        z_section_wise=True)
    pipeline += gp.IntensityScaleShift(raw, 2, -1)
    pipeline += gp.RasterizePoints(
        postsyn, gt_post_indicator,
        gp.ArraySpec(voxel_size=voxel_size, dtype=np.int32),
        postsyn_rastersetting)
    spec = gp.ArraySpec(voxel_size=voxel_size)
    pipeline += AddPartnerVectorMap(
        src_points=postsyn,
        trg_points=presyn,
        array=gt_postpre_vectors,
        radius=parameter['d_blob_radius'],
        trg_context=trg_context,  # enlarge
        array_spec=spec,
        mask=gt_neurons,
        pointmask=vectors_mask)
    pipeline += gp.BalanceLabels(labels=gt_post_indicator,
                                 scales=post_loss_weight,
                                 slab=(-1, -1, -1),
                                 clipmin=parameter['cliprange'][0],
                                 clipmax=parameter['cliprange'][1])
    if parameter['d_scale'] != 1:
        pipeline += gp.IntensityScaleShift(gt_postpre_vectors,
                                           scale=parameter['d_scale'],
                                           shift=0)
    pipeline += gp.PreCache(cache_size=40, num_workers=10)
    pipeline += gp.tensorflow.Train(
        './train_net',
        optimizer=net_config['optimizer'],
        loss=net_config['loss'],
        summary=net_config['summary'],
        log_dir='./tensorboard/',
        save_every=30000,  # 10000
        log_every=100,
        inputs={
            net_config['raw']:
            raw,
            net_config['gt_partner_vectors']:
            gt_postpre_vectors,
            net_config['gt_syn_indicator']:
            gt_post_indicator,
            net_config['vectors_mask']:
            vectors_mask,
            # Loss weights --> mask
            net_config['indicator_weight']:
            post_loss_weight,  # Loss weights
        },
        outputs={
            net_config['pred_partner_vectors']: pred_postpre_vectors,
            net_config['pred_syn_indicator']: pred_post_indicator,
        },
        gradients={
            net_config['pred_partner_vectors']: grad_partner_vectors,
            net_config['pred_syn_indicator']: grad_syn_indicator,
        },
    )
    # Visualize.
    pipeline += gp.IntensityScaleShift(raw, 0.5, 0.5)
    pipeline += gp.Snapshot(
        {
            raw: 'volumes/raw',
            gt_neurons: 'volumes/labels/neuron_ids',
            gt_post_indicator: 'volumes/gt_post_indicator',
            gt_postpre_vectors: 'volumes/gt_postpre_vectors',
            pred_postpre_vectors: 'volumes/pred_postpre_vectors',
            pred_post_indicator: 'volumes/pred_post_indicator',
            post_loss_weight: 'volumes/post_loss_weight',
            grad_syn_indicator: 'volumes/post_indicator_gradients',
            grad_partner_vectors: 'volumes/partner_vectors_gradients',
            vectors_mask: 'volumes/vectors_mask'
        },
        every=1000,
        output_filename='batch_{iteration}.hdf',
        compression_type='gzip',
        additional_request=snapshot_request)
    pipeline += gp.PrintProfilingStats(every=100)

    print("Starting training...")
    max_iteration = parameter['max_iteration']
    with gp.build(pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)
예제 #15
0
def predict(**kwargs):
    name = kwargs['name']

    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_fgbg = gp.ArrayKey('PRED_FGBG')

    with open(os.path.join(kwargs['input_folder'], name + '_config.json'),
              'r') as f:
        net_config = json.load(f)
    with open(os.path.join(kwargs['input_folder'], name + '_names.json'),
              'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size
    context = (input_shape_world - output_shape_world) // 2

    # formulate the request for what a batch should contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(pred_affs, output_shape_world)
    request.add(pred_fgbg, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("predict node for %s not implemented yet",
                                  kwargs['input_format'])
    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
        with h5py.File(
                os.path.join(kwargs['data_folder'], kwargs['sample'] + ".hdf"),
                'r') as f:
            shape = f['volumes/raw'].shape
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource
        f = zarr.open(
            os.path.join(kwargs['data_folder'], kwargs['sample'] + ".zarr"),
            'r')
        shape = f['volumes/raw'].shape
    source = sourceNode(os.path.join(
        kwargs['data_folder'],
        kwargs['sample'] + "." + kwargs['input_format']),
                        datasets={raw: 'volumes/raw'})

    if kwargs['output_format'] != "zarr":
        raise NotImplementedError("Please use zarr as prediction output")
    # pre-create zarr file
    zf = zarr.open(os.path.join(kwargs['output_folder'],
                                kwargs['sample'] + '.zarr'),
                   mode='w')
    zf.create('volumes/pred_affs',
              shape=[3] + list(shape),
              chunks=[3] + list(shape),
              dtype=np.float32)
    zf['volumes/pred_affs'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_affs'].attrs['resolution'] = kwargs['voxel_size']

    zf.create('volumes/pred_fgbg',
              shape=[1] + list(shape),
              chunks=[1] + list(shape),
              dtype=np.float32)
    zf['volumes/pred_fgbg'].attrs['offset'] = [0, 0, 0]
    zf['volumes/pred_fgbg'].attrs['resolution'] = kwargs['voxel_size']

    zf.create('volumes/raw_cropped',
              shape=[1] + list(shape),
              chunks=[1] + list(shape),
              dtype=np.float32)
    zf['volumes/raw_cropped'].attrs['offset'] = [0, 0, 0]
    zf['volumes/raw_cropped'].attrs['resolution'] = kwargs['voxel_size']

    pipeline = (

        # read from HDF5 file
        source + gp.Pad(raw, context) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Predict(graph=os.path.join(kwargs['input_folder'],
                                                 name + '.meta'),
                              checkpoint=kwargs['checkpoint'],
                              inputs={net_names['raw']: raw},
                              outputs={
                                  net_names['pred_affs']: pred_affs,
                                  net_names['pred_fgbg']: pred_fgbg,
                                  net_names['raw_cropped']: raw_cropped
                              }) +

        # store all passing batches in the same HDF5 file
        gp.ZarrWrite(
            {
                raw_cropped: '/volumes/raw_cropped',
                pred_affs: '/volumes/pred_affs',
                pred_fgbg: '/volumes/pred_fgbg',
            },
            output_dir=kwargs['output_folder'],
            output_filename=kwargs['sample'] + ".zarr",
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=10) +

        # iterate over the whole dataset in a scanning fashion, emitting
        # requests that match the size of the network
        gp.Scan(reference=request))

    with gp.build(pipeline):
        # request an empty batch from Scan to trigger scanning of the dataset
        # without keeping the complete dataset in memory
        pipeline.request_batch(gp.BatchRequest())
예제 #16
0
    def create_train_pipeline(self, model):

        print(
            f"Creating training pipeline with batch size {self.params['batch_size']}"
        )

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        gt_aff = gp.ArrayKey('AFFINITIES')
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_shape = gp.Coordinate(model.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_shape = out_shape * out_voxel_size

        context = (in_shape - out_shape) / 2
        gt_labels_out_shape = out_shape
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            aff_neighborhood = [[0, -1, 0], [0, 0, -1]]
            gt_labels_out_shape = (1, *gt_labels_out_shape)
        else:
            aff_neighborhood = [[-1, 0, 0], [0, -1, 0], [0, 0, -1]]

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(gt_aff, out_shape)
        request.add(predictions, out_shape)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))
        snapshot_request[gt_labels] = gp.ArraySpec(
            roi=gp.Roi(context, gt_labels_out_shape))

        source = (
            gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                          array_specs={
                              raw:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size,
                                           interpolatable=True),
                              gt_labels:
                              gp.ArraySpec(roi=source_roi,
                                           voxel_size=source_voxel_size)
                          }) + gp.Normalize(raw, self.params['norm_factor']) +
            gp.Pad(raw, context) + gp.Pad(gt_labels, context) +
            gp.RandomLocation()
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # raw      : (l=1, h, w)
            # gt_labels: (l=1, h, w)
            gp.AddAffinities(aff_neighborhood, gt_labels, gt_aff) +
            SetDtype(gt_aff, np.float32) +
            # raw      : (l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
            AddChannelDim(raw)
            # raw      : (c=1, l=1, h, w)
            # gt_aff   : (c=2, l=1, h, w)
        )

        if is_2d:
            pipeline = (
                pipeline + RemoveSpatialDim(raw) + RemoveSpatialDim(gt_aff)
                # raw      : (c=1, h, w)
                # gt_aff   : (c=2, h, w)
            )

        pipeline = (
            pipeline + gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, h, w)
            # gt_aff   : (b, c=2, h, w)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={'raw': raw},
                loss_inputs={
                    0: predictions,
                    1: gt_aff
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(voxel_size=out_voxel_size),
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, h, w)
            # gt_aff     : (b, c=2, h, w)
            # predictions: (b, c=2, h, w)

            # Crop GT to look at labels
            gp.Crop(gt_labels, gp.Roi(context, gt_labels_out_shape)) +
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            gt_aff: 'gt_aff',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            gp.PrintProfilingStats(every=500))

        return pipeline, request
예제 #17
0
                output_roi.get_end() / input_spec.voxel_size,
            ))]
        output_spec = copy.deepcopy(input_spec)
        output_spec.roi = output_roi

        output_array = gp.Array(output_data, output_spec)

        batch[self.output_array] = output_array


input_size = Coordinate([74, 260, 260])
output_size = Coordinate([42, 168, 168])
path_to_data = Path("/nrs/funke/mouselight-v2")

# array keys for data sources
raw = gp.ArrayKey("RAW")
swcs = gp.PointsKey("SWCS")
labels = gp.ArrayKey("LABELS")

# array keys for base volume
raw_base = gp.ArrayKey("RAW_BASE")
labels_base = gp.ArrayKey("LABELS_BASE")
swc_base = gp.PointsKey("SWC_BASE")

# array keys for add volume
raw_add = gp.ArrayKey("RAW_ADD")
labels_add = gp.ArrayKey("LABELS_ADD")
swc_add = gp.PointsKey("SWC_ADD")

# array keys for fused volume
raw_fused = gp.ArrayKey("RAW_FUSED")
예제 #18
0
파일: predict.py 프로젝트: pattonw/dacapo
def predict(
    model: Model,
    raw_array: Array,
    prediction_array_identifier: LocalArrayIdentifier,
    num_cpu_workers: int = 4,
    compute_context: ComputeContext = LocalTorch(),
    output_roi: Optional[Roi] = None,
):
    # get the model's input and output size

    input_voxel_size = Coordinate(raw_array.voxel_size)
    output_voxel_size = model.scale(input_voxel_size)
    input_shape = Coordinate(model.eval_input_shape)
    input_size = input_voxel_size * input_shape
    output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

    logger.info(
        "Predicting with input size %s, output size %s", input_size, output_size
    )

    # calculate input and output rois

    context = (input_size - output_size) / 2
    if output_roi is None:
        input_roi = raw_array.roi
        output_roi = input_roi.grow(-context, -context)
    else:
        input_roi = output_roi.grow(context, context)

    logger.info("Total input ROI: %s, output ROI: %s", input_roi, output_roi)

    # prepare prediction dataset
    axes = ["c"] + [axis for axis in raw_array.axes if axis != "c"]
    ZarrArray.create_from_array_identifier(
        prediction_array_identifier,
        axes,
        output_roi,
        model.num_out_channels,
        output_voxel_size,
        np.float32,
    )

    # create gunpowder keys

    raw = gp.ArrayKey("RAW")
    prediction = gp.ArrayKey("PREDICTION")

    # assemble prediction pipeline

    # prepare data source
    pipeline = DaCapoArraySource(raw_array, raw)
    # raw: (c, d, h, w)
    pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims))
    # raw: (c, d, h, w)
    pipeline += gp.Unsqueeze([raw])
    # raw: (1, c, d, h, w)

    gt_padding = (output_size - output_roi.shape) % output_size
    prediction_roi = output_roi.grow(gt_padding)

    # predict
    pipeline += gp_torch.Predict(
        model=model,
        inputs={"x": raw},
        outputs={0: prediction},
        array_specs={
            prediction: gp.ArraySpec(
                roi=prediction_roi, voxel_size=output_voxel_size, dtype=np.float32
            )
        },
        spawn_subprocess=False,
        device=str(compute_context.device),
    )
    # raw: (1, c, d, h, w)
    # prediction: (1, [c,] d, h, w)

    # prepare writing
    pipeline += gp.Squeeze([raw, prediction])
    # raw: (c, d, h, w)
    # prediction: (c, d, h, w)
    # raw: (c, d, h, w)
    # prediction: (c, d, h, w)

    # write to zarr
    pipeline += gp.ZarrWrite(
        {prediction: prediction_array_identifier.dataset},
        prediction_array_identifier.container.parent,
        prediction_array_identifier.container.name,
    )

    # create reference batch request
    ref_request = gp.BatchRequest()
    ref_request.add(raw, input_size)
    ref_request.add(prediction, output_size)
    pipeline += gp.Scan(ref_request)

    # build pipeline and predict in complete output ROI

    with gp.build(pipeline):
        pipeline.request_batch(gp.BatchRequest())

    container = zarr.open(prediction_array_identifier.container)
    dataset = container[prediction_array_identifier.dataset]
    dataset.attrs["axes"] = (
        raw_array.axes if "c" in raw_array.axes else ["c"] + raw_array.axes
    )
예제 #19
0
def train(iterations):

    ##################
    # DECLARE ARRAYS #
    ##################

    # raw intensities
    raw = gp.ArrayKey('RAW')

    # objects labelled with unique IDs
    gt_labels = gp.ArrayKey('LABELS')

    # array of per-voxel affinities to direct neighbors
    gt_affs = gp.ArrayKey('AFFINITIES')

    # weights to use to balance the loss
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    # the predicted affinities
    pred_affs = gp.ArrayKey('PRED_AFFS')

    # the gredient of the loss wrt to the predicted affinities
    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')

    ####################
    # DECLARE REQUESTS #
    ####################

    with open('train_net_config.json', 'r') as f:
        net_config = json.load(f)

    # get the input and output size in world units (nm, in this case)
    voxel_size = gp.Coordinate((8, 8, 8))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt_affs, output_size)
    request.add(loss_weights, output_size)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request[pred_affs] = request[gt_affs]
    snapshot_request[pred_affs_gradients] = request[gt_affs]

    ##############################
    # ASSEMBLE TRAINING PIPELINE #
    ##############################

    pipeline = (

        # a tuple of sources, one for each sample (A, B, and C) provided by the
        # CREMI challenge
        tuple(

            # read batches from the HDF5 file
            gp.Hdf5Source(os.path.join(data_dir, 'fib.hdf'),
                          datasets={
                              raw: 'volumes/raw',
                              gt_labels: 'volumes/labels/neuron_ids'
                          }) +

            # convert raw to float in [0, 1]
            gp.Normalize(raw) +

            # chose a random location for each requested batch
            gp.RandomLocation()) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        gp.ElasticAugment([8, 8, 8], [0, 2, 2], [0, math.pi / 2.0],
                          prob_slip=0.05,
                          prob_shift=0.05,
                          max_misalign=25) +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(transpose_only=[1, 2]) +

        # scale and shift the intensity of the raw array
        gp.IntensityAugment(raw,
                            scale_min=0.9,
                            scale_max=1.1,
                            shift_min=-0.1,
                            shift_max=0.1,
                            z_section_wise=True) +

        # grow a boundary between labels
        gp.GrowBoundary(gt_labels, steps=3, only_xy=True) +

        # convert labels into affinities between voxels
        gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels,
                         gt_affs) +

        # create a weight array that balances positive and negative samples in
        # the affinity array
        gp.BalanceLabels(gt_affs, loss_weights) +

        # pre-cache batches from the point upstream
        gp.PreCache(cache_size=10, num_workers=5) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            'train_net',
            net_config['optimizer'],
            net_config['loss'],
            inputs={
                net_config['raw']: raw,
                net_config['gt_affs']: gt_affs,
                net_config['loss_weights']: loss_weights
            },
            outputs={net_config['pred_affs']: pred_affs},
            gradients={net_config['pred_affs']: pred_affs_gradients},
            save_every=10000) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                gt_labels: '/volumes/labels/neuron_ids',
                gt_affs: '/volumes/labels/affs',
                pred_affs: '/volumes/pred_affs',
                pred_affs_gradients: '/volumes/pred_affs_gradients'
            },
            output_dir='snapshots',
            output_filename='batch_{iteration}.hdf',
            every=1000,
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=1000))

    #########
    # TRAIN #
    #########

    print("Training for", iterations, "iterations")

    with gp.build(pipeline):
        for i in range(iterations):
            pipeline.request_batch(request)

    print("Finished")
예제 #20
0
    "cc_threshold": 0.50,
    "loc_type": "edt",
    "score_thr": 0,
    "score_type": "sum",
    "nms_radius": None
}

parameters = synful.detection.SynapseExtractionParameters(
    extract_type=parameter_dic['extract_type'],
    cc_threshold=parameter_dic['cc_threshold'],
    loc_type=parameter_dic['loc_type'],
    score_thr=parameter_dic['score_thr'],
    score_type=parameter_dic['score_type'],
    nms_radius=parameter_dic['nms_radius'])

gp.ArrayKey('M_PRED')
gp.ArrayKey('D_PRED')
gp.PointsKey('PRESYN')
gp.PointsKey('POSTSYN')


class TestSource(gp.BatchProvider):
    def __init__(self, m_pred, d_pred, voxel_size):
        self.voxel_size = voxel_size
        self.m_pred = m_pred
        self.d_pred = d_pred

    def setup(self):
        self.provides(
            gp.ArrayKeys.M_PRED,
            gp.ArraySpec(roi=gp.Roi((0, 0, 0), (200, 200, 200)),
예제 #21
0
    def create_train_pipeline(self, model):

        print(f"Creating training pipeline with batch size \
              {self.params['batch_size']}")

        filename = self.params['data_file']
        raw_dataset = self.params['dataset']['train']['raw']
        gt_dataset = self.params['dataset']['train']['gt']

        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])

        raw = gp.ArrayKey('RAW')
        gt_labels = gp.ArrayKey('LABELS')
        points = gp.GraphKey("POINTS")
        locations = gp.ArrayKey("LOCATIONS")
        predictions = gp.ArrayKey('PREDICTIONS')
        emb = gp.ArrayKey('EMBEDDING')

        raw_data = daisy.open_ds(filename, raw_dataset)
        source_roi = gp.Roi(raw_data.roi.get_offset(),
                            raw_data.roi.get_shape())
        source_voxel_size = gp.Coordinate(raw_data.voxel_size)
        out_voxel_size = gp.Coordinate(raw_data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(model.in_shape)
        out_roi = gp.Coordinate(model.base_encoder.out_shape[2:])
        is_2d = in_shape.dims() == 2

        in_shape = in_shape * out_voxel_size
        out_roi = out_roi * out_voxel_size
        out_shape = gp.Coordinate(
            (self.params["num_points"], *model.out_shape[2:]))

        context = (in_shape - out_roi) / 2
        gt_labels_out_shape = out_roi
        # Add fake 3rd dim
        if is_2d:
            source_voxel_size = gp.Coordinate((1, *source_voxel_size))
            source_roi = gp.Roi((0, *source_roi.get_offset()),
                                (raw_data.shape[0], *source_roi.get_shape()))
            context = gp.Coordinate((0, *context))
            gt_labels_out_shape = (1, *gt_labels_out_shape)

            points_roi = out_voxel_size * tuple((*self.params["point_roi"], ))
            points_pad = (0, *points_roi)
            context = gp.Coordinate((0, None, None))
        else:
            points_roi = source_voxel_size * tuple(self.params["point_roi"])
            points_pad = points_roi
            context = gp.Coordinate((None, None, None))

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {out_voxel_size}")
        logger.info(f"context: {context}")
        logger.info(f"out_voxel_size: {out_voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(points, points_roi)
        request.add(gt_labels, out_roi)
        request[locations] = gp.ArraySpec(nonspatial=True)
        request[predictions] = gp.ArraySpec(nonspatial=True)

        snapshot_request = gp.BatchRequest()
        snapshot_request[emb] = gp.ArraySpec(
            roi=gp.Roi((0, ) * in_shape.dims(),
                       gp.Coordinate((*model.base_encoder.out_shape[2:], )) *
                       out_voxel_size))

        source = (
            (gp.ZarrSource(filename, {
                raw: raw_dataset,
                gt_labels: gt_dataset
            },
                           array_specs={
                               raw:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size,
                                            interpolatable=True),
                               gt_labels:
                               gp.ArraySpec(roi=source_roi,
                                            voxel_size=source_voxel_size)
                           }),
             PointsLabelsSource(points, self.data, scale=source_voxel_size)) +
            gp.MergeProvider() + gp.Pad(raw, context) +
            gp.Pad(gt_labels, context) + gp.Pad(points, points_pad) +
            gp.RandomLocation(ensure_nonempty=points) +
            gp.Normalize(raw, self.params['norm_factor'])
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            # If 2d then source_roi = (1, input_shape) in order to select a RL
        )
        source = self._augmentation_pipeline(raw, source)

        pipeline = (
            source +
            # Batches seem to be rejected because points are chosen near the
            # edge of the points ROI and the augmentations remove them.
            # TODO: Figure out if this is an actual issue, and if anything can
            # be done.
            gp.Reject(ensure_nonempty=points) + SetDtype(gt_labels, np.int64) +
            # raw      : (source_roi)
            # gt_labels: (source_roi)
            # points   : (c=1, source_locations_shape)
            AddChannelDim(raw) + AddChannelDim(gt_labels)
            # raw      : (c=1, source_roi)
            # gt_labels: (c=2, source_roi)
            # points   : (c=1, source_locations_shape)
        )

        if is_2d:
            pipeline = (
                # Remove extra dim the 2d roi had
                pipeline + RemoveSpatialDim(raw) +
                RemoveSpatialDim(gt_labels) + RemoveSpatialDim(points)
                # raw      : (c=1, roi)
                # gt_labels: (c=1, roi)
                # points   : (c=1, locations_shape)
            )

        pipeline = (
            pipeline +
            FillLocations(raw, points, locations, is_2d=False, max_points=1) +
            gp.Stack(self.params['batch_size']) + gp.PreCache() +
            # raw      : (b, c=1, roi)
            # gt_labels: (b, c=1, roi)
            # locations: (b, c=1, locations_shape)
            # (which is what train requires)
            gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={
                    'raw': raw,
                    'points': locations
                },
                loss_inputs={
                    0: predictions,
                    1: gt_labels,
                    2: locations
                },
                outputs={
                    0: predictions,
                    1: emb
                },
                array_specs={
                    predictions: gp.ArraySpec(nonspatial=True),
                    emb: gp.ArraySpec(voxel_size=out_voxel_size)
                },
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every) +
            # everything is 2D at this point, plus extra dimensions for
            # channels and batch
            # raw        : (b, c=1, roi)
            # gt_labels  : (b, c=1, roi)
            # predictions: (b, num_points)
            gp.Snapshot(output_dir=self.logdir + '/snapshots',
                        output_filename='it{iteration}.hdf',
                        dataset_names={
                            raw: 'raw',
                            gt_labels: 'gt_labels',
                            predictions: 'predictions',
                            emb: 'emb'
                        },
                        additional_request=snapshot_request,
                        every=self.params['save_every']) +
            InspectBatch('END') + gp.PrintProfilingStats(every=500))

        return pipeline, request
예제 #22
0
def train(until):

    model = SpineUNet()
    loss = torch.nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    input_size = (8, 96, 96)

    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    affs = gp.ArrayKey('AFFS')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

    pipeline = (
        (
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample1/raw',
                    labels: 'train/sample1/labels'
                }),
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample2/raw',
                    labels: 'train/sample2/labels'
                }),
            gp.ZarrSource(
                'data/20200201.zarr',
                {
                    raw: 'train/sample3/raw',
                    labels: 'train/sample3/labels'
                })
        ) +
        gp.RandomProvider() +
        gp.Normalize(raw) +
        gp.RandomLocation() +
        gp.SimpleAugment(transpose_only=(1, 2)) +
        gp.ElasticAugment((2, 10, 10), (0.0, 0.5, 0.5), [0, math.pi]) +
        gp.AddAffinities(
            [(1, 0, 0), (0, 1, 0), (0, 0, 1)],
            labels,
            affs) +
        gp.Normalize(affs, factor=1.0) +
        #gp.PreCache(num_workers=1) +
        # raw: (d, h, w)
        # affs: (3, d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        # affs: (1, 3, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        # affs: (1, 3, d, h, w)
        gp_torch.Train(
            model,
            loss,
            optimizer,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            loss_inputs={0: affs_predicted, 1: affs},
            save_every=10000) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs: (3, d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Snapshot(
            {
                raw: 'raw',
                labels: 'labels',
                affs: 'affs',
                affs_predicted: 'affs_predicted'
            },
            every=500,
            output_filename='iteration_{iteration}.hdf')
    )

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(affs, input_size)
    request.add(affs_predicted, input_size)

    with gp.build(pipeline):
        for i in range(until):
            pipeline.request_batch(request)
예제 #23
0
    def make_pipeline(self):
        raw = gp.ArrayKey('RAW')
        pred_affs = gp.ArrayKey('PREDICTIONS')

        source_shape = zarr.open(self.data_file)[self.dataset].shape
        raw_roi = gp.Roi(np.zeros(len(source_shape[1:])), source_shape[1:])

        data = daisy.open_ds(self.data_file, self.dataset)
        source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        voxel_size = gp.Coordinate(data.voxel_size)

        # Get in and out shape
        in_shape = gp.Coordinate(self.model.in_shape)
        out_shape = gp.Coordinate(self.model.out_shape[2:])

        is_2d = in_shape.dims() == 2

        in_shape = in_shape * voxel_size
        out_shape = out_shape * voxel_size

        logger.info(f"source roi: {source_roi}")
        logger.info(f"in_shape: {in_shape}")
        logger.info(f"out_shape: {out_shape}")
        logger.info(f"voxel_size: {voxel_size}")

        request = gp.BatchRequest()
        request.add(raw, in_shape)
        request.add(pred_affs, out_shape)

        context = (in_shape - out_shape) / 2

        source = (gp.ZarrSource(self.data_file, {
            raw: self.dataset,
        },
                                array_specs={
                                    raw:
                                    gp.ArraySpec(roi=source_roi,
                                                 interpolatable=False)
                                }))

        in_dims = len(self.model.in_shape)
        if is_2d:
            # 2D: [samples, y, x] or [samples, channels, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 1)
            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=1))
            # raw [samples, channels, y, x]
        else:
            # 3D: [z, y, x] or [channel, z, y, x] or [sample, channel, z, y, x]
            needs_channel_fix = (len(data.shape) - in_dims == 0)
            needs_batch_fix = (len(data.shape) - in_dims <= 1)

            if needs_channel_fix:
                source = (source + AddChannelDim(raw, axis=0))
            # Batch fix
            if needs_batch_fix:
                source = (source + AddChannelDim(raw))
            # raw: [sample, channels, z, y, x]

        with gp.build(source):
            raw_roi = source.spec[raw].roi
            logger.info(f"raw_roi: {raw_roi}")

        pipeline = (source +
                    gp.Normalize(raw, factor=self.params['norm_factor']) +
                    gp.Pad(raw, context) + gp.PreCache() + gp.torch.Predict(
                        self.model,
                        inputs={'raw': raw},
                        outputs={0: pred_affs},
                        array_specs={pred_affs: gp.ArraySpec(roi=raw_roi)}))

        pipeline = (pipeline + gp.ZarrWrite({
            pred_affs: 'predictions',
        },
                                            output_dir=self.curr_log_dir,
                                            output_filename='predictions.zarr',
                                            compression_type='gzip') +
                    gp.Scan(request))

        return pipeline, request, pred_affs
예제 #24
0
args = parser.parse_args()
directory = args.directory
max_dist = args.max_dist

# download some test data
url = "https://cremi.org/static/data/sample_A_20160501.hdf"
urllib.request.urlretrieve(url, os.path.join(directory, 'sample_A.hdf'))

# configure where to store results
result_file = "sample_A.n5"
ds_name = "neuron_ids_stardists_downsampled"
if max_dist is not None:
    ds_name += "_max{0:}".format(max_dist)

# declare arrays to use
raw = gp.ArrayKey("RAW")
labels = gp.ArrayKey("LABELS")
stardists = gp.ArrayKey("STARDIST")

# prepare requests for scanning (i.e. chunks) and overall
scan_request = gp.BatchRequest()
scan_request[stardists] = gp.Roi(
    gp.Coordinate((0, 0, 0)),
    gp.Coordinate((40, 100, 100)) * gp.Coordinate((40, 8, 8)))
voxel_size = gp.Coordinate((40, 4, 4))
request = gp.BatchRequest(
)  # empty request will loop over whole area with scanning
request[stardists] = gp.Roi(
    gp.Coordinate((40, 200, 200)) * gp.Coordinate((40, 8, 8)),
    gp.Coordinate((40, 100, 100)) * gp.Coordinate((40, 8, 8)) * gp.Coordinate(
        (2, 2, 2)))
예제 #25
0
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names,
                          loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    neuron_radius = setup_config["NEURON_RADIUS"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    gt_fg = gp.ArrayKey("GT_FG")
    fg_pred = gp.ArrayKey("FG_PRED")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")
    ratio_pos = gp.ArrayKey("RATIO_POS")
    ratio_neg = gp.ArrayKey("RATIO_NEG")
    dist = gp.ArrayKey("DIST")
    num_pos_pairs = gp.ArrayKey("NUM_POS")
    num_neg_pairs = gp.ArrayKey("NUM_NEG")

    # add request
    request = gp.BatchRequest()
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    request.add(labels_fg, output_size)

    # tensorflow requests
    # snapshot_request.add(raw, input_size)  # input_size request for positioning
    # snapshot_request.add(embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size)
    # snapshot_request.add(maxima, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    # snapshot_request[emst] = gp.ArraySpec()
    # snapshot_request[edges_u] = gp.ArraySpec()
    # snapshot_request[edges_v] = gp.ArraySpec()
    # snapshot_request[ratio_pos] = gp.ArraySpec()
    # snapshot_request[ratio_neg] = gp.ArraySpec()
    # snapshot_request[dist] = gp.ArraySpec()
    # snapshot_request[num_pos_pairs] = gp.ArraySpec()
    # snapshot_request[num_neg_pairs] = gp.ArraySpec()

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched) + RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + GrowLabels(labels, radii=[neuron_radius * micron_scale])
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(labels, labels_fg) +
        BinarizeGt(labels_fg, labels_fg_bin) +
        gp.BalanceLabels(labels_fg_bin, loss_weights) +
        gp.PreCache(cache_size=cache_size, num_workers=num_workers) +
        gp.tensorflow.Train(
            "train_net",
            optimizer=create_custom_loss(mknet_tensor_names, setup_config),
            loss=None,
            inputs={
                mknet_tensor_names["loss_weights"]: loss_weights,
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_labels"]: labels_fg,
            },
            outputs={
                mknet_tensor_names["embedding"]: embedding,
                mknet_tensor_names["fg"]: fg,
                loss_tensor_names["fg_pred"]: fg_pred,
                loss_tensor_names["maxima"]: maxima,
                loss_tensor_names["gt_fg"]: gt_fg,
                loss_tensor_names["emst"]: emst,
                loss_tensor_names["edges_u"]: edges_u,
                loss_tensor_names["edges_v"]: edges_v,
                loss_tensor_names["ratio_pos"]: ratio_pos,
                loss_tensor_names["ratio_neg"]: ratio_neg,
                loss_tensor_names["dist"]: dist,
                loss_tensor_names["num_pos_pairs"]: num_pos_pairs,
                loss_tensor_names["num_neg_pairs"]: num_neg_pairs,
            },
            gradients={
                mknet_tensor_names["embedding"]: gradient_embedding,
                mknet_tensor_names["fg"]: gradient_fg,
            },
            save_every=checkpoint_every,
            summary="Merge/MergeSummary:0",
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                # labeled data
                labels: "volumes/labels",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                embedding: "volumes/embedding",
                fg: "volumes/fg",
                maxima: "volumes/maxima",
                gt_fg: "volumes/gt_fg",
                fg_pred: "volumes/fg_pred",
                gradient_embedding: "volumes/gradient_embedding",
                gradient_fg: "volumes/gradient_fg",
                # output trees
                emst: "emst",
                edges_u: "edges_u",
                edges_v: "edges_v",
                # output debug data
                ratio_pos: "ratio_pos",
                ratio_neg: "ratio_neg",
                dist: "dist",
                num_pos_pairs: "num_pos_pairs",
                num_neg_pairs: "num_neg_pairs",
                loss_weights: "volumes/loss_weights",
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
예제 #26
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_threeclass = gp.ArrayKey('GT_THREECLASS')

    loss_weights_threeclass = gp.ArrayKey('LOSS_WEIGHTS_THREECLASS')

    pred_threeclass = gp.ArrayKey('PRED_THREECLASS')

    pred_threeclass_gradients = gp.ArrayKey('PRED_THREECLASS_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_threeclass, output_shape_world)
    request.add(anchor, output_shape_world)
    request.add(loss_weights_threeclass, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(gt_threeclass, output_shape_world)
    snapshot_request.add(pred_threeclass, output_shape_world)
    # snapshot_request.add(pred_threeclass_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for {} not implemented".format(
            kwargs['input_format']))

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    # padR = 46
    # padGT = 32

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    pipeline = (
        tuple(
            # read batches from the HDF5 file
            sourceNode(
                fls[t] + "." + kwargs['input_format'],
                datasets={
                    raw: 'volumes/raw',
                    gt_threeclass: 'volumes/gt_threeclass',
                    anchor: 'volumes/gt_threeclass',
                },
                array_specs={
                    raw: gp.ArraySpec(interpolatable=True),
                    gt_threeclass: gp.ArraySpec(interpolatable=False),
                    anchor: gp.ArraySpec(interpolatable=False)
                }
            )
            + gp.MergeProvider()
            + gp.Pad(raw, None)
            + gp.Pad(gt_threeclass, None)
            + gp.Pad(anchor, gp.Coordinate((2,2,2)))


            # chose a random location for each requested batch
            + gp.RandomLocation()

            for t in range(ln)
        ) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        # grow a boundary between labels
        # TODO: check
        # gp.GrowBoundary(
        #     gt_threeclass,
        #     steps=1,
        #     only_xy=False) +

        gp.BalanceLabels(
            gt_threeclass,
            loss_weights_threeclass,
            num_classes=3) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['anchor']: anchor,
                net_names['gt_threeclass']: gt_threeclass,
                net_names['loss_weights_threeclass']: loss_weights_threeclass
            },
            outputs={
                net_names['pred_threeclass']: pred_threeclass,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                net_names['pred_threeclass']: pred_threeclass_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_threeclass: '/volumes/gt_threeclass',
                pred_threeclass: '/volumes/pred_threeclass',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
예제 #27
0
def predict_3d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # ensure validation ROI is at least the size of the network input
    roi = raw_data.roi.grow(input_size / 2, input_size / 2)

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=roi)
    total_request[prediction] = gp.ArraySpec(roi=roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=roi)
        total_request[target] = gp.ArraySpec(roi=roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
예제 #28
0
from __future__ import print_function
import gunpowder as gp
from gunpowder import *
from gunpowder.ext import malis
import tensorflow as tf
import mala

############
# Training #
############

# Declare Arrays
raw = gp.ArrayKey('RAW')
gt = gp.ArrayKey('GT')
mask = gp.ArrayKey('mask')
prediction = gp.ArrayKey('prediction')
grad = gp.ArrayKey('gradient')

# define training values
input_shape = (40, 300, 300)
voxel_size = (40, 4, 4)

#define network parameters these will be used to define the feed dict for the network
raw_tf = tf.placeholder(tf.float32, shape=input_shape)
raw_batched = tf.reshape(raw_tf, (1, 1) + input_shape)

unet = mala.networks.unet(raw_batched, 3, 3, [[1, 1, 1], [1, 1, 1], [3, 3, 3]])

#since we want binary predictions we will be using 1 feature map per
output = mala.networks.conv_pass(unet,
                                 kernel_size=1,
예제 #29
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_mask = gp.ArrayKey('GT_MASK')
    gt_dt = gp.ArrayKey('GT_DT')
    pred_dt = gp.ArrayKey('PRED_DT')
    loss_gradient = gp.ArrayKey('LOSS_GRADIENT')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_mask, output_shape)
    request.add(gt_dt, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    snapshot_request.add(gt_dt, output_shape)
    snapshot_request.add(pred_dt, output_shape)
    snapshot_request.add(loss_gradient, output_shape)

    # specify data source
    data_sources = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw: sample + '/raw',
                        gt_mask: sample + '/fg'
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask, np.uint8) +
                gp.Pad(raw, context) +
                gp.Pad(gt_mask, context) +
                gp.RandomLocation()
                for sample in f)

    pipeline = (
            data_sources +
            gp.RandomProvider() +
            gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) +
            DistanceTransform(gt_mask, gt_dt, 3) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2,-1) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=5) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt_dt']: gt_dt,
                },
                outputs={
                    net_names['pred_dt']: pred_dt,
                },
                gradients={
                    net_names['pred_dt']: loss_gradient,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    gt_mask: 'volumes/gt_mask',
                    gt_dt: 'volumes/gt_dt',
                    pred_dt: 'volumes/pred_dt',
                    loss_gradient: 'volumes/gradient',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2000) +
            gp.PrintProfilingStats(every=500)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
def batch__aug_data_generator(input_path,
                              batch_size=12,
                              voxel_shape=[1, 1, 1],
                              input_shape=[240, 240, 4],
                              output_shape=[240, 240, 4],
                              without_background=False,
                              mix_output=False,
                              validate=False,
                              seq=None):
    raw = gp.ArrayKey('raw')
    gt = gp.ArrayKey('ground_truth')
    files = os.listdir(input_path)
    files = [os.path.join(input_path, f) for f in files]
    pipeline = (
        tuple(
            gp.ZarrSource(
                files[t],  # the zarr container
                {
                    raw: 'raw',
                    gt: 'ground_truth'
                },  # which dataset to associate to the array key
                {
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 dtype=np.dtype('float32'),
                                 voxel_size=voxel_shape),
                    gt:
                    gp.ArraySpec(interpolatable=True,
                                 dtype=np.dtype('float32'),
                                 voxel_size=voxel_shape)
                }  # meta-information
            ) + gp.RandomLocation()
            for t in range(len(files))) + gp.RandomProvider()
        #    +gp.Stack(batch_size)
    )
    input_size = gp.Coordinate(input_shape)
    output_size = gp.Coordinate(output_shape)

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt, input_size)
    diff = input_shape[1] - output_shape[1]
    diff = int(diff / 2)
    max_p = input_shape[1] - diff
    different_shape = diff > 0
    if different_shape:
        print('Difference padding: {}'.format(diff))
    with gp.build(pipeline):
        while 1:
            b = 0
            imgs = []
            masks = []
            while b < batch_size:
                valid = False
                batch = pipeline.request_batch(request)
                if validate:
                    valid = validate_mask(batch[gt].data)
                else:
                    valid = True

                while (valid == False):
                    batch = pipeline.request_batch(request)
                    valid = validate_mask(batch[gt].data)
                im = batch[raw].data
                out = batch[gt].data
                # if different_shape:
                #     out = out[diff:max_p,diff:max_p,:]
                if without_background:
                    out = out[:, :, 1:4]
                if mix_output:
                    out = out.argmax(axis=3).astype(float)
                imgs.append(im)
                masks.append(out)
                b = b + 1
            imgs = np.asarray(imgs)
            masks = np.asarray(masks)
            if seq is not None:
                imgs, masks = augmentation(imgs, masks, seq)
            if different_shape:
                out = []
                for m in masks:
                    out.append(m[diff:max_p, diff:max_p, :])
                masks = np.asarray(out)
            yield imgs, masks