Example #1
0
    def create_train_pipeline(self, model):
        optimizer = self.params['optimizer'](model.parameters(),
                                             **self.params['optimizer_kwargs'])
        points = gp.ArrayKey('POINTS')
        predictions = gp.ArrayKey("PREDICTIONS")
        gt_labels = gp.ArrayKey('LABELS')

        request = gp.BatchRequest()
        # Because of PointsLabelsSource we can keep everything as nonspatial
        request[points] = gp.ArraySpec(nonspatial=True)
        request[predictions] = gp.ArraySpec(nonspatial=True)
        request[gt_labels] = gp.ArraySpec(nonspatial=True)

        pipeline = (
            PointsLabelsSource(points, self.data, gt_labels, self.labels, 1) +
            gp.Stack(self.params['batch_size']) + gp.torch.Train(
                model,
                self.loss,
                optimizer,
                inputs={'points': points},
                loss_inputs={
                    0: predictions,
                    1: gt_labels
                },
                outputs={0: predictions},
                checkpoint_basename=self.logdir + '/checkpoints/model',
                save_every=self.params['save_every'],
                log_dir=self.logdir,
                log_every=self.log_every))

        return pipeline, request
Example #2
0
    def build_source(self):
        data = daisy.open_ds(filename, key)

        if self.time_window is None:
            source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        else:
            offs = list(data.roi.get_offset())
            offs[1] += self.time_window[0]
            sh = list(data.roi.get_shape())
            offs[1] = self.time_window[1] - self.time_window[0]
            source_roi = gp.Roi(tuple(offs), tuple(sh))

        voxel_size = gp.Coordinate(data.voxel_size)

        return gp.ZarrSource(filename,
                             {
                                 self.raw_0: key,
                                 self.raw_1: key
                             },
                             array_specs={
                                 self.raw_0: gp.ArraySpec(
                                     roi=source_roi,
                                     voxel_size=voxel_size,
                                     interpolatable=True),
                                 self.raw_1: gp.ArraySpec(
                                     roi=source_roi,
                                     voxel_size=voxel_size,
                                     interpolatable=True)
                             })
Example #3
0
def create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter,
                  gt_neurons):
    data_sources = tuple((
        Hdf5PointsSource(os.path.join(data_dir_syn, sample + '.hdf'),
                         datasets={
                             presyn: 'annotations',
                             postsyn: 'annotations'
                         },
                         rois={
                             presyn: cremi_roi,
                             postsyn: cremi_roi
                         }),
        Hdf5PointsSource(
            os.path.join(data_dir_syn, sample + '.hdf'),
            datasets={dummypostsyn: 'annotations'},
            rois={
                # presyn: cremi_roi,
                dummypostsyn: cremi_roi
            },
            kind='postsyn'),
        gp.Hdf5Source(os.path.join(data_dir, sample + '.hdf'),
                      datasets={
                          raw: 'volumes/raw',
                          gt_neurons: 'volumes/labels/neuron_ids',
                      },
                      array_specs={
                          raw: gp.ArraySpec(interpolatable=True),
                          gt_neurons: gp.ArraySpec(interpolatable=False),
                      })))
    source_pip = data_sources + gp.MergeProvider() + gp.Normalize(
        raw) + gp.RandomLocation(ensure_nonempty=dummypostsyn,
                                 p_nonempty=parameter['reject_probability'])
    return source_pip
Example #4
0
    def setup(self):

        self.ndims = self.data.shape[1]

        if self.points_spec is not None:
            self.provides(self.points, self.points_spec)
        elif isinstance(self.points, gp.ArrayKey):
            self.provides(self.points, gp.ArraySpec(voxel_size=((1, ))))
        elif isinstance(self.points, gp.GraphKey):
            print(self.ndims)
            min_bb = gp.Coordinate(
                np.floor(np.amin(self.data[:, :self.ndims], 0)))
            max_bb = gp.Coordinate(
                np.ceil(np.amax(self.data[:, :self.ndims], 0)) + 1)

            roi = gp.Roi(min_bb, max_bb - min_bb)
            logger.debug(f"Bounding Box: {roi}")

            self.provides(self.points, gp.GraphSpec(roi=roi))

        if self.labels is not None:
            assert isinstance(self.labels, gp.ArrayKey), \
                   f"Label key must be an ArrayKey, \
                     was given {type(self.labels)}"

            if self.labels_spec is not None:
                self.provides(self.labels, self.labels_spec)
            else:
                self.provides(self.labels, gp.ArraySpec(voxel_size=((1, ))))
Example #5
0
def predict(iteration,path_to_dataGP):
   
  
    input_size = (8, 96, 96)
    output_size = (4, 64, 64)
    amount_size = gp.Coordinate((2, 16, 16))
    model = SpineUNet(crop_output='output_size')

    raw = gp.ArrayKey('RAW')
    affs_predicted = gp.ArrayKey('AFFS_PREDICTED')

                                
    reference_request = gp.BatchRequest()
    reference_request.add(raw, input_size)
    reference_request.add(affs_predicted, output_size)
    
    source = gp.ZarrSource(
        path_to_dataGP,
        {
            raw: 'validate/sample1/raw'
        } 
    )
  
    with gp.build(source):
        source_roi = source.spec[raw].roi
    request = gp.BatchRequest()
    request[raw] = gp.ArraySpec(roi=source_roi)
    request[affs_predicted] = gp.ArraySpec(roi=source_roi)

    pipeline = (
        source +
       
        gp.Pad(raw,amount_size) +
        gp.Normalize(raw) +
        # raw: (d, h, w)
        gp.Stack(1) +
        # raw: (1, d, h, w)
        AddChannelDim(raw) +
        # raw: (1, 1, d, h, w)
        gp_torch.Predict(
            model,
            inputs={'x': raw},
            outputs={0: affs_predicted},
            checkpoint=f'C:/Users/filip/spine_yodl/model_checkpoint_{iteration}') +
        RemoveChannelDim(raw) +
        RemoveChannelDim(raw) +
        RemoveChannelDim(affs_predicted) +
        # raw: (d, h, w)
        # affs_predicted: (3, d, h, w)
        gp.Scan(reference_request)
    )

    with gp.build(pipeline):
        prediction = pipeline.request_batch(request)

    return prediction[raw].data, prediction[affs_predicted].data
Example #6
0
 def setup(self):
     self.provides(
         gp.ArrayKeys.M_PRED,
         gp.ArraySpec(roi=gp.Roi((0, 0, 0), (200, 200, 200)),
                      voxel_size=self.voxel_size,
                      interpolatable=False))
     self.provides(
         gp.ArrayKeys.D_PRED,
         gp.ArraySpec(roi=gp.Roi((0, 0, 0), (200, 200, 200)),
                      voxel_size=self.voxel_size,
                      interpolatable=False))
Example #7
0
    def provide(self, request):
        roi_array = request[gp.ArrayKeys.M_PRED].roi
        batch = gp.Batch()
        batch.arrays[gp.ArrayKeys.M_PRED] = gp.Array(
            self.m_pred[(roi_array / self.voxel_size).to_slices()],
            spec=gp.ArraySpec(roi=roi_array, voxel_size=self.voxel_size))
        slices = (roi_array / self.voxel_size).to_slices()
        batch.arrays[gp.ArrayKeys.D_PRED] = gp.Array(
            self.d_pred[:, slices[0], slices[1], slices[2]],
            spec=gp.ArraySpec(roi=roi_array, voxel_size=self.voxel_size))

        return batch
Example #8
0
    def setup(self):

        # we provide cage maps everywhere where we have a segmentation:
        roi = self.spec[self.seg].roi.copy()
        voxel_size = self.spec[self.seg].voxel_size
        self.provides(
            self.cage_map,
            gp.ArraySpec(roi=roi, dtype=np.uint16, voxel_size=voxel_size))

        # same for the density map
        roi = self.spec[self.seg].roi.copy()
        self.provides(
            self.density_map,
            gp.ArraySpec(roi=roi, dtype=np.float32, voxel_size=voxel_size))
Example #9
0
    def setup(self):

        self.provides(
            self.raw,
            gp.ArraySpec(roi=gp.Roi((0, 0), (1000, 1000)),
                         dtype=np.uint8,
                         interpolatable=True,
                         voxel_size=(1, 1)))
        self.provides(
            self.gt,
            gp.ArraySpec(roi=gp.Roi((0, 0), (1000, 1000)),
                         dtype=np.uint64,
                         interpolatable=False,
                         voxel_size=(1, 1)))
Example #10
0
    def setup(self):
        provided_spec = gp.ArraySpec(
            roi=self.spec[self.gt_key].roi,
            voxel_size=self.spec[self.gt_key].voxel_size,
            interpolatable=self.predictor.output_array_type.interpolatable,
        )
        self.provides(self.target_key, provided_spec)

        provided_spec = gp.ArraySpec(
            roi=self.spec[self.gt_key].roi,
            voxel_size=self.spec[self.gt_key].voxel_size,
            interpolatable=True,
        )
        self.provides(self.weights_key, provided_spec)
Example #11
0
    def __init__(self, voxel_size):
        self.voxel_size = gp.Coordinate(voxel_size)
        self.roi = gp.Roi((0, 0, 0), (10, 10, 10)) * self.voxel_size

        self.raw = gp.ArrayKey("RAW")
        self.labels = gp.ArrayKey("LABELS")

        self.array_spec_raw = gp.ArraySpec(roi=self.roi,
                                           voxel_size=self.voxel_size,
                                           dtype='uint8',
                                           interpolatable=True)

        self.array_spec_labels = gp.ArraySpec(roi=self.roi,
                                              voxel_size=self.voxel_size,
                                              dtype='uint64',
                                              interpolatable=False)
Example #12
0
 def setup(self):
     self.enable_autoskip()
     self.provides(self.output, gp.ArraySpec(nonspatial=True))
     if self.details is not None:
         self.provides(self.details, self.spec[self.mst].copy())
     if self.output_graph is not None:
         self.provides(self.output_graph, self.spec[self.mst].copy())
Example #13
0
    def setup(self):

        self.provides(
            self.array_key,
            gp.ArraySpec(roi=gp.Roi(offset=gp.Coordinate(
                (-10000, -10000, -10000)),
                                    shape=gp.Coordinate(
                                        (20000, 20000, 20000))),
                         voxel_size=(1, 1, 1)))
Example #14
0
    def __init__(self, filename,
                 key,
                 density=None,
                 channels=0,
                 shape=(16, 256, 256),
                 time_window=None,
                 add_sparse_mosaic_channel=True,
                 random_rot=False):

        self.filename = filename
        self.key = key
        self.shape = shape
        self.density = density
        self.raw = gp.ArrayKey('RAW_0')
        self.add_sparse_mosaic_channel = add_sparse_mosaic_channel
        self.random_rot = random_rot
        self.channels = channels

        data = daisy.open_ds(filename, key)

        if time_window is None:
            source_roi = gp.Roi(data.roi.get_offset(), data.roi.get_shape())
        else:
            offs = list(data.roi.get_offset())
            offs[1] += time_window[0]
            sh = list(data.roi.get_shape())
            offs[1] = time_window[1] - time_window[0]
            source_roi = gp.Roi(tuple(offs), tuple(sh))

        voxel_size = gp.Coordinate(data.voxel_size)

        self.pipeline = gp.ZarrSource(
            filename,
            {
                self.raw: key
            },
            array_specs={
                self.raw: gp.ArraySpec(
                    roi=source_roi,
                    voxel_size=voxel_size,
                    interpolatable=True)
            }) + gp.RandomLocation() + IntensityDiffFilter(self.raw, 0, min_distance=0.1, channels=Slice(None))

        # add  augmentations
        self.pipeline = self.pipeline + gp.ElasticAugment([40, 40],
                                                          [2, 2],
                                                          [0, math.pi / 2.0],
                                                          prob_slip=-1,
                                                          spatial_dims=2)



        self.pipeline.setup()
        np.random.seed(os.getpid() + int(time.time()))
Example #15
0
    def prepare(self, request):

        context = self.context
        dims = request[self.srcpoints].roi.dims()

        assert type(context) == list
        if len(context) == 1:
            context = context * dims

        # request array in a larger area to get predictions from outside
        # write roi
        m_roi = request[self.srcpoints].roi.grow(gp.Coordinate(context),
                                                 gp.Coordinate(context))

        # however, restrict the request to the array actually provided
        # m_roi = m_roi.intersect(self.spec[self.m_array].roi)
        request[self.m_array] = gp.ArraySpec(roi=m_roi)

        # Do the same for the direction vector array.
        request[self.d_array] = gp.ArraySpec(roi=m_roi)
Example #16
0
def get_requests(config, blocks, raw, emb_pred, labels, gt):
    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape
    diff = input_size - output_size

    cube_rois = [get_cube_roi(config, block) for block in blocks]

    requests = []
    for cube_roi in cube_rois:
        context_roi = cube_roi.grow(diff // 2, diff // 2)
        request = gp.BatchRequest()
        request[raw] = gp.ArraySpec(roi=context_roi)
        request[emb_pred] = gp.ArraySpec(roi=cube_roi)
        request[labels] = gp.ArraySpec(roi=cube_roi)
        request[gt] = gp.GraphSpec(roi=cube_roi)
        requests.append(request)
    return requests
Example #17
0
    def process(self, batch, request):

        spec = self.spec[self.fg].copy()
        voxel_size = (1, ) + spec.voxel_size
        merged = np.stack([batch[self.fg].data, batch[self.bg].data], axis=0)

        batch[self.raw] = gp.Array(
            data=merged.astype(spec.dtype),
            spec=gp.ArraySpec(dtype=spec.dtype,
                              roi=Roi((0, 0, 0, 0), merged.shape) * voxel_size,
                              interpolatable=True,
                              voxel_size=voxel_size))
Example #18
0
def validation_data_sources_from_snapshots(config, blocks):
    validation_blocks = Path(config["VALIDATION_BLOCKS"])

    raw = gp.ArrayKey("RAW")
    ground_truth = gp.GraphKey("GROUND_TRUTH")
    labels = gp.ArrayKey("LABELS")

    voxel_size = gp.Coordinate(config["VOXEL_SIZE"])
    input_shape = gp.Coordinate(config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(config["OUTPUT_SHAPE"])
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    block_pipelines = []
    for block in blocks:

        pipelines = (
            SnapshotSource(
                validation_blocks / f"block_{block}.hdf",
                {
                    labels: "volumes/labels",
                    ground_truth: "points/gt"
                },
                directed={ground_truth: True},
            ),
            SnapshotSource(validation_blocks / f"block_{block}.hdf",
                           {raw: "volumes/raw"}),
        )

        cube_roi = get_cube_roi(config, block)

        request = gp.BatchRequest()
        input_roi = cube_roi.grow((input_size - output_size) // 2,
                                  (input_size - output_size) // 2)
        request[raw] = gp.ArraySpec(input_roi)
        request[ground_truth] = gp.GraphSpec(cube_roi)
        request[labels] = gp.ArraySpec(cube_roi)

        block_pipelines.append((pipelines, request))
    return block_pipelines, (raw, labels, ground_truth)
Example #19
0
 def process(self, batch, request):
     final_scores = {}
     for key, array in batch.items():
         if "SCORE" in str(key):
             block = int(str(key).split("_")[1])
             final_scores[block] = array.data
     final_scores = [
         final_scores[block] for block in range(1, 26)
         if block in final_scores
     ]
     outputs = gp.Batch()
     outputs[self.output] = gp.Array(np.array(final_scores),
                                     gp.ArraySpec(nonspatial=True))
     return outputs
Example #20
0
def evaluate_affs(pred_labels, gt_labels, return_results=False):

    results = rand_voi(gt_labels.data, pred_labels.data)
    results["voi_sum"] = results["voi_split"] + results["voi_merge"]

    scores = {"sample": results, "average": results}

    if return_results:
        results = {
            "pred_labels":
            gp.Array(
                pred_labels.data.astype(np.uint64),
                gp.ArraySpec(roi=pred_labels.spec.roi,
                             voxel_size=pred_labels.spec.voxel_size)),
            "gt_labels":
            gp.Array(
                gt_labels.data.astype(np.uint64),
                gp.ArraySpec(roi=gt_labels.spec.roi,
                             voxel_size=gt_labels.spec.voxel_size)),
        }

        return scores, results

    return scores
Example #21
0
def test_gp_dacapo_array_source(array_config):

    # Create Array from config
    array = array_config.array_type(array_config)

    # Make sure the DaCapoArraySource can properly read
    # the data in `array`
    key = gp.ArrayKey("TEST")
    source_node = DaCapoArraySource(array, key)

    with gp.build(source_node):
        request = gp.BatchRequest()
        request[key] = gp.ArraySpec(roi=array.roi)
        batch = source_node.request_batch(request)
        data = batch[key].data
        assert (data - array[array.roi]).sum() == 0
Example #22
0
    def test_pipeline3(self):
        array_key = gp.ArrayKey("TEST_ARRAY")
        points_key = gp.PointsKey("TEST_POINTS")
        voxel_size = gp.Coordinate((1, 1))
        spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file,
                                    {array_key: 'testdata'},
                                    array_specs={array_key: spec})
        csv_source = gp.CsvPointsSource(
            self.fake_points_file, points_key,
            gp.PointsSpec(
                roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0))))

        request = gp.BatchRequest()
        shape = gp.Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = gp.ShiftAugment(prob_slip=0.2,
                                     prob_shift=0.2,
                                     sigma=5,
                                     shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() +
                    gp.RandomLocation(ensure_nonempty=points_key) + shift_node)
        with gp.build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = request[points_key].data
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points.values()
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(result_val, list(result_points.values()), target_vals,
                        self.fake_points))
Example #23
0
    def test_prepare1(self):

        key = gp.ArrayKey("TEST_ARRAY")
        spec = gp.ArraySpec(voxel_size=gp.Coordinate((1, 1)),
                            interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'},
                                    array_specs={key: spec})

        request = gp.BatchRequest()
        shape = gp.Coordinate((3, 3))
        request.add(key, shape, voxel_size=gp.Coordinate((1, 1)))

        shift_node = gp.ShiftAugment(sigma=1, shift_axis=0)
        with gp.build((hdf5_source + shift_node)):
            shift_node.prepare(request)
            self.assertTrue(shift_node.ndim == 2)
            self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0]))
Example #24
0
    def test_pipeline2(self):

        key = gp.ArrayKey("TEST_ARRAY")
        spec = gp.ArraySpec(voxel_size=gp.Coordinate((3, 1)),
                            interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file, {key: 'testdata'},
                                    array_specs={key: spec})

        request = gp.BatchRequest()
        shape = gp.Coordinate((3, 3))
        request.add(key, shape, voxel_size=gp.Coordinate((3, 1)))

        shift_node = gp.ShiftAugment(prob_slip=0.2,
                                     prob_shift=0.2,
                                     sigma=1,
                                     shift_axis=0)
        with gp.build((hdf5_source + shift_node)) as b:
            b.request_batch(request)
Example #25
0
    def __read_spec(self, array_key):

        if array_key in self.array_specs:
            spec = self.array_specs[array_key].copy()
        else:
            spec = gp.ArraySpec()
        assert spec.voxel_size is not None, "Voxel size needs to be given"

        self.ndims = len(spec.voxel_size)

        if spec.roi is None:
            roi = gp.Roi(gp.Coordinate((0, ) * self.ndims),
                         shape=gp.Coordinate((1, ) * self.ndims))
            roi.set_shape(None)
            spec.roi = roi

        arr = self.func((2, ) * self.ndims)
        if spec.dtype is not None:
            assert spec.dtype == arr.dtype, (
                "dtype %s provided in array_specs for %s, "
                "but differs from function output %s dtype %s" %
                (self.array_specs[array_key].dtype, array_key, self.func,
                 arr.dtype))
        else:
            spec.dtype = arr.dtype

        if spec.interpolatable is None:
            spec.interpolatable = spec.dtype in [
                np.float,
                np.float32,
                np.float64,
                np.float128,
                np.uint8  # assuming this is not used for labels
            ]
            logger.warning(
                "WARNING: You didn't set 'interpolatable' for %s "
                "(func %s) . Based on the dtype %s, it has been "
                "set to %s. This might not be what you want.", array_key,
                self.func, spec.dtype, spec.interpolatable)

        return spec
Example #26
0
def train_until(max_iteration, name='train_net', output_folder='.', clip_max=2000):

    # get the latest checkpoint
    if tf.train.latest_checkpoint(output_folder):
        trained_until = int(tf.train.latest_checkpoint(output_folder).split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    with open(os.path.join(output_folder, name + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(os.path.join(output_folder, name + '_names.json'), 'r') as f:
        net_names = json.load(f)

    # array keys
    raw = gp.ArrayKey('RAW')
    gt_mask = gp.ArrayKey('GT_MASK')
    gt_dt = gp.ArrayKey('GT_DT')
    pred_dt = gp.ArrayKey('PRED_DT')
    loss_gradient = gp.ArrayKey('LOSS_GRADIENT')

    voxel_size = gp.Coordinate((1, 1, 1))
    input_shape = gp.Coordinate(net_config['input_shape'])
    output_shape = gp.Coordinate(net_config['output_shape'])
    context = gp.Coordinate(input_shape - output_shape) / 2

    request = gp.BatchRequest()
    request.add(raw, input_shape)
    request.add(gt_mask, output_shape)
    request.add(gt_dt, output_shape)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, input_shape)
    snapshot_request.add(gt_mask, output_shape)
    snapshot_request.add(gt_dt, output_shape)
    snapshot_request.add(pred_dt, output_shape)
    snapshot_request.add(loss_gradient, output_shape)

    # specify data source
    data_sources = tuple()
    for data_file in data_files:
        current_path = os.path.join(data_dir, data_file)
        with h5py.File(current_path, 'r') as f:
            data_sources += tuple(
                gp.Hdf5Source(
                    current_path,
                    datasets={
                        raw: sample + '/raw',
                        gt_mask: sample + '/fg'
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True, dtype=np.uint16, voxel_size=voxel_size),
                        gt_mask: gp.ArraySpec(interpolatable=False, dtype=np.bool, voxel_size=voxel_size),
                    }
                ) +
                Convert(gt_mask, np.uint8) +
                gp.Pad(raw, context) +
                gp.Pad(gt_mask, context) +
                gp.RandomLocation()
                for sample in f)

    pipeline = (
            data_sources +
            gp.RandomProvider() +
            gp.Reject(gt_mask, min_masked=0.005, reject_probability=1.) +
            DistanceTransform(gt_mask, gt_dt, 3) +
            nl.Clip(raw, 0, clip_max) +
            gp.Normalize(raw, factor=1.0/clip_max) +
            gp.ElasticAugment(
                control_point_spacing=[20, 20, 20],
                jitter_sigma=[1, 1, 1],
                rotation_interval=[0, math.pi/2.0],
                subsample=4) +
            gp.SimpleAugment(mirror_only=[1,2], transpose_only=[1,2]) +

            gp.IntensityAugment(raw, 0.9, 1.1, -0.1, 0.1) +
            gp.IntensityScaleShift(raw, 2,-1) +

            # train
            gp.PreCache(
                cache_size=40,
                num_workers=5) +
            gp.tensorflow.Train(
                os.path.join(output_folder, name),
                optimizer=net_names['optimizer'],
                loss=net_names['loss'],
                inputs={
                    net_names['raw']: raw,
                    net_names['gt_dt']: gt_dt,
                },
                outputs={
                    net_names['pred_dt']: pred_dt,
                },
                gradients={
                    net_names['pred_dt']: loss_gradient,
                },
                save_every=5000) +

            # visualize
            gp.Snapshot({
                    raw: 'volumes/raw',
                    gt_mask: 'volumes/gt_mask',
                    gt_dt: 'volumes/gt_dt',
                    pred_dt: 'volumes/pred_dt',
                    loss_gradient: 'volumes/gradient',
                },
                output_filename=os.path.join(output_folder, 'snapshots', 'batch_{iteration}.hdf'),
                additional_request=snapshot_request,
                every=2000) +
            gp.PrintProfilingStats(every=500)
    )

    with gp.build(pipeline):
        
        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
Example #27
0
def predict_2d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    dataset_shape = raw_data.shape
    dataset_roi = raw_data.roi
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1
    data_dims = len(dataset_shape) - channel_dims

    if data_dims == 3:
        num_samples = dataset_shape[0]
        sample_shape = dataset_shape[channel_dims + 1:]
    else:
        raise RuntimeError(
            "For 2D validation, please provide a 3D array where the first "
            "dimension indexes the samples.")

    num_samples = raw_data.num_samples

    sample_shape = gp.Coordinate(sample_shape)
    sample_size = sample_shape * voxel_size

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    # overwrite source ROI to treat samples as z dimension
    spec = gp.ArraySpec(roi=gp.Roi((0, ) + dataset_roi.get_begin(),
                                   (num_samples, ) + sample_size),
                        voxel_size=(1, ) + voxel_size)
    if gt_data:
        sources = (raw_data.get_source(raw, overwrite_spec=spec),
                   gt_data.get_source(gt, overwrite_spec=spec))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw, overwrite_spec=spec)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] s, h, w)
    # gt: ([c,] s, h, w)
    # target: ([c,] s, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    if gt_data and predictor.target_channels == 0:
        pipeline += AddChannelDim(target)
    # raw: (c, s, h, w)
    # gt: ([c,] s, h, w)
    # target: (c, s, h, w)
    pipeline += TransposeDims(raw, (1, 0, 2, 3))
    if gt_data:
        pipeline += TransposeDims(target, (1, 0, 2, 3))
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # raw: (s, c, h, w)
    # gt: ([c,] s, h, w)
    # target: (s, c, h, w)
    # prediction: (s, c, h, w)
    pipeline += gp.Scan(scan_request)

    total_request = gp.BatchRequest()
    total_request.add(raw, sample_size)
    total_request.add(prediction, sample_size)
    if gt_data:
        total_request.add(gt, sample_size)
        total_request.add(target, sample_size)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Example #28
0
def predict_3d(raw_data, gt_data, predictor):

    raw_channels = max(1, raw_data.num_channels)
    input_shape = predictor.input_shape
    output_shape = predictor.output_shape
    voxel_size = raw_data.voxel_size

    # switch to world units
    input_size = voxel_size * input_shape
    output_size = voxel_size * output_shape

    raw = gp.ArrayKey('RAW')
    gt = gp.ArrayKey('GT')
    target = gp.ArrayKey('TARGET')
    prediction = gp.ArrayKey('PREDICTION')

    channel_dims = 0 if raw_channels == 1 else 1

    num_samples = raw_data.num_samples
    assert num_samples == 0, (
        "Multiple samples for 3D validation not yet implemented")

    scan_request = gp.BatchRequest()
    scan_request.add(raw, input_size)
    scan_request.add(prediction, output_size)
    if gt_data:
        scan_request.add(gt, output_size)
        scan_request.add(target, output_size)

    if gt_data:
        sources = (raw_data.get_source(raw), gt_data.get_source(gt))
        pipeline = sources + gp.MergeProvider()
    else:
        pipeline = raw_data.get_source(raw)
    pipeline += gp.Pad(raw, None)
    if gt_data:
        pipeline += gp.Pad(gt, None)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    pipeline += gp.Normalize(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    if gt_data:
        pipeline += predictor.add_target(gt, target)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += AddChannelDim(raw)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # add a "batch" dimension
    pipeline += AddChannelDim(raw)
    # raw: (1, c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    pipeline += gp_torch.Predict(model=predictor,
                                 inputs={'x': raw},
                                 outputs={0: prediction})
    # remove "batch" dimension
    pipeline += RemoveChannelDim(raw)
    pipeline += RemoveChannelDim(prediction)
    # raw: (c, d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    if channel_dims == 0:
        pipeline += RemoveChannelDim(raw)
    # raw: ([c,] d, h, w)
    # gt: ([c,] d, h, w)
    # target: ([c,] d, h, w)
    # prediction: ([c,] d, h, w)
    pipeline += gp.Scan(scan_request)

    # ensure validation ROI is at least the size of the network input
    roi = raw_data.roi.grow(input_size / 2, input_size / 2)

    total_request = gp.BatchRequest()
    total_request[raw] = gp.ArraySpec(roi=roi)
    total_request[prediction] = gp.ArraySpec(roi=roi)
    if gt_data:
        total_request[gt] = gp.ArraySpec(roi=roi)
        total_request[target] = gp.ArraySpec(roi=roi)

    with gp.build(pipeline):
        batch = pipeline.request_batch(total_request)
        ret = {'raw': batch[raw], 'prediction': batch[prediction]}
        if gt_data:
            ret.update({'gt': batch[gt], 'target': batch[target]})
        return ret
Example #29
0
prediction = gp.ArrayKey('PREDICTION')


class PrepareTrainingData(gp.BatchFilter):
    def process(self, batch, request):

        batch[out_cage_map].data = batch[out_cage_map].data.astype(np.float32)
        batch[out_cage_map].spec.dtype = np.float32


# assemble pipeline
sourceA = gp.ZarrSource('../data/cropped_sample_A.zarr', {
    raw: 'raw',
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
    seg: gp.ArraySpec(interpolatable=False)
})
sourceB = gp.ZarrSource('../data/cropped_sample_B.zarr', {
    raw: 'raw',
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
    seg: gp.ArraySpec(interpolatable=False)
})
sourceC = gp.ZarrSource('../data/cropped_sample_C.zarr', {
    raw: 'raw',
    seg: 'segmentation'
}, {
    raw: gp.ArraySpec(interpolatable=True),
    seg: gp.ArraySpec(interpolatable=False)
Example #30
0
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names,
                          loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    neuron_radius = setup_config["NEURON_RADIUS"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    gt_fg = gp.ArrayKey("GT_FG")
    fg_pred = gp.ArrayKey("FG_PRED")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")
    ratio_pos = gp.ArrayKey("RATIO_POS")
    ratio_neg = gp.ArrayKey("RATIO_NEG")
    dist = gp.ArrayKey("DIST")
    num_pos_pairs = gp.ArrayKey("NUM_POS")
    num_neg_pairs = gp.ArrayKey("NUM_NEG")

    # add request
    request = gp.BatchRequest()
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    request.add(labels_fg, output_size)

    # tensorflow requests
    # snapshot_request.add(raw, input_size)  # input_size request for positioning
    # snapshot_request.add(embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size)
    # snapshot_request.add(maxima, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    # snapshot_request[emst] = gp.ArraySpec()
    # snapshot_request[edges_u] = gp.ArraySpec()
    # snapshot_request[edges_v] = gp.ArraySpec()
    # snapshot_request[ratio_pos] = gp.ArraySpec()
    # snapshot_request[ratio_neg] = gp.ArraySpec()
    # snapshot_request[dist] = gp.ArraySpec()
    # snapshot_request[num_pos_pairs] = gp.ArraySpec()
    # snapshot_request[num_neg_pairs] = gp.ArraySpec()

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched) + RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + GrowLabels(labels, radii=[neuron_radius * micron_scale])
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(labels, labels_fg) +
        BinarizeGt(labels_fg, labels_fg_bin) +
        gp.BalanceLabels(labels_fg_bin, loss_weights) +
        gp.PreCache(cache_size=cache_size, num_workers=num_workers) +
        gp.tensorflow.Train(
            "train_net",
            optimizer=create_custom_loss(mknet_tensor_names, setup_config),
            loss=None,
            inputs={
                mknet_tensor_names["loss_weights"]: loss_weights,
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_labels"]: labels_fg,
            },
            outputs={
                mknet_tensor_names["embedding"]: embedding,
                mknet_tensor_names["fg"]: fg,
                loss_tensor_names["fg_pred"]: fg_pred,
                loss_tensor_names["maxima"]: maxima,
                loss_tensor_names["gt_fg"]: gt_fg,
                loss_tensor_names["emst"]: emst,
                loss_tensor_names["edges_u"]: edges_u,
                loss_tensor_names["edges_v"]: edges_v,
                loss_tensor_names["ratio_pos"]: ratio_pos,
                loss_tensor_names["ratio_neg"]: ratio_neg,
                loss_tensor_names["dist"]: dist,
                loss_tensor_names["num_pos_pairs"]: num_pos_pairs,
                loss_tensor_names["num_neg_pairs"]: num_neg_pairs,
            },
            gradients={
                mknet_tensor_names["embedding"]: gradient_embedding,
                mknet_tensor_names["fg"]: gradient_fg,
            },
            save_every=checkpoint_every,
            summary="Merge/MergeSummary:0",
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                # labeled data
                labels: "volumes/labels",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                embedding: "volumes/embedding",
                fg: "volumes/fg",
                maxima: "volumes/maxima",
                gt_fg: "volumes/gt_fg",
                fg_pred: "volumes/fg_pred",
                gradient_embedding: "volumes/gradient_embedding",
                gradient_fg: "volumes/gradient_fg",
                # output trees
                emst: "emst",
                edges_u: "edges_u",
                edges_v: "edges_v",
                # output debug data
                ratio_pos: "ratio_pos",
                ratio_neg: "ratio_neg",
                dist: "dist",
                num_pos_pairs: "num_pos_pairs",
                num_neg_pairs: "num_neg_pairs",
                loss_weights: "volumes/loss_weights",
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)