def __select_random_location_with_points(
            self,
            request,
            lcm_shift_roi,
            lcm_voxel_size):

        request_points_roi = request[self.ensure_nonempty].roi

        while True:

            # How to pick shifts that ensure that a randomly chosen point is
            # contained in the request ROI:
            #
            #
            # request          point
            # [---------)      .
            # 0        +10     17
            #
            #         least shifted to contain point
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #
            #                  most shifted to contain point:
            #                  [---------)
            #                  17       +10
            #                  ==
            #                  point-request.begin
            #
            #         all possible shifts
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #                   ==
            #                   request.shape
            #
            # In the most shifted scenario, it could happen that the point lies
            # exactly at the lower boundary (17 in the example). This will cause
            # trouble if later we mirror the batch. The point would end up lying
            # on the other boundary, which is exclusive and thus not part of the
            # ROI. Therefore, we have to ensure that the point is well inside
            # the shifted ROI, not just on the boundary:
            #
            #         all possible shifts
            #         [--------)
            #         8       +9
            #                 ==
            #                 request.shape-1

            # pick a random point
            point_id = choice(self.points.data.keys())
            point = self.points.data[point_id]

            logger.debug(
                "select random point %d at %s",
                point_id,
                point.location)

            # get the lcm voxel that contains this point
            lcm_location = Coordinate(point.location/lcm_voxel_size)
            logger.debug(
                "belongs to lcm voxel %s",
                lcm_location)

            # mark all dimensions in which the point lies on the lower boundary
            # of the lcm voxel
            on_lower_boundary = lcm_location*lcm_voxel_size == point.location
            logger.debug(
                "lies on the lower boundary of the lcm voxel in dimensions %s",
                on_lower_boundary)

            # for each of these dimensions, we have to change the shape of the
            # shift ROI using the following correction
            lower_boundary_correction = Coordinate((
                -1 if o else 0
                for o in on_lower_boundary
            ))
            logger.debug(
                "lower bound correction for shape of shift ROI %s",
                lower_boundary_correction)

            # get the request ROI's shape in lcm
            lcm_roi_begin = request_points_roi.get_begin()/lcm_voxel_size
            lcm_roi_shape = request_points_roi.get_shape()/lcm_voxel_size
            logger.debug("Point request ROI: %s", request_points_roi)
            logger.debug("Point request lcm ROI shape: %s", lcm_roi_shape)

            # get all possible starting points of lcm_roi_shape that contain
            # lcm_location
            lcm_shift_roi_begin = (
                lcm_location - lcm_roi_begin - lcm_roi_shape +
                Coordinate((1,)*len(lcm_location))
            )
            lcm_shift_roi_shape = (
                lcm_roi_shape + lower_boundary_correction
            )
            lcm_point_shift_roi = Roi(lcm_shift_roi_begin, lcm_shift_roi_shape)
            logger.debug("lcm point shift roi: %s", lcm_point_shift_roi)

            # intersect with total shift ROI
            if not lcm_point_shift_roi.intersects(lcm_shift_roi):
                logger.debug(
                    "reject random shift, random point %s shift ROI %s does "
                    "not intersect total shift ROI %s", point.location,
                    lcm_point_shift_roi, lcm_shift_roi)
                continue
            lcm_point_shift_roi = lcm_point_shift_roi.intersect(lcm_shift_roi)

            # select a random shift from all possible shifts
            random_shift = self.__select_random_location(
                lcm_point_shift_roi,
                lcm_voxel_size)
            logger.debug("random shift: %s", random_shift)

            # count all points inside the shifted ROI
            points_request = BatchRequest()
            points_request[self.ensure_nonempty] = PointsSpec(
                roi=request_points_roi.shift(random_shift))
            logger.debug("points request: %s", points_request)
            points_batch = self.get_upstream_provider().request_batch(points_request)

            point_ids = points_batch.points[self.ensure_nonempty].data.keys()
            assert point_id in point_ids, (
                "Requested batch to contain point %s, but got points "
                "%s"%(point_id, point_ids))
            num_points = len(point_ids)

            # accept this shift with p=1/num_points
            #
            # This is to compensate the bias introduced by close-by points.
            accept = random() <= 1.0/num_points
            if accept:
                return random_shift
Exemple #2
0
    def __rasterize(self, points, data_roi, voxel_size, dtype, settings, mask_array=None):
        '''Rasterize 'points' into an array with the given 'voxel_size'''

        mask = mask_array.data if mask_array is not None else None

        logger.debug("Rasterizing points in %s", points.spec.roi)

        # prepare output array
        rasterized_points = np.zeros(data_roi.get_shape(), dtype=dtype)

        # Fast rasterization currently only implemented for mode ball without
        # inner radius set
        use_fast_rasterization = (
            settings.mode == 'ball' and
            settings.inner_radius_fraction is None
        )

        if use_fast_rasterization:

            dims = len(rasterized_points.shape)

            # get structuring element for mode ball
            ball_kernel = create_ball_kernel(settings.radius, voxel_size)
            radius_voxel = Coordinate(np.array(ball_kernel.shape)/2)
            data_roi_base = Roi(
                    offset=Coordinate((0,)*dims),
                    shape=Coordinate(rasterized_points.shape))
            kernel_roi_base = Roi(
                    offset=Coordinate((0,)*dims),
                    shape=Coordinate(ball_kernel.shape))

        # Rasterize volume either with single voxel or with defined struct elememt
        for point in points.data.values():

            # get the voxel coordinate, 'Coordinate' ensures integer
            v = Coordinate(point.location/voxel_size)

            # get the voxel coordinate relative to output array start
            v -= data_roi.get_begin()

            # skip points outside of mask
            if mask is not None and not mask[v]:
                continue

            logger.debug(
                "Rasterizing point %s at %s",
                point.location,
                point.location/voxel_size - data_roi.get_begin())

            if use_fast_rasterization:

                # Calculate where to crop the kernel mask and the rasterized array
                shifted_kernel = kernel_roi_base.shift(v - radius_voxel)
                shifted_data = data_roi_base.shift(-(v - radius_voxel))
                arr_crop = data_roi_base.intersect(shifted_kernel)
                kernel_crop = kernel_roi_base.intersect(shifted_data)
                arr_crop_ind = arr_crop.get_bounding_box()
                kernel_crop_ind = kernel_crop.get_bounding_box()

                rasterized_points[arr_crop_ind] = np.logical_or(
                    ball_kernel[kernel_crop_ind],
                    rasterized_points[arr_crop_ind])

            else:

                rasterized_points[v] = 1

        # grow points
        if not use_fast_rasterization:

            if settings.mode == 'ball':

                enlarge_binary_map(
                    rasterized_points,
                    settings.radius,
                    voxel_size,
                    1.0 - settings.inner_radius_fraction,
                    in_place=True)

            else:

                sigmas = settings.radius/voxel_size

                gaussian_filter(
                    rasterized_points,
                    sigmas,
                    output=rasterized_points,
                    mode='constant')

                # renormalize to have 1 be the highest value
                max_value = np.max(rasterized_points)
                if max_value > 0:
                    rasterized_points /= max_value

        if mask_array is not None:
            # use more efficient bitwise operation when possible
            if settings.mode == 'ball':
                rasterized_points &= mask
            else:
                rasterized_points *= mask

        return rasterized_points
    def __select_random_location_with_points(
            self,
            request,
            lcm_shift_roi,
            lcm_voxel_size):

        request_points = request.graph_specs.get(self.ensure_nonempty)
        if request_points is None:
            total_roi = request.get_total_roi()
            logger.warning(
                f"Requesting non empty {self.ensure_nonempty}, however {self.ensure_nonempty} "
                f"has not been requested. Falling back on using the total roi of the "
                f"request {total_roi} for {self.ensure_nonempty}."
            )
            request_points_roi = total_roi
        else:
            request_points_roi = request_points.roi

        while True:

            # How to pick shifts that ensure that a randomly chosen point is
            # contained in the request ROI:
            #
            #
            # request          point
            # [---------)      .
            # 0        +10     17
            #
            #         least shifted to contain point
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #
            #                  most shifted to contain point:
            #                  [---------)
            #                  17       +10
            #                  ==
            #                  point-request.begin
            #
            #         all possible shifts
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #                   ==
            #                   request.shape

            # pick a random point
            point = choices(self.points.data, cum_weights=self.cumulative_weights)[0]

            logger.debug("select random point at %s", point)

            # get the lcm voxel that contains this point
            lcm_location = Coordinate(point/lcm_voxel_size)
            logger.debug(
                "belongs to lcm voxel %s",
                lcm_location)

            # get the request ROI's shape in lcm
            lcm_roi_begin = request_points_roi.get_begin()/lcm_voxel_size
            lcm_roi_shape = request_points_roi.get_shape()/lcm_voxel_size
            logger.debug("Point request ROI: %s", request_points_roi)
            logger.debug("Point request lcm ROI shape: %s", lcm_roi_shape)

            # get all possible starting points of lcm_roi_shape that contain
            # lcm_location
            if self.ensure_centered:
                lcm_shift_roi_begin = (
                    lcm_location
                    - lcm_roi_begin
                    - lcm_roi_shape / 2
                    + Coordinate((1,) * len(lcm_location))
                )
                lcm_shift_roi_shape = Coordinate((1,) * len(lcm_location))
            else:
                lcm_shift_roi_begin = (
                    lcm_location - lcm_roi_begin - lcm_roi_shape +
                    Coordinate((1,)*len(lcm_location))
                )
                lcm_shift_roi_shape = lcm_roi_shape
            lcm_point_shift_roi = Roi(lcm_shift_roi_begin, lcm_shift_roi_shape)

            logger.debug("lcm point shift roi: %s", lcm_point_shift_roi)

            # intersect with total shift ROI
            if not lcm_point_shift_roi.intersects(lcm_shift_roi):
                logger.debug(
                    "reject random shift, random point %s shift ROI %s does "
                    "not intersect total shift ROI %s", point,
                    lcm_point_shift_roi, lcm_shift_roi)
                continue
            lcm_point_shift_roi = lcm_point_shift_roi.intersect(lcm_shift_roi)

            # select a random shift from all possible shifts
            random_shift = self.__select_random_location(
                lcm_point_shift_roi,
                lcm_voxel_size)
            logger.debug("random shift: %s", random_shift)

            # count all points inside the shifted ROI
            points = self.__get_points_in_roi(
                request_points_roi.shift(random_shift))
            assert point in points, (
                "Requested batch to contain point %s, but got points "
                "%s"%(point, points))
            num_points = len(points)

            return random_shift
Exemple #4
0
    def __rasterize(self, graph, data_roi, voxel_size, dtype, settings, mask_array=None):
        '''Rasterize 'graph' into an array with the given 'voxel_size'''

        mask = mask_array.data if mask_array is not None else None

        logger.debug("Rasterizing graph in %s", graph.spec.roi)

        # prepare output array
        rasterized_graph = np.zeros(data_roi.get_shape(), dtype=dtype)

        # Fast rasterization currently only implemented for mode ball without
        # inner radius set
        use_fast_rasterization = (
            settings.mode == "ball"
            and settings.inner_radius_fraction is None
            and len(list(graph.edges)) == 0
        )

        if use_fast_rasterization:

            dims = len(rasterized_graph.shape)

            # get structuring element for mode ball
            ball_kernel = create_ball_kernel(settings.radius, voxel_size)
            radius_voxel = Coordinate(np.array(ball_kernel.shape)/2)
            data_roi_base = Roi(
                    offset=Coordinate((0,)*dims),
                    shape=Coordinate(rasterized_graph.shape))
            kernel_roi_base = Roi(
                    offset=Coordinate((0,)*dims),
                    shape=Coordinate(ball_kernel.shape))

        # Rasterize volume either with single voxel or with defined struct elememt
        for node in graph.nodes:

            # get the voxel coordinate, 'Coordinate' ensures integer
            v = Coordinate(node.location/voxel_size)

            # get the voxel coordinate relative to output array start
            v -= data_roi.get_begin()

            # skip graph outside of mask
            if mask is not None and not mask[v]:
                continue

            logger.debug(
                "Rasterizing node %s at %s",
                node.location,
                node.location/voxel_size - data_roi.get_begin())

            if use_fast_rasterization:

                # Calculate where to crop the kernel mask and the rasterized array
                shifted_kernel = kernel_roi_base.shift(v - radius_voxel)
                shifted_data = data_roi_base.shift(-(v - radius_voxel))
                arr_crop = data_roi_base.intersect(shifted_kernel)
                kernel_crop = kernel_roi_base.intersect(shifted_data)
                arr_crop_ind = arr_crop.get_bounding_box()
                kernel_crop_ind = kernel_crop.get_bounding_box()

                rasterized_graph[arr_crop_ind] = np.logical_or(
                    ball_kernel[kernel_crop_ind], rasterized_graph[arr_crop_ind]
                )

            else:

                if settings.color_attr is not None:
                    c = graph.nodes[node].get(settings.color_attr)
                    if c is None:
                        logger.debug(f"Skipping node: {node}")
                        continue
                    elif np.isclose(c, 1) and not np.isclose(settings.fg_value, 1):
                        logger.warning(
                            f"Node {node} is being colored with color {c} according to "
                            f"attribute {settings.color_attr} "
                            f"but color 1 will be replaced with fg_value: {settings.fg_value}"
                            )
                else:
                    c = 1
                rasterized_graph[v] = c
        if settings.edges:
            for e in graph.edges:
                if settings.color_attr is not None:
                    c = graph.edges[e].get(settings.color_attr)
                    if c is None:
                        continue
                    elif np.isclose(c, 1) and not np.isclose(settings.fg_value, 1):
                        logger.warning(
                            f"Edge {e} is being colored with color {c} according to "
                            f"attribute {settings.color_attr} "
                            f"but color 1 will be replaced with fg_value: {settings.fg_value}"
                            )

                u = graph.node(e.u)
                v = graph.node(e.v)
                u_coord = Coordinate(u.location / voxel_size)
                v_coord = Coordinate(v.location / voxel_size)
                line = draw.line_nd(u_coord, v_coord, endpoint=True)
                rasterized_graph[line] = 1

        # grow graph
        if not use_fast_rasterization:

            if settings.mode == "ball":

                enlarge_binary_map(
                    rasterized_graph,
                    settings.radius,
                    voxel_size,
                    settings.inner_radius_fraction,
                    in_place=True)

            else:

                sigmas = settings.radius/voxel_size

                gaussian_filter(
                    rasterized_graph, sigmas, output=rasterized_graph, mode="constant"
                )

                # renormalize to have 1 be the highest value
                max_value = np.max(rasterized_graph)
                if max_value > 0:
                    rasterized_graph /= max_value

        if mask_array is not None:
            # use more efficient bitwise operation when possible
            if settings.mode == "ball":
                rasterized_graph &= mask
            else:
                rasterized_graph *= mask

        return rasterized_graph