def __select_random_location_with_points(
            self,
            request,
            lcm_shift_roi,
            lcm_voxel_size):

        request_points_roi = request[self.ensure_nonempty].roi

        while True:

            # How to pick shifts that ensure that a randomly chosen point is
            # contained in the request ROI:
            #
            #
            # request          point
            # [---------)      .
            # 0        +10     17
            #
            #         least shifted to contain point
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #
            #                  most shifted to contain point:
            #                  [---------)
            #                  17       +10
            #                  ==
            #                  point-request.begin
            #
            #         all possible shifts
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #                   ==
            #                   request.shape
            #
            # In the most shifted scenario, it could happen that the point lies
            # exactly at the lower boundary (17 in the example). This will cause
            # trouble if later we mirror the batch. The point would end up lying
            # on the other boundary, which is exclusive and thus not part of the
            # ROI. Therefore, we have to ensure that the point is well inside
            # the shifted ROI, not just on the boundary:
            #
            #         all possible shifts
            #         [--------)
            #         8       +9
            #                 ==
            #                 request.shape-1

            # pick a random point
            point_id = choice(self.points.data.keys())
            point = self.points.data[point_id]

            logger.debug(
                "select random point %d at %s",
                point_id,
                point.location)

            # get the lcm voxel that contains this point
            lcm_location = Coordinate(point.location/lcm_voxel_size)
            logger.debug(
                "belongs to lcm voxel %s",
                lcm_location)

            # mark all dimensions in which the point lies on the lower boundary
            # of the lcm voxel
            on_lower_boundary = lcm_location*lcm_voxel_size == point.location
            logger.debug(
                "lies on the lower boundary of the lcm voxel in dimensions %s",
                on_lower_boundary)

            # for each of these dimensions, we have to change the shape of the
            # shift ROI using the following correction
            lower_boundary_correction = Coordinate((
                -1 if o else 0
                for o in on_lower_boundary
            ))
            logger.debug(
                "lower bound correction for shape of shift ROI %s",
                lower_boundary_correction)

            # get the request ROI's shape in lcm
            lcm_roi_begin = request_points_roi.get_begin()/lcm_voxel_size
            lcm_roi_shape = request_points_roi.get_shape()/lcm_voxel_size
            logger.debug("Point request ROI: %s", request_points_roi)
            logger.debug("Point request lcm ROI shape: %s", lcm_roi_shape)

            # get all possible starting points of lcm_roi_shape that contain
            # lcm_location
            lcm_shift_roi_begin = (
                lcm_location - lcm_roi_begin - lcm_roi_shape +
                Coordinate((1,)*len(lcm_location))
            )
            lcm_shift_roi_shape = (
                lcm_roi_shape + lower_boundary_correction
            )
            lcm_point_shift_roi = Roi(lcm_shift_roi_begin, lcm_shift_roi_shape)
            logger.debug("lcm point shift roi: %s", lcm_point_shift_roi)

            # intersect with total shift ROI
            if not lcm_point_shift_roi.intersects(lcm_shift_roi):
                logger.debug(
                    "reject random shift, random point %s shift ROI %s does "
                    "not intersect total shift ROI %s", point.location,
                    lcm_point_shift_roi, lcm_shift_roi)
                continue
            lcm_point_shift_roi = lcm_point_shift_roi.intersect(lcm_shift_roi)

            # select a random shift from all possible shifts
            random_shift = self.__select_random_location(
                lcm_point_shift_roi,
                lcm_voxel_size)
            logger.debug("random shift: %s", random_shift)

            # count all points inside the shifted ROI
            points_request = BatchRequest()
            points_request[self.ensure_nonempty] = PointsSpec(
                roi=request_points_roi.shift(random_shift))
            logger.debug("points request: %s", points_request)
            points_batch = self.get_upstream_provider().request_batch(points_request)

            point_ids = points_batch.points[self.ensure_nonempty].data.keys()
            assert point_id in point_ids, (
                "Requested batch to contain point %s, but got points "
                "%s"%(point_id, point_ids))
            num_points = len(point_ids)

            # accept this shift with p=1/num_points
            #
            # This is to compensate the bias introduced by close-by points.
            accept = random() <= 1.0/num_points
            if accept:
                return random_shift
Beispiel #2
0
    def __select_random_location_with_points(
            self,
            request,
            lcm_shift_roi,
            lcm_voxel_size):

        request_points = request.graph_specs.get(self.ensure_nonempty)
        if request_points is None:
            total_roi = request.get_total_roi()
            logger.warning(
                f"Requesting non empty {self.ensure_nonempty}, however {self.ensure_nonempty} "
                f"has not been requested. Falling back on using the total roi of the "
                f"request {total_roi} for {self.ensure_nonempty}."
            )
            request_points_roi = total_roi
        else:
            request_points_roi = request_points.roi

        while True:

            # How to pick shifts that ensure that a randomly chosen point is
            # contained in the request ROI:
            #
            #
            # request          point
            # [---------)      .
            # 0        +10     17
            #
            #         least shifted to contain point
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #
            #                  most shifted to contain point:
            #                  [---------)
            #                  17       +10
            #                  ==
            #                  point-request.begin
            #
            #         all possible shifts
            #         [---------)
            #         8        +10
            #         ==
            #         point-request.begin-request.shape+1
            #                   ==
            #                   request.shape

            # pick a random point
            point = choices(self.points.data, cum_weights=self.cumulative_weights)[0]

            logger.debug("select random point at %s", point)

            # get the lcm voxel that contains this point
            lcm_location = Coordinate(point/lcm_voxel_size)
            logger.debug(
                "belongs to lcm voxel %s",
                lcm_location)

            # get the request ROI's shape in lcm
            lcm_roi_begin = request_points_roi.get_begin()/lcm_voxel_size
            lcm_roi_shape = request_points_roi.get_shape()/lcm_voxel_size
            logger.debug("Point request ROI: %s", request_points_roi)
            logger.debug("Point request lcm ROI shape: %s", lcm_roi_shape)

            # get all possible starting points of lcm_roi_shape that contain
            # lcm_location
            if self.ensure_centered:
                lcm_shift_roi_begin = (
                    lcm_location
                    - lcm_roi_begin
                    - lcm_roi_shape / 2
                    + Coordinate((1,) * len(lcm_location))
                )
                lcm_shift_roi_shape = Coordinate((1,) * len(lcm_location))
            else:
                lcm_shift_roi_begin = (
                    lcm_location - lcm_roi_begin - lcm_roi_shape +
                    Coordinate((1,)*len(lcm_location))
                )
                lcm_shift_roi_shape = lcm_roi_shape
            lcm_point_shift_roi = Roi(lcm_shift_roi_begin, lcm_shift_roi_shape)

            logger.debug("lcm point shift roi: %s", lcm_point_shift_roi)

            # intersect with total shift ROI
            if not lcm_point_shift_roi.intersects(lcm_shift_roi):
                logger.debug(
                    "reject random shift, random point %s shift ROI %s does "
                    "not intersect total shift ROI %s", point,
                    lcm_point_shift_roi, lcm_shift_roi)
                continue
            lcm_point_shift_roi = lcm_point_shift_roi.intersect(lcm_shift_roi)

            # select a random shift from all possible shifts
            random_shift = self.__select_random_location(
                lcm_point_shift_roi,
                lcm_voxel_size)
            logger.debug("random shift: %s", random_shift)

            # count all points inside the shifted ROI
            points = self.__get_points_in_roi(
                request_points_roi.shift(random_shift))
            assert point in points, (
                "Requested batch to contain point %s, but got points "
                "%s"%(point, points))
            num_points = len(points)

            return random_shift