Example #1
0
    def process(self, batch, request):
        output = Batch()

        for in_key, out_key in zip(self.arrays, self.output_arrays):
            array = batch[in_key]
            data = array.data
            d_min = data.min()
            d_max = data.max()
            assert (
                d_min >= 0 and d_max <= 1
            ), f"Clahe expects data in range (0,1), got ({d_min}, {d_max})"
            if np.isclose(d_max, d_min):
                output[out_key] = Array(data, array.spec)
                continue
            if self.normalize:
                data = (data - d_min) / (d_max - d_min)
            shape = data.shape
            data_dims = len(shape)
            kernel_dims = len(self.kernel_size)
            extra_dims = data_dims - kernel_dims
            voxel_size = array.spec.voxel_size

            for index in itertools.product(*[range(s) for s in shape[:extra_dims]]):
                data[index] = clahe(
                    data[index],
                    kernel_size=Coordinate(self.kernel_size / voxel_size),
                    clip_limit=self.clip_limit,
                    nbins=self.nbins,
                )
            assert (
                data.min() >= 0 and data.max() <= 1
            ), f"Clahe should output data in range (0,1), got ({data.min()}, {data.max()})"
            output[out_key] = Array(data, array.spec).crop(request[out_key].roi)
        return output
Example #2
0
    def process(self, batch, request):

        if self.target not in request:
            return

        input_roi = batch.arrays[self.source].spec.roi
        request_roi = request[self.target].roi

        assert input_roi.contains(request_roi)

        # upsample
        order = 3 if batch.arrays[self.source].spec.interpolatable else 0
        data = ndimage.zoom(batch.arrays[self.source].data,
                            np.array(self.factor),
                            order=order)

        # Create output array, crop accordingly.
        spec = self.spec[self.target].copy()
        spec.roi = input_roi
        ar = Array(data, spec)
        batch.arrays[self.target] = ar.crop(request_roi)

        if self.source in request:
            # restore requested rois
            request_roi = request[self.source].roi

            if input_roi != request_roi:
                assert input_roi.contains(request_roi)

                logger.debug("restoring original request roi %s of %s from %s",
                             request_roi, self.source, input_roi)
                cropped = batch.arrays[self.source].crop(request_roi)
                batch.arrays[self.source] = cropped
Example #3
0
    def train_step(self, batch, request):

        data = {}
        for input_name, network_input in self.inputs.items():
            if isinstance(network_input, ArrayKey):
                if network_input in batch.arrays:
                    data[input_name] = batch.arrays[network_input].data
                else:
                    logger.warn("batch does not contain %s, input %s will not "
                                "be set", network_input, input_name)
            elif isinstance(network_input, np.ndarray):
                data[input_name] = network_input
            elif isinstance(network_input, str):
                data[input_name] = getattr(batch, network_input)
            else:
                raise Exception(
                    "Unknown network input type {}, can't be given to "
                    "network".format(network_input))
        self.net_io.set_inputs(data)

        loss = self.solver.step(1)
        # self.__consistency_check()

        requested_outputs = {
            name: array_key
            for name, array_key in self.outputs.items()
            if array_key in request.array_specs }

        if requested_outputs:

            output = self.net_io.get_outputs()

            for output_name, array_key in requested_outputs.items():

                spec = self.spec[array_key].copy()
                spec.roi = request[array_key].roi
                batch.arrays[array_key] = Array(
                    output[output_name][0], # strip #batch dimension
                    spec)

        requested_gradients = {
            name: array_key
            for name, array_key in self.gradients.items()
            if array_key in request.array_specs }

        if requested_gradients:

            diffs = self.net_io.get_output_diffs()

            for output_name, array_key in requested_gradients.items():

                spec = self.spec[array_key].copy()
                spec.roi = request[array_key].roi
                batch.arrays[array_key] = Array(
                    diffs[output_name][0], # strip #batch dimension
                    spec)

        batch.loss = loss
        batch.iteration = self.solver.iter
Example #4
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        _, request_spec = request.array_specs.items()[0]

        logger.debug("Reading %s in %s...", self.array, request_spec.roi)

        voxel_size = self.spec[self.array].voxel_size

        # scale request roi to voxel units
        dataset_roi = request_spec.roi/voxel_size

        # shift request roi into dataset
        dataset_roi = dataset_roi - self.spec[self.array].roi.get_offset()/voxel_size

        # create array spec
        array_spec = self.spec[self.array].copy()
        array_spec.roi = request_spec.roi

        # add array to batch
        batch.arrays[self.array] = Array(
            self.__read(dataset_roi),
            array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Example #5
0
    def process(self, batch, request):
        output = Batch()

        for in_key, out_key in zip(self.arrays, self.output_arrays):
            array = batch[in_key]
            data = array.data
            shape = data.shape
            data_dims = len(shape)
            kernel_dims = len(self.kernel_size)
            extra_dims = data_dims - kernel_dims
            if self.slice_wise:
                for index in itertools.product(
                        *[range(s) for s in shape[:extra_dims]]):
                    data[index] = mclahe(
                        data[index],
                        kernel_size=self.kernel_size,
                        clip_limit=self.clip_limit,
                        n_bins=self.nbins,
                        use_gpu=False,
                        adaptive_hist_range=self.adaptive_hist_range,
                    )
            else:
                full_kernel = np.array(
                    (1, ) * extra_dims + tuple(self.kernel_size), dtype=int)
                data = mclahe(
                    data,
                    kernel_size=full_kernel,
                    clip_limit=self.clip_limit,
                    n_bins=self.nbins,
                    # use_gpu=False,
                ).astype(self.spec[out_key].dtype)
            output[out_key] = Array(data,
                                    array.spec).crop(request[out_key].roi)
        return output
Example #6
0
    def __setup_batch(self, batch_spec, chunk):
        '''Allocate a batch matching the sizes of ``batch_spec``, using
        ``chunk`` as template.'''

        batch = Batch()

        for (array_key, spec) in batch_spec.array_specs.items():
            roi = spec.roi
            voxel_size = self.spec[array_key].voxel_size

            # get the 'non-spatial' shape of the chunk-batch
            # and append the shape of the request to it
            array = chunk.arrays[array_key]
            shape = array.data.shape[:-roi.dims()]
            shape += (roi.get_shape() // voxel_size)

            spec = self.spec[array_key].copy()
            spec.roi = roi
            logger.info("allocating array of shape %s for %s", shape,
                        array_key)
            batch.arrays[array_key] = Array(data=np.zeros(shape), spec=spec)

        for (points_key, spec) in batch_spec.points_specs.items():
            roi = spec.roi
            spec = self.spec[points_key].copy()
            spec.roi = roi
            batch.points[points_key] = Points(data={}, spec=spec)

        logger.debug("setup batch to fill %s", batch)

        return batch
Example #7
0
    def process(self, batch, request):

        # create vector map and add it to batch
        for (array_key,
             (src_points_key,
              trg_points_key)) in self.array_to_src_trg_points.items():
            if array_key in request:
                vector_map = self.__get_vector_map(
                    batch=batch,
                    request=request,
                    vector_map_array_key=array_key)
                spec = self.spec[array_key].copy()
                spec.roi = request[array_key].roi
                batch.arrays[array_key] = Array(data=vector_map, spec=spec)

        # restore request / remove not requested points in padding-for-neighbors region & shrink batch roi
        for (array_key,
             (src_points_key,
              trg_points_key)) in self.array_to_src_trg_points.items():
            if array_key in request:
                if trg_points_key in request:
                    for loc_id, point in batch.points[
                            trg_points_key].data.items():
                        if not request[trg_points_key].roi.contains(
                                Coordinate(point.location)):
                            del batch.points[trg_points_key].data[loc_id]
                    neg_pad_for_partners = Coordinate(
                        (self.pad_for_partners * np.asarray([-1])).tolist())
                    batch.points[trg_points_key].spec.roi = batch.points[
                        trg_points_key].spec.roi.grow(neg_pad_for_partners,
                                                      neg_pad_for_partners)
                elif trg_points_key in batch.points:
                    del batch.points[trg_points_key]
Example #8
0
    def process(self, batch, request):
        outputs = Batch()

        if self.target not in request:
            return

        input_roi = batch.arrays[self.source].spec.roi
        request_roi = request[self.target].roi

        assert input_roi.contains(request_roi)

        # upsample

        logger.debug("upsampling %s with %s", self.source, self.factor)

        crop = batch.arrays[self.source].crop(request_roi)
        data = crop.data

        for d, f in enumerate(self.factor):
            data = np.repeat(data, f, axis=d)

        # create output array
        spec = self.spec[self.target].copy()
        spec.roi = request_roi
        outputs.arrays[self.target] = Array(data, spec)
        return outputs
Example #9
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with self._open_file(self.filename) as data_file:
            for (array_key, request_spec) in request.array_specs.items():

                voxel_size = self.spec[array_key].voxel_size

                # scale request roi to voxel units
                dataset_roi = request_spec.roi / voxel_size

                # shift request roi into dataset
                dataset_roi = dataset_roi - self.spec[
                    array_key].roi.get_offset() / voxel_size

                # create array spec
                array_spec = self.spec[array_key].copy()
                array_spec.roi = request_spec.roi

                # add array to batch
                batch.arrays[array_key] = Array(
                    self.__read(data_file, self.datasets[array_key],
                                dataset_roi, self.channel_ids[array_key]),
                    array_spec)

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Example #10
0
    def process(self, batch, request):

        labels = batch.arrays[self.labels]
        scales = np.ones(labels.data.shape, dtype=np.float32)

        # check if fg in batch
        if np.max(scales) > 0:
            scales = scales * self.factors[0]
            scales[labels.data > 0] = self.factors[1]

        spec = self.spec[self.scales].copy()
        spec.roi = labels.spec.roi
        batch.arrays[self.scales] = Array(scales, spec)
Example #11
0
    def process(self, batch, request):

        gt = batch.arrays[self.labels]

        # 0 marks included regions (to be used directly with distance transform
        # later)
        include_mask = np.ones(gt.data.shape)

        gt_labels = np.unique(gt.data)
        logger.debug("batch contains GT labels: " + str(gt_labels))
        for label in gt_labels:
            if label in self.exclude:
                logger.debug("excluding label " + str(label))
                gt.data[gt.data == label] = self.background_value
            else:
                include_mask[gt.data == label] = 0

        # if no ignore mask is provided or requested, we are done
        if not self.ignore_mask or not self.ignore_mask in request:
            return

        voxel_size = self.spec[self.labels].voxel_size
        distance_to_include = distance_transform_edt(include_mask,
                                                     sampling=voxel_size)
        logger.debug("max distance to foreground is " +
                     str(distance_to_include.max()))

        # 1 marks included regions, plus a context area around them
        include_mask = distance_to_include < self.ignore_mask_erode

        # include mask was computed on labels ROI, we need to copy it to
        # the requested ignore_mask ROI
        gt_ignore_roi = request[self.ignore_mask].roi

        intersection = gt.spec.roi.intersect(gt_ignore_roi)
        intersection_in_gt = intersection - gt.spec.roi.get_offset()
        intersection_in_gt_ignore = intersection - gt_ignore_roi.get_offset()

        # to voxel coordinates
        intersection_in_gt //= voxel_size
        intersection_in_gt_ignore //= voxel_size

        gt_ignore = np.zeros((gt_ignore_roi // voxel_size).get_shape(),
                             dtype=np.uint8)
        gt_ignore[intersection_in_gt_ignore.get_bounding_box()] = include_mask[
            intersection_in_gt.get_bounding_box()]

        spec = self.spec[self.labels].copy()
        spec.roi = gt_ignore_roi
        spec.dtype = np.uint8
        batch.arrays[self.ignore_mask] = Array(gt_ignore, spec)
Example #12
0
    def process(self, batch, request):

        # create vector map and add it to batch
        for (
            array_key,
            (src_points_key, trg_points_key),
        ) in self.array_to_src_trg_points.items():
            if array_key in request:
                vector_map = self.__get_vector_map(
                    batch=batch, request=request, vector_map_array_key=array_key
                )
                spec = self.spec[array_key].copy()
                spec.roi = request[array_key].roi
                batch.arrays[array_key] = Array(data=vector_map, spec=spec)
Example #13
0
    def process(self, batch, request):

        labels = batch.arrays[self.labels]

        assert len(np.unique(labels.data)) <= 2, (
            "Found more than two labels in %s."%self.labels)
        assert np.min(labels.data) in [0.0, 1.0], (
            "Labels %s are not binary."%self.labels)
        assert np.max(labels.data) in [0.0, 1.0], (
            "Labels %s are not binary."%self.labels)

        # initialize error scale with 1s
        error_scale = np.ones(labels.data.shape, dtype=np.float32)

        # set error_scale to 0 in masked-out areas
        for key in self.masks:
            mask = batch.arrays[key]
            assert labels.data.shape == mask.data.shape, (
                "Shape of mask %s %s does not match %s %s"%(
                    mask,
                    mask.data.shape,
                    self.labels,
                    labels.data.shape))
            error_scale *= mask.data

        if not self.slab:
            slab = error_scale.shape
        else:
            # slab with -1 replaced by shape
            slab = tuple(
                m if s == -1 else s
                for m, s in zip(error_scale.shape, self.slab))

        slab_ranges = (
            range(0, m, s)
            for m, s in zip(error_scale.shape, slab))

        for start in itertools.product(*slab_ranges):
            slices = tuple(
                slice(start[d], start[d] + slab[d])
                for d in range(len(slab)))
            self.__balance(
                labels.data[slices],
                error_scale[slices])

        spec = self.spec[self.scales].copy()
        spec.roi = labels.spec.roi
        batch.arrays[self.scales] = Array(error_scale, spec)
Example #14
0
    def train_step(self, batch, request):

        array_outputs = self.__collect_requested_outputs(request)
        inputs = self.__collect_provided_inputs(batch)

        to_compute = {
            'optimizer': self.optimizer,
            'loss': self.loss,
            'iteration': self.iteration_increment
        }
        to_compute.update(array_outputs)

        # compute outputs, gradients, and update variables
        if isinstance(self.summary, str):
            to_compute["summaries"] = self.summary
        elif isinstance(self.summary, dict):
            for k, (v, f) in self.summary.items():
                if int(self.current_step + 1) % f == 0:
                    to_compute[k] = v
        outputs = self.session.run(to_compute, feed_dict=inputs)

        for array_key in array_outputs:
            spec = self.spec[array_key].copy()
            spec.roi = request[array_key].roi
            batch.arrays[array_key] = Array(outputs[array_key], spec)

        batch.loss = outputs['loss']
        batch.iteration = outputs['iteration'][0]
        self.current_step = batch.iteration
        if self.summary is not None:
            if isinstance(self.summary, str) and \
               (batch.iteration % self.log_every == 0 or batch.iteration == 1):
                self.summary_saver.add_summary(outputs['summaries'],
                                               batch.iteration)
            else:
                for k, (_, f) in self.summary.items():
                    if int(self.current_step) % f == 0:
                        self.summary_saver.add_summary(outputs[k],
                                                       batch.iteration)

        if batch.iteration % self.save_every == 0:

            checkpoint_name = (self.meta_graph_filename +
                               '_checkpoint_%i' % batch.iteration)

            logger.info("Creating checkpoint %s", checkpoint_name)

            self.full_saver.save(self.session, checkpoint_name)
Example #15
0
    def process(self, batch, request):

        labels = batch.arrays[self.labels]
        scales = np.ones(labels.data.shape, dtype=np.float32)
        classes = np.unique(labels.data)

        # heads up: skeleton only for all foreground classes
        if np.max(classes) > 0:
            foreground = labels.data > 0
            skeleton = skeletonize_3d(foreground) > 0
            scales = scales * self.factors[0]
            scales[skeleton] = self.factors[1]

        spec = self.spec[self.scales].copy()
        spec.roi = labels.spec.roi
        batch.arrays[self.scales] = Array(scales, spec)
Example #16
0
    def predict(self, batch, request):

        logger.debug("predicting in batch %i", batch.id)

        array_outputs = self.__collect_requested_outputs(request)
        inputs = self.__collect_provided_inputs(batch)

        # compute outputs
        outputs = self.session.run(array_outputs, feed_dict=inputs)

        for array_key in array_outputs:
            spec = self.spec[array_key].copy()
            spec.roi = request[array_key].roi
            batch.arrays[array_key] = Array(outputs[array_key], spec)

        logger.debug("predicted in batch %i", batch.id)
Example #17
0
    def process(self, batch, request):

        for old_key, new_key in self.to_duplicate.items():

            if old_key not in request:
                continue

            assert new_key not in batch.arrays, "key {} already present in batch".format(
                new_key)
            assert isinstance(new_key,
                              ArrayKey), "Can only duplicate array data"

            array_ = batch.arrays[old_key]
            array = Array(array_.data.copy(),
                          spec=array_.spec,
                          attrs=copy.deepcopy(array_.attrs))
            batch.arrays[new_key] = array
Example #18
0
    def predict(self, batch, request):

        self.net_io.set_inputs({
            input_name: batch.arrays[array_key].data
            for input_name, array_key in self.inputs.items()
        })

        self.net.forward()
        output = self.net_io.get_outputs()

        for output_name, array_key in self.outputs.items():
            spec = self.spec[array_key].copy()
            spec.roi = request[array_key].roi
            batch.arrays[array_key] = Array(
                output[output_name][0],  # strip #batch dimension
                spec)

        return batch
Example #19
0
    def process(self, batch, request):

        src_points = batch.points[self.src_points]
        voxel_size = self.spec[self.array].voxel_size

        # get roi used for creating the new array (points_roi does not
        # necessarily align with voxel size)
        enlarged_vol_roi = src_points.spec.roi.snap_to_grid(voxel_size)
        offset = enlarged_vol_roi.get_begin() / voxel_size
        shape = enlarged_vol_roi.get_shape() / voxel_size
        data_roi = Roi(offset, shape)

        logger.debug("Src points in %s", src_points.spec.roi)
        for i, point in src_points.data.items():
            logger.debug("%d, %s", i, point.location)
        logger.debug("Data roi in voxels: %s", data_roi)
        logger.debug("Data roi in world units: %s", data_roi * voxel_size)

        mask_array = None if self.mask is None else batch.arrays[
            self.mask].crop(enlarged_vol_roi)

        partner_vectors_data, pointmask = self.__draw_partner_vectors(
            src_points, batch.points[self.trg_points], data_roi, voxel_size,
            enlarged_vol_roi.get_begin(), self.radius, mask_array)

        # create array and crop it to requested roi
        spec = self.spec[self.array].copy()
        spec.roi = data_roi * voxel_size
        partner_vectors = Array(data=partner_vectors_data, spec=spec)
        logger.debug("Cropping partner vectors to %s", request[self.array].roi)
        batch.arrays[self.array] = partner_vectors.crop(
            request[self.array].roi)

        if self.pointmask is not None and self.pointmask in request:
            spec = self.spec[self.array].copy()
            spec.roi = data_roi * voxel_size
            pointmask = Array(data=np.array(pointmask, dtype=spec.dtype),
                              spec=spec)
            batch.arrays[self.pointmask] = pointmask.crop(
                request[self.pointmask].roi)

        # restore requested ROI of src and target points.
        if self.src_points in request:
            self.__restore_points_roi(request, self.src_points,
                                      batch.points[self.src_points])
        if self.trg_points in request:
            self.__restore_points_roi(request, self.trg_points,
                                      batch.points[self.trg_points])
        # restore requested objectmask
        if self.mask is not None:
            batch.arrays[self.mask] = batch.arrays[self.mask].crop(
                request[self.mask].roi)
Example #20
0
    def process(self, batch, request):

        if self.target not in request:
            return

        input_roi = batch.arrays[self.source].spec.roi
        request_roi = request[self.target].roi

        assert input_roi.contains(request_roi)

        # downsample
        if isinstance(self.factor, tuple):
            slices = tuple(
                slice(None, None, k)
                for k in self.factor)
        else:
            slices = tuple(
                slice(None, None, self.factor)
                for i in range(input_roi.dims()))

        logger.debug("downsampling %s with %s", self.source, slices)

        crop = batch.arrays[self.source].crop(request_roi)
        data = crop.data[slices]

        # create output array
        spec = self.spec[self.target].copy()
        spec.roi = request_roi
        batch.arrays[self.target] = Array(data, spec)

        if self.source in request:

            # restore requested rois
            request_roi = request[self.source].roi

            if input_roi != request_roi:

                assert input_roi.contains(request_roi)

                logger.debug(
                    "restoring original request roi %s of %s from %s",
                    request_roi, self.source, input_roi)
                cropped = batch.arrays[self.source].crop(request_roi)
                batch.arrays[self.source] = cropped
Example #21
0
    def process(self, batch, request):

        points = batch.points[self.points]
        voxel_size = self.spec[self.array].voxel_size

        # get roi used for creating the new array (points_roi does not
        # necessarily align with voxel size)
        enlarged_vol_roi = points.spec.roi.snap_to_grid(voxel_size)
        offset = enlarged_vol_roi.get_begin() / voxel_size
        shape = enlarged_vol_roi.get_shape() / voxel_size
        data_roi = Roi(offset, shape)

        # points ROI is at least +- 1 in t of requested array ROI, we can save
        # some time by shaving the excess off
        data_roi = data_roi.grow((-1, 0, 0, 0), (-1, 0, 0, 0))

        logger.debug("Points in %s", points.spec.roi)
        for i, point in points.data.items():
            logger.debug("%d, %s", i, point.location)
        logger.debug("Data roi in voxels: %s", data_roi)
        logger.debug("Data roi in world units: %s", data_roi * voxel_size)

        parent_vectors_data, mask_data = self.__draw_parent_vectors(
            points, data_roi, voxel_size, enlarged_vol_roi.get_begin(),
            self.radius)

        # create array and crop it to requested roi
        spec = self.spec[self.array].copy()
        spec.roi = data_roi * voxel_size
        parent_vectors = Array(data=parent_vectors_data, spec=spec)
        logger.debug("Cropping parent vectors to %s", request[self.array].roi)
        batch.arrays[self.array] = parent_vectors.crop(request[self.array].roi)

        # create mask and crop it to requested roi
        spec = self.spec[self.mask].copy()
        spec.roi = data_roi * voxel_size
        mask = Array(data=mask_data, spec=spec)
        logger.debug("Cropping mask to %s", request[self.mask].roi)
        batch.arrays[self.mask] = mask.crop(request[self.mask].roi)

        # restore requested ROI of points
        if self.points in request:
            request_roi = request[self.points].roi
            points.spec.roi = request_roi
            for i, p in list(points.data.items()):
                if not request_roi.contains(p.location):
                    del points.data[i]

            if len(points.data) == 0:
                logger.warning("Returning empty batch for key %s and roi %s" %
                               (self.points, request_roi))
Example #22
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        for (array_key, request_spec) in request.array_specs.items():

            logger.debug("Reading %s in %s...", array_key, request_spec.roi)

            voxel_size = self.spec[array_key].voxel_size

            # scale request roi to voxel units
            dataset_roi = request_spec.roi / voxel_size

            # shift request roi into dataset
            dataset_roi = dataset_roi - self.spec[array_key].roi.get_offset(
            ) / voxel_size

            # create array spec
            array_spec = self.spec[array_key].copy()
            array_spec.roi = request_spec.roi

            # read the data
            if array_key in self.datasets:
                data = self.__read_array(self.datasets[array_key], dataset_roi)
            elif array_key in self.masks:
                data = self.__read_mask(self.masks[array_key], dataset_roi)
            else:
                assert False, (
                    "Encountered a request for %s that is neither a volume "
                    "nor a mask." % array_key)

            # add array to batch
            batch.arrays[array_key] = Array(data, array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Example #23
0
    def train_step(self, batch, request):

        array_outputs = self.__collect_requested_outputs(request)
        inputs = self.__collect_provided_inputs(batch)

        to_compute = {
            'optimizer': self.optimizer,
            'loss': self.loss,
            'iteration': self.iteration_increment}
        to_compute.update(array_outputs)

        # compute outputs, gradients, and update variables
        if self.summary is not None:
            outputs, summaries = self.session.run([to_compute, self.summary], feed_dict=inputs)
        else:
            outputs = self.session.run(to_compute, feed_dict=inputs)

        for array_key in array_outputs:
            spec = self.spec[array_key].copy()
            spec.roi = request[array_key].roi
            batch.arrays[array_key] = Array(
                outputs[array_key],
                spec)

        batch.loss = outputs['loss']
        batch.iteration = outputs['iteration'][0]
        if self.summary is not None:
            self.summary_saver.add_summary(summaries, batch.iteration)

        if batch.iteration%self.save_every == 0:

            checkpoint_name = (
                self.meta_graph_filename +
                '_checkpoint_%i'%batch.iteration)

            logger.info(
                "Creating checkpoint %s",
                checkpoint_name)

            self.full_saver.save(
                self.session,
                checkpoint_name)
    def process(self, batch, request):

        # do nothing if no gt binary maps were requested
        if self.skip_next:
            self.skip_next = False
            return

        for EZ_mask_type in self.EZ_masks_to_create:
            binary_map_type = self.EZ_masks_to_binary_map[EZ_mask_type]
            binary_map = batch.arrays[binary_map_type].data
            resolution = batch.arrays[binary_map_type].resolution
            EZ_mask = self.__get_exclusivezone_mask(
                binary_map,
                shape_EZ_mask=request.arrays[EZ_mask_type].get_shape(),
                resolution=resolution)

            batch.arrays[EZ_mask_type] = Array(
                data=EZ_mask,
                roi=request.arrays[EZ_mask_type],
                resolution=resolution)
Example #25
0
    def process(self, batch, request):
        outputs = Batch()

        # downsample
        if isinstance(self.factor, tuple):
            slices = tuple(slice(None, None, k) for k in self.factor)
        else:
            slices = tuple(
                slice(None, None, self.factor)
                for i in range(batch[self.source].spec.roi.dims()))

        logger.debug("downsampling %s with %s", self.source, slices)

        data = batch.arrays[self.source].data[slices]

        # create output array
        spec = self.spec[self.target].copy()
        spec.roi = request[self.target].roi
        outputs.arrays[self.target] = Array(data, spec)

        return outputs
Example #26
0
    def process(self, batch, request):

        gt_labels = batch.arrays[self.labels_array_key]
        next_id = gt_labels.data.max() + 1

        gt_pos_pass = gt_labels.data

        if self.ignore_array_key and self.ignore_array_key in batch.arrays:

            gt_neg_pass = np.array(gt_labels.data)
            gt_neg_pass[batch.arrays[self.ignore_array_key].data ==
                        0] = next_id

        else:

            gt_neg_pass = gt_pos_pass

        spec = self.spec[self.malis_comp_array_key].copy()
        spec.roi = request[self.labels_array_key].roi
        batch.arrays[self.malis_comp_array_key] = Array(
            np.array([gt_neg_pass, gt_pos_pass]), spec)
Example #27
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        cv = CloudVolume(self.cloudvolume_url, use_https=True, mip=self.mip)

        request_spec = request.array_specs[self.array_key]
        array_key = self.array_key
        logger.debug("Reading %s in %s...", array_key, request_spec.roi)

        voxel_size = self.array_spec.voxel_size

        # scale request roi to voxel units
        dataset_roi = request_spec.roi / voxel_size

        # shift request roi into dataset
        dataset_roi = dataset_roi - self.spec[
            array_key].roi.get_offset() / voxel_size

        # create array spec
        array_spec = self.array_spec.copy()
        array_spec.roi = request_spec.roi
        # array_spec.voxel_size = array_spec.voxel_size

        # add array to batch
        batch.arrays[array_key] = Array(
            self.__read(cv, dataset_roi),
            array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Example #28
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batch = Batch()

        with h5py.File(self.filename, 'r') as hdf_file:

            for (array_key, request_spec) in request.array_specs.items():

                logger.debug("Reading %s in %s...", array_key,
                             request_spec.roi)

                voxel_size = self.spec[array_key].voxel_size

                # scale request roi to voxel units
                dataset_roi = request_spec.roi / voxel_size

                # shift request roi into dataset
                dataset_roi = dataset_roi - self.spec[
                    array_key].roi.get_offset() / voxel_size

                # create array spec
                array_spec = self.spec[array_key].copy()
                array_spec.roi = request_spec.roi

                # add array to batch
                batch.arrays[array_key] = Array(
                    self.__read(hdf_file, self.datasets[array_key],
                                dataset_roi), array_spec)

        logger.debug("done")

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Example #29
0
    def provide(self, request):

        timing = Timing(self)
        timing.start()

        batches = [
            self.get_upstream_provider().request_batch(request)
            for _ in range(self.num_repetitions)
        ]

        batch = Batch()
        for key, spec in request.array_specs.items():

            data = np.stack([b[key].data for b in batches])
            batch[key] = Array(data, batches[0][key].spec.copy())

        # copy points of first batch requested
        for key, spec in request.points_specs.items():
            batch[key] = batches[0][key]

        timing.stop()
        batch.profiling_stats.add(timing)

        return batch
Example #30
0
    def process(self, batch, request):

        labels_roi = request[self.labels].roi

        logger.debug("computing ground-truth affinities from labels")
        affinities = malis.seg_to_affgraph(
                batch.arrays[self.labels].data.astype(np.int32),
                self.affinity_neighborhood
        ).astype(np.float32)


        # crop affinities to original label ROI
        offset = labels_roi.get_offset()
        shift = -offset - self.padding_neg
        crop_roi = labels_roi.shift(shift)
        crop_roi /= self.spec[self.labels].voxel_size
        crop = crop_roi.get_bounding_box()

        logger.debug("cropping with " + str(crop))
        affinities = affinities[(slice(None),)+crop]

        spec = self.spec[self.affinities].copy()
        spec.roi = labels_roi
        batch.arrays[self.affinities] = Array(affinities, spec)

        if self.affinities_mask and self.affinities_mask in request:

            if self.labels_mask:

                logger.debug("computing ground-truth affinities mask from "
                             "labels mask")
                affinities_mask = malis.seg_to_affgraph(
                    batch.arrays[self.labels_mask].data.astype(np.int32),
                    self.affinity_neighborhood)
                affinities_mask = affinities_mask[(slice(None),)+crop]

            else:

                affinities_mask = np.ones_like(affinities)

            if self.unlabelled:

                # 1 for all affinities between unlabelled voxels
                unlabelled = (1 - batch.arrays[self.unlabelled].data)
                unlabelled_mask = malis.seg_to_affgraph(
                    unlabelled.astype(np.int32),
                    self.affinity_neighborhood)
                unlabelled_mask = unlabelled_mask[(slice(None),)+crop]

                # 0 for all affinities between unlabelled voxels
                unlabelled_mask = (1 - unlabelled_mask)

                # combine with mask
                affinities_mask = affinities_mask*unlabelled_mask

            affinities_mask = affinities_mask.astype(np.float32)
            batch.arrays[self.affinities_mask] = Array(affinities_mask, spec)

        else:

            if self.labels_mask is not None:
                logger.warning("GT labels does have a mask, but affinities "
                               "mask is not requested.")

        # crop labels to original label ROI
        batch.arrays[self.labels] = batch.arrays[self.labels].crop(labels_roi)

        # same for label mask
        if self.labels_mask:
            batch.arrays[self.labels_mask] = batch.arrays[self.labels_mask].crop(labels_roi)
        # and unlabelled mask
        if self.unlabelled:
            batch.arrays[self.unlabelled] = batch.arrays[self.unlabelled].crop(labels_roi)

        batch.affinity_neighborhood = self.affinity_neighborhood