예제 #1
0
    def process(self, batch, request):

        for (volume_type, volume) in batch.volumes.items():

            # apply transformation
            volume.data = augment.apply_transformation(
                volume.data,
                self.transformations[volume_type],
                interpolate=volume_type.interpolate)

            # restore original ROIs
            volume.roi = request.volumes[volume_type]
예제 #2
0
    def __fast_point_projection(self, transformation, nodes, source_roi,
                                target_roi):
        if len(nodes) < 1:
            return []
        # rasterize the points into an array
        ids, locs = zip(*[(
            node.id,
            (np.floor(node.location).astype(int) - source_roi.get_begin()) //
            self.voxel_size,
        ) for node in nodes if source_roi.contains(node.location)])
        ids, locs = np.array(ids), tuple(zip(*locs))
        points_array = np.zeros(source_roi.get_shape() / self.voxel_size,
                                dtype=np.int64)
        points_array[locs] = ids

        # reshape array data into (channels,) + spatial dims
        shape = points_array.shape
        data = points_array.reshape((-1, ) + shape[-self.spatial_dims:])

        # apply transformation on each channel
        data = np.array([
            augment.apply_transformation(data[c],
                                         transformation,
                                         interpolate="nearest")
            for c in range(data.shape[0])
        ])

        missing_points = []
        projected_locs = ndimage.measurements.center_of_mass(
            data > 0, data, ids)
        projected_locs = [
            np.array(loc[-self.spatial_dims:]) * self.voxel_size +
            target_roi.get_begin() for loc in projected_locs
        ]
        node_dict = {node.id: node for node in nodes}
        for point_id, proj_loc in zip(ids, projected_locs):
            point = node_dict.pop(point_id)
            if not any([np.isnan(x) for x in proj_loc]):
                assert (
                    len(proj_loc) == self.spatial_dims
                ), "projected location has wrong number of dimensions: {}, expected: {}".format(
                    len(proj_loc), self.spatial_dims)
                point.location[-self.spatial_dims:] = proj_loc
            else:
                missing_points.append(point)
        for node in node_dict.values():
            missing_points.append(point)
        logger.debug("{} of {} points lost in fast points projection".format(
            len(missing_points), len(ids)))

        return missing_points
예제 #3
0
    def process(self, batch, request):

        for (array_key, array) in batch.arrays.items():

            if array_key not in self.target_rois:
                continue

            # for arrays, the target ROI and the requested ROI should be the
            # same in spatial coordinates
            assert (
                self.target_rois[array_key].get_begin() ==
                request[array_key].roi.get_begin()[-self.spatial_dims:]
            ), "Target roi offset {} does not match request roi offset {}".format(
                self.target_rois[array_key].get_begin(),
                request[array_key].roi.get_begin()[-self.spatial_dims:],
            )

            assert (
                self.target_rois[array_key].get_shape() ==
                request[array_key].roi.get_shape()[-self.spatial_dims:]
            ), "Target roi offset {} does not match request roi offset {}".format(
                self.target_rois[array_key].get_shape(),
                request[array_key].roi.get_shape()[-self.spatial_dims:],
            )

            # reshape array data into (channels,) + spatial dims
            shape = array.data.shape
            channel_shape = shape[:-self.spatial_dims]
            data = array.data.reshape((-1, ) + shape[-self.spatial_dims:])

            # apply transformation on each channel
            data = np.array([
                augment.apply_transformation(
                    data[c],
                    self.transformations[array_key],
                    interpolate=self.spec[array_key].interpolatable,
                ) for c in range(data.shape[0])
            ])

            data_roi = request[array_key].roi / self.spec[array_key].voxel_size
            array.data = data.reshape(
                channel_shape + data_roi.get_shape()[-self.spatial_dims:])

            # restore original ROIs
            array.spec.roi = request[array_key].roi

        for (graph_key, graph) in batch.graphs.items():

            nodes = list(graph.nodes)

            if self.use_fast_points_transform:
                missed_nodes = self.__fast_point_projection(
                    self.transformations[graph_key],
                    nodes,
                    graph.spec.roi,
                    target_roi=self.target_rois[graph_key],
                )
                if not self.recompute_missing_points:
                    for node in set(missed_nodes):
                        graph.remove_node(node, retain_connectivity=True)
                    missed_nodes = []
            else:
                missed_nodes = nodes

            for node in missed_nodes:
                # logger.debug("projecting %s", node.location)

                # get location relative to beginning of upstream ROI
                location = node.location - graph.spec.roi.get_begin()
                logger.debug("relative to upstream ROI: %s", location)

                # get spatial coordinates of node in voxels
                location_voxels = location[-self.
                                           spatial_dims:] / self.voxel_size

                # get projected location in transformation data space, this
                # yields voxel coordinates relative to target ROI
                projected_voxels = self.__project(
                    self.transformations[graph_key], location_voxels)

                logger.debug("projected in voxels, relative to target ROI: %s",
                             projected_voxels)

                if projected_voxels is None:
                    logger.debug("node outside of target, skipping")
                    graph.remove_node(node, retain_connectivity=True)
                    continue

                # convert to world units (now in float again)
                projected = projected_voxels * np.array(self.voxel_size)

                logger.debug(
                    "projected in world units, relative to target ROI: %s",
                    projected)

                # get global coordinates
                projected += np.array(self.target_rois[graph_key].get_begin())

                # update spatial coordinates of node location
                node.location[-self.spatial_dims:] = projected

                logger.debug("final location: %s", node.location)

                # finally, it can happen that a node no longer is contained in
                # the requested ROI (because larger ROIs than necessary have
                # been requested upstream)
                if not request[graph_key].roi.contains(node.location):
                    logger.debug("node outside of target, skipping")
                    graph.remove_node(node, retain_connectivity=True)
                    continue

            # restore original ROIs
            graph.spec.roi = request[graph_key].roi
예제 #4
0
    def process(self, batch, request):

        for (array_key, array) in batch.arrays.items():

            # for arrays, the target ROI and the requested ROI should be the
            # same in spatial coordinates
            assert (self.target_rois[array_key].get_begin() ==
                    request[array_key].roi.get_begin()[-self.spatial_dims:])
            assert (self.target_rois[array_key].get_shape() ==
                    request[array_key].roi.get_shape()[-self.spatial_dims:])

            # reshape array data into (channels,) + spatial dims
            shape = array.data.shape
            channel_shape = shape[:-self.spatial_dims]
            data = array.data.reshape((-1, ) + shape[-self.spatial_dims:])

            # apply transformation on each channel
            data = np.array([
                augment.apply_transformation(
                    data[c],
                    self.transformations[array_key],
                    interpolate=self.spec[array_key].interpolatable)
                for c in range(data.shape[0])
            ])

            data_roi = request[array_key].roi / self.spec[array_key].voxel_size
            array.data = data.reshape(channel_shape + data_roi.get_shape())

            # restore original ROIs
            array.spec.roi = request[array_key].roi

        for (points_key, points) in batch.points.items():

            for point_id, point in list(points.data.items()):

                logger.debug("projecting %s", point.location)

                # get location relative to beginning of upstream ROI
                location = point.location - points.spec.roi.get_begin()
                logger.debug("relative to upstream ROI: %s", location)

                # get spatial coordinates of point in voxels
                location_voxels = location[-self.
                                           spatial_dims:] / self.voxel_size
                logger.debug("relative to upstream ROI in voxels: %s",
                             location_voxels)

                # get projected location in transformation data space, this
                # yields voxel coordinates relative to target ROI
                projected_voxels = self.__project(
                    self.transformations[points_key], location_voxels)

                logger.debug("projected in voxels, relative to target ROI: %s",
                             projected_voxels)

                if projected_voxels is None:
                    logger.debug("point outside of target, skipping")
                    del points.data[point_id]
                    continue

                # convert to world units (now in float again)
                projected = projected_voxels * np.array(self.voxel_size)

                logger.debug(
                    "projected in world units, relative to target ROI: %s",
                    projected)

                # get global coordinates
                projected += np.array(self.target_rois[points_key].get_begin())

                # update spatial coordinates of point location
                point.location[-self.spatial_dims:] = projected

                logger.debug("final location: %s", point.location)

                # finally, it can happen that a point no longer is contained in
                # the requested ROI (because larger ROIs than necessary have
                # been requested upstream)
                if not request[points_key].roi.contains(point.location):
                    logger.debug("point outside of target, skipping")
                    del points.data[point_id]
                    continue

            # restore original ROIs
            points.spec.roi = request[points_key].roi