def __select_random_location_with_points( self, request, lcm_shift_roi, lcm_voxel_size): request_points_roi = request[self.ensure_nonempty].roi while True: # How to pick shifts that ensure that a randomly chosen point is # contained in the request ROI: # # # request point # [---------) . # 0 +10 17 # # least shifted to contain point # [---------) # 8 +10 # == # point-request.begin-request.shape+1 # # most shifted to contain point: # [---------) # 17 +10 # == # point-request.begin # # all possible shifts # [---------) # 8 +10 # == # point-request.begin-request.shape+1 # == # request.shape # # In the most shifted scenario, it could happen that the point lies # exactly at the lower boundary (17 in the example). This will cause # trouble if later we mirror the batch. The point would end up lying # on the other boundary, which is exclusive and thus not part of the # ROI. Therefore, we have to ensure that the point is well inside # the shifted ROI, not just on the boundary: # # all possible shifts # [--------) # 8 +9 # == # request.shape-1 # pick a random point point_id = choice(self.points.data.keys()) point = self.points.data[point_id] logger.debug( "select random point %d at %s", point_id, point.location) # get the lcm voxel that contains this point lcm_location = Coordinate(point.location/lcm_voxel_size) logger.debug( "belongs to lcm voxel %s", lcm_location) # mark all dimensions in which the point lies on the lower boundary # of the lcm voxel on_lower_boundary = lcm_location*lcm_voxel_size == point.location logger.debug( "lies on the lower boundary of the lcm voxel in dimensions %s", on_lower_boundary) # for each of these dimensions, we have to change the shape of the # shift ROI using the following correction lower_boundary_correction = Coordinate(( -1 if o else 0 for o in on_lower_boundary )) logger.debug( "lower bound correction for shape of shift ROI %s", lower_boundary_correction) # get the request ROI's shape in lcm lcm_roi_begin = request_points_roi.get_begin()/lcm_voxel_size lcm_roi_shape = request_points_roi.get_shape()/lcm_voxel_size logger.debug("Point request ROI: %s", request_points_roi) logger.debug("Point request lcm ROI shape: %s", lcm_roi_shape) # get all possible starting points of lcm_roi_shape that contain # lcm_location lcm_shift_roi_begin = ( lcm_location - lcm_roi_begin - lcm_roi_shape + Coordinate((1,)*len(lcm_location)) ) lcm_shift_roi_shape = ( lcm_roi_shape + lower_boundary_correction ) lcm_point_shift_roi = Roi(lcm_shift_roi_begin, lcm_shift_roi_shape) logger.debug("lcm point shift roi: %s", lcm_point_shift_roi) # intersect with total shift ROI if not lcm_point_shift_roi.intersects(lcm_shift_roi): logger.debug( "reject random shift, random point %s shift ROI %s does " "not intersect total shift ROI %s", point.location, lcm_point_shift_roi, lcm_shift_roi) continue lcm_point_shift_roi = lcm_point_shift_roi.intersect(lcm_shift_roi) # select a random shift from all possible shifts random_shift = self.__select_random_location( lcm_point_shift_roi, lcm_voxel_size) logger.debug("random shift: %s", random_shift) # count all points inside the shifted ROI points_request = BatchRequest() points_request[self.ensure_nonempty] = PointsSpec( roi=request_points_roi.shift(random_shift)) logger.debug("points request: %s", points_request) points_batch = self.get_upstream_provider().request_batch(points_request) point_ids = points_batch.points[self.ensure_nonempty].data.keys() assert point_id in point_ids, ( "Requested batch to contain point %s, but got points " "%s"%(point_id, point_ids)) num_points = len(point_ids) # accept this shift with p=1/num_points # # This is to compensate the bias introduced by close-by points. accept = random() <= 1.0/num_points if accept: return random_shift
def __select_random_location_with_points( self, request, lcm_shift_roi, lcm_voxel_size): request_points = request.graph_specs.get(self.ensure_nonempty) if request_points is None: total_roi = request.get_total_roi() logger.warning( f"Requesting non empty {self.ensure_nonempty}, however {self.ensure_nonempty} " f"has not been requested. Falling back on using the total roi of the " f"request {total_roi} for {self.ensure_nonempty}." ) request_points_roi = total_roi else: request_points_roi = request_points.roi while True: # How to pick shifts that ensure that a randomly chosen point is # contained in the request ROI: # # # request point # [---------) . # 0 +10 17 # # least shifted to contain point # [---------) # 8 +10 # == # point-request.begin-request.shape+1 # # most shifted to contain point: # [---------) # 17 +10 # == # point-request.begin # # all possible shifts # [---------) # 8 +10 # == # point-request.begin-request.shape+1 # == # request.shape # pick a random point point = choices(self.points.data, cum_weights=self.cumulative_weights)[0] logger.debug("select random point at %s", point) # get the lcm voxel that contains this point lcm_location = Coordinate(point/lcm_voxel_size) logger.debug( "belongs to lcm voxel %s", lcm_location) # get the request ROI's shape in lcm lcm_roi_begin = request_points_roi.get_begin()/lcm_voxel_size lcm_roi_shape = request_points_roi.get_shape()/lcm_voxel_size logger.debug("Point request ROI: %s", request_points_roi) logger.debug("Point request lcm ROI shape: %s", lcm_roi_shape) # get all possible starting points of lcm_roi_shape that contain # lcm_location if self.ensure_centered: lcm_shift_roi_begin = ( lcm_location - lcm_roi_begin - lcm_roi_shape / 2 + Coordinate((1,) * len(lcm_location)) ) lcm_shift_roi_shape = Coordinate((1,) * len(lcm_location)) else: lcm_shift_roi_begin = ( lcm_location - lcm_roi_begin - lcm_roi_shape + Coordinate((1,)*len(lcm_location)) ) lcm_shift_roi_shape = lcm_roi_shape lcm_point_shift_roi = Roi(lcm_shift_roi_begin, lcm_shift_roi_shape) logger.debug("lcm point shift roi: %s", lcm_point_shift_roi) # intersect with total shift ROI if not lcm_point_shift_roi.intersects(lcm_shift_roi): logger.debug( "reject random shift, random point %s shift ROI %s does " "not intersect total shift ROI %s", point, lcm_point_shift_roi, lcm_shift_roi) continue lcm_point_shift_roi = lcm_point_shift_roi.intersect(lcm_shift_roi) # select a random shift from all possible shifts random_shift = self.__select_random_location( lcm_point_shift_roi, lcm_voxel_size) logger.debug("random shift: %s", random_shift) # count all points inside the shifted ROI points = self.__get_points_in_roi( request_points_roi.shift(random_shift)) assert point in points, ( "Requested batch to contain point %s, but got points " "%s"%(point, points)) num_points = len(points) return random_shift