def createRtree(data): """ Creates an R-Tree from the given data """ tree = Rtree() for index, pair in enumerate(data): tree.insert(index, (pair[3], pair[4]), obj=pair) return tree
def generate_rtree_from_entities(self): """Create an rtree with all entities with bounding rectangles.""" self.bounding_rects = { id_: e.bounding_rect for id_, e in self.entities.items() if hasattr(e, 'bounding_rect') } self.rtree = Rtree( (id_, rect, None) for id_, rect in self.bounding_rects.items())
def reset(self, name: Optional[str] = None): """Reset index and set name.""" self.name = name self.id_count = count() self.entities = {} self.bounding_rects = {} self.rtree = Rtree() self.path_map = PathMap() self.register_updates = False self.simulation = Simulation() self.stats = defaultdict(dict) self._updates = set()
def _merge_points(points: Dict[int, ConflictPoint], rtree: Rtree): """Merge conflict points closer than MERGE_RADIUS.""" curves = set() merged = set() for id_, point in points.items(): if id_ in merged: continue for other_id in rtree.intersection( point.point.enclosing_rect(MERGE_RADIUS)): if other_id == id_: continue other = points[other_id] for curve in other.curves: curves.add(curve) curve.replace_conflict_point(other, point) merged.add(other_id) rtree.delete(other_id, other.point.bounding_rect) for id_ in merged: del points[id_] for curve in curves: curve.remove_conflict_point_duplicates()
def _is_minimal_trajectory(self, trajectory: Trajectory, prior_end_poses: index.Rtree) -> bool: """ Determine wheter a trajectory is a minimal trajectory. Uses an RTree for speedup. Args: trajectory: Trajectory The trajectory to check prior_end_poses: RTree An RTree holding the current minimal set of trajectories Returns ------- bool True if the trajectory is a minimal trajectory otherwise false """ # Iterate over line segments in the trajectory for x1, y1, x2, y2, yaw in zip( trajectory.path.xs[:-1], trajectory.path.ys[:-1], trajectory.path.xs[1:], trajectory.path.ys[1:], trajectory.path.yaws[:-1], ): p1 = np.array([x1, y1]) p2 = np.array([x2, y2]) # Create a bounding box search region # around the line segment left_bb = min(x1, x2) - self.DISTANCE_THRESHOLD right_bb = max(x1, x2) + self.DISTANCE_THRESHOLD top_bb = max(y1, y2) + self.DISTANCE_THRESHOLD bottom_bb = min(y1, y2) - self.DISTANCE_THRESHOLD # For any previous end points in the search region we # check the distance to that point and the angle # difference. If they are within threshold then this # trajectory can be composed from a previous trajectory for prior_end_pose in prior_end_poses.intersection( (left_bb, bottom_bb, right_bb, top_bb), objects='raw'): if (self._point_to_line_distance( p1, p2, prior_end_pose[:-1]) < self.DISTANCE_THRESHOLD and angle_difference(yaw, prior_end_pose[-1]) < self.ROTATION_THRESHOLD): return False return True
def _fill_neighbors(points: Dict[int, ConflictPoint], rtree: Rtree, skip_in_same_curve: bool = True): """Add conflict points closer than NEIGHBOR_RADIUS as neighbors.""" for id_, point in points.items(): for other_id in rtree.intersection( point.point.enclosing_rect(NEIGHBOR_RADIUS)): if other_id == id_: continue other = points[other_id] if skip_in_same_curve and (point.curves & other.curves): continue distance_squared = point.point.distance_squared(other.point) if distance_squared <= NEIGHBOR_RADIUS_SQUARED: point.neighbors.add(other)
def _build_conflict_points(node: Node, curves: Dict[LaneConnection, Curve]): """Create conflict points from lane connection curves. Fill conflict points in given Curve objects. Merges points that are less than MERGE_RADIUS apart and add points that are within NEIGHBOR_RADIUS as neighbors. """ id_generator = count() points: Dict[int, ConflictPoint] = {} diverge = defaultdict(set) merge = defaultdict(set) for (lanes1, curve1), (lanes2, curve2) in combinations(curves.items(), 2): if lanes1[0] == lanes2[0]: diverge[lanes1[0]].update((curve1, curve2)) if lanes1[1] == lanes2[1]: merge[lanes1[1]].update((curve1, curve2)) _add_crossing_conflict_points(id_generator, points, curve1, curve2) _add_diverge_merge_conflict_points(id_generator, points, diverge, merge) if len(points) > 1: rtree = Rtree((id_, p.point.bounding_rect, None) for id_, p in points.items()) _merge_points(points, rtree) _fill_neighbors(points, rtree) for point in points.values(): point.create_lock_order() # This makes the conflict point positions relative to the node. point.point = point.point - node.position for curve in curves.values(): curve.remove_redundant_conflict_point() # Remove from neighbors points that aren't in any curve points = {p for _, p in chain.from_iterable(c.conflict_points for c in curves.values())} for point in points: for neighbor in list(point.neighbors): if neighbor not in points: point.neighbors.remove(neighbor)
class Rtree2D(object): """Wrapper of `rtree.Index` for supporting friendly 2d operations. Also forces the uniqueness of the `id` parameter, which is different from the rtree module's behavior. """ def __init__(self): self._index = Rtree() self._locations = {} @staticmethod def to_coords(location): return (location[0], location[1], location[0], location[1]) def keys(self): return self._locations.keys() def get(self, id, objects=False): return self._locations.get(id) def set(self, id, location, obj=None): # Clean up previous value first if any old = self._locations.get(id) if old is not None: self._index.delete(id, self.to_coords(old)) self._locations[id] = location self._index.insert(id, self.to_coords(location), obj=obj) def remove(self, id): self._index.delete(id, self.to_coords(self._locations[id])) del self._locations[id] def nearest(self, location, count=1, objects=False, max_distance=None): ids = self._index.nearest(self.to_coords(location), num_results=count, objects=objects) if max_distance is not None: ids = [id_ for id_ in ids if distance(self._locations[id_], location) <= max_distance] return ids
def main(query, train, query_num, qgram_size): logger = logging.getLogger('search_rtree') logger.setLevel(logging.DEBUG) fh = logging.FileHandler('./log/%s' % query) fh.setLevel(logging.DEBUG) # create console handler with a higher log level ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) # create formatter and add it to the handlers formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh.setFormatter(formatter) ch.setFormatter(formatter) # add the handlers to the logger logger.addHandler(fh) logger.addHandler(ch) logger.info( '------------------------- Calculate common Q-grams for query trajectories -------------------------' ) qgram_tag = 'q_%d' % qgram_size query_path = './data/processed/%s.txt' % query rtree_path = './data/interim/%s/my_rtree_%s' % (train, qgram_tag) logger.info('Query trajectory path: %s' % query_path) logger.info('Rtree path: %s' % rtree_path) query_data = load_trajectory(query_path, n=query_num) logger.info('Load %d query trajectories' % query_num) qry_qgram, qry_id_list = build_qgram(query_data, qgram_size) qry_id_dict = build_id_to_key( qry_id_list) # key: query_id, value: query_key data_index = Rtree(rtree_path) conf = SparkConf().setAppName("PythonWordCount").setMaster("local") sc = SparkContext(conf=conf) all_data = [] for qry_id, qry_qgrams in qry_qgram.items(): qry_key = qry_id_dict[qry_id] data = [] for qry_qgram in qry_qgrams: matches = [ hit.object for hit in data_index.intersection(qry_qgram, objects=True) ] matches = set(matches) data.append(list(matches)) flat_data = [item for sublist in data for item in sublist] # print(flat_data) dist_data = sc.parallelize(flat_data) map_data = dist_data.map(lambda x: (x, 1)) reduce_data = map_data.reduceByKey(lambda a, b: a + b).sortBy( lambda x: x[1], ascending=False).collect() # print(reduce_data) all_data.append([qry_key, reduce_data]) if not os.path.exists('./data/interim/%s' % query): os.mkdir('./data/interim/%s' % query) if not os.path.exists('./data/interim/%s/%s' % (query, train)): os.mkdir('./data/interim/%s/%s' % (query, train)) candidate_traj_path = './data/interim/%s/%s/candidate_trajectory_%s.txt' % ( query, train, qgram_tag) save_pickle(all_data, candidate_traj_path) logger.info('Output candidate_trajectory: %s' % candidate_traj_path) query_id_dict_path = './data/interim/%s/%s/query_id_dict_%s.txt' % ( query, train, qgram_tag) logger.info('Output query_id_dict: %s' % query_id_dict_path) save_pickle(qry_id_dict, query_id_dict_path) gc.collect()
""" Deletes a page """ try: RTreePage.objects(name=self.name, page=page).delete(safe=True) except: returnError.contents.value = self.InvalidPageError hasData = property( lambda self: RTreePage.objects(name=self.name).first() is not None ) if __name__=='__main__': settings = Property() settings.writethrough= True settings.buffering_capacity=1 storage = MongoStorage('test') storage.clear() r = Rtree( storage, properties=settings) r.add(123,(0,0,1,1)) print "test 1 should be true" item = list(r.nearest((0,0), 1, objects=True))[0] print item.id print r.valid() print "test 2 should be true" r.delete(123, (0,0,1,1)) print r.valid() print "test 3 should be true" r.clearBuffer() print r.valid()
RTreePage.objects(name=self.name, page=page).delete(safe=True) except: returnError.contents.value = self.InvalidPageError hasData = property( lambda self: RTreePage.objects(name=self.name).first() is not None) if __name__ == '__main__': settings = Property() settings.writethrough = True settings.buffering_capacity = 1 storage = MongoStorage('test') storage.clear() r = Rtree(storage, properties=settings) r.add(123, (0, 0, 1, 1)) print "test 1 should be true" item = list(r.nearest((0, 0), 1, objects=True))[0] print item.id print r.valid() print "test 2 should be true" r.delete(123, (0, 0, 1, 1)) print r.valid() print "test 3 should be true" r.clearBuffer() print r.valid()
def __init__(self): self._index = Rtree() self._locations = {}
class EntityIndex: """Index of spatial entities. When an entity is added to the index, it gets an unique id and is kept in a way than can be queried by id or by spatial coordinates. """ __slots__ = ('name', 'id_count', 'entities', 'bounding_rects', 'rtree', 'path_map', 'register_updates', 'simulation', 'stats', '_updates') extension = 'shelf' storage_fields = 'id_count', 'entities', 'path_map' name: str id_count: count entities: Dict[int, Entity] bounding_rects: Dict[int, BoundingRect] rtree: Rtree path_map: PathMap register_updates: bool simulation: Simulation # TODO: Define type for stats instead of Any. stats: Dict[Type, Dict[Any, Any]] _updates: Set[int] def __init__(self, name: Optional[str] = None): self.reset(name) @property def filename(self) -> str: """Name with extension added.""" if self.name.endswith(f'.{EntityIndex.extension}'): return self.name return f'{self.name}.{EntityIndex.extension}' def reset(self, name: Optional[str] = None): """Reset index and set name.""" self.name = name self.id_count = count() self.entities = {} self.bounding_rects = {} self.rtree = Rtree() self.path_map = PathMap() self.register_updates = False self.simulation = Simulation() self.stats = defaultdict(dict) self._updates = set() def add(self, entity: Entity): """Add entity to index.""" if entity.id is not None: raise ValueError('Entity already has an id.') entity.id = next(self.id_count) self.entities[entity.id] = entity log.debug('[%s] Added %s', __name__, Entity.__repr__(entity)) def add_static(self, entity: Entity): """Add entity as static. Entity may or not have already been added with the `add` method. It will be added in case it was not already. A static entity is an entity with geometric information (`bounding_rect`) that will rarely change. A spatial index is used to allow for quick spatial queries. These entities are added to the updated queue when something about them changes. This queue can be consumed by a front end application with `consume_updates` to update the representation only when needed. """ if entity.id is None: self.add(entity) if entity.id in self.bounding_rects: raise ValueError('Entity already added as static.') self.bounding_rects[entity.id] = entity.bounding_rect self.rtree.insert(entity.id, entity.bounding_rect) self.updated(entity) def delete(self, entity: Entity): """Delete entity from index.""" to_remove = {entity} while to_remove: entity = to_remove.pop() assert self.entities[entity.id] is entity del self.entities[entity.id] delete_result = entity.on_delete() to_remove.update(delete_result.cascade) for updated in delete_result.updated: self.updated(updated) if entity.id in self.bounding_rects: self.rtree.delete(entity.id, self.bounding_rects[entity.id]) del self.bounding_rects[entity.id] self.updated(entity) self.rebuild_path_map() log.debug('[%s] Removed %s', __name__, entity) def update_bounding_rect(self, entity: Entity, new_rect: Optional[BoundingRect] = None): """Change the bounding rectangle of an entity. Update the bounding rect to `entity.bounding_rect` or to `new_rect` if it's not None. """ assert self.entities[entity.id] is entity if new_rect is None: new_rect = entity.bounding_rect old_rect = self.bounding_rects.get(entity.id, None) if old_rect is None or old_rect == new_rect: return self.rtree.delete(entity.id, old_rect) self.bounding_rects[entity.id] = new_rect self.rtree.insert(entity.id, new_rect) def updated(self, entity: Union[Entity, int]): """Mark entity as updated.""" if self.register_updates: try: self._updates.add(entity.id) except AttributeError: self._updates.add(entity) def clear_updates(self): """Clear entity updates.""" self._updates.clear() def consume_updates(self) -> Iterator[int]: """Get generator that pops and returns updates.""" while self._updates: yield self._updates.pop() def generate_rtree_from_entities(self): """Create an rtree with all entities with bounding rectangles.""" self.bounding_rects = { id_: e.bounding_rect for id_, e in self.entities.items() if hasattr(e, 'bounding_rect') } self.rtree = Rtree( (id_, rect, None) for id_, rect in self.bounding_rects.items()) def load(self, name: Optional[str] = None): """Load entities from shelf. Load enities using the this index name. If a name is passed as argument, will set the index name before loading. """ if name is not None: self.name = name with shelve.open(self.filename) as data: for key in EntityIndex.storage_fields: log.info('Loading %s', key) value = data.get(key, None) if value: setattr(self, key, value) log.info('Loaded %s', self.name) self.generate_rtree_from_entities() if not hasattr(self, 'path_map'): self.rebuild_path_map() def save(self): """Save entities to shelf.""" with shelve.open(self.filename) as data: for key in EntityIndex.storage_fields: log.info('Saving %s', key) data[key] = getattr(self, key) def get_all(self, of_type: Type[Entity] = None, where: Callable[[Entity], bool] = None) -> Iterator[Entity]: """Get all entities with optional filters.""" def type_filter(entity): return isinstance(entity, of_type) filters = [] if of_type is not None: filters.append(type_filter) if where is not None: filters.append(where) yield from filter(lambda e: all(f(e) for f in filters), self.entities.values()) def get_at(self, point: Point, of_type: Type[Entity] = None, where: Callable[[Entity], bool] = None) -> List[Entity]: """Get entities at given coordinates. Get a list with entities intersecting the given point. If of_type is not None, will return only entities of the given type. If where is not None, where must be a function that receives an Entity and returns True or False, meaning whether the entity will be returned. """ def polygon_filter(entity: Entity) -> bool: return point_in_polygon(point, entity.polygon) def type_filter(entity: Entity) -> bool: return isinstance(entity, of_type) filters = [polygon_filter] if of_type is not None: filters.append(type_filter) if where is not None: filters.append(where) return list( filter( lambda e: all(f(e) for f in filters), map(self.entities.get, self.rtree.intersection(point.bounding_rect)))) def rebuild_path_map(self): """Rebuild the path map, invalidating the old map.""" self.path_map = PathMap()
def main(train, qgram_size): logger = logging.getLogger('build_rtree') logger.setLevel(logging.DEBUG) fh = logging.FileHandler('./log/%s' % train) fh.setLevel(logging.DEBUG) # create console handler with a higher log level ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) # create formatter and add it to the handlers formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh.setFormatter(formatter) ch.setFormatter(formatter) # add the handlers to the logger logger.addHandler(fh) logger.addHandler(ch) logger.info('---------------------------- Build R-tree ----------------------------') qgram_tag = 'q_%d' % qgram_size train_path = './data/processed/%s.txt' % train data = load_trajectory(train_path) logger.info('Load train trajectory: %s' % train_path) trajectory, id_list = build_qgram(data, qgram_size) id_to_key_dict = build_id_to_key(id_list) # order_key_dict = build_order_dict(id_list) #save orderId-key mapping #key: trajectory id in string, value: encoded key rtree_id_dict_path = './data/interim/%s/rtree_id_dict_%s.txt' % (train, qgram_tag) save_pickle(id_to_key_dict, rtree_id_dict_path) logger.info('Output rtree_id_dict: %s' % rtree_id_dict_path) #key: key, value: trajectory id in string # filename = '../data/processed/order_key_dict.txt' # outfile = open(filename,'wb') # pickle.dump(order_key_dict,outfile) # outfile.close() # R-tree constructor # parameter: 'data_full' is the filename of R-tree storage # 2 files are created: data_full.dat, data_full.idx # return: r-tree index rtree_path = './data/interim/%s/my_rtree_%s' % (train, qgram_tag) data_idx = Rtree(rtree_path) logger.info('Output R-tree: %s' % rtree_path) # put all trajectories into r-tree in the form of bounding box node_id = 0 start_time = time.time() for key, qgrams in trajectory.items(): for qgram in qgrams: # parameters: # 1. node id # 2. bounding box(point): (x,y,x,y) # 3. data inside each node: trajectory's key from order_dict x = np.around(qgram[0], decimals=5) y = np.around(qgram[1], decimals=5) data_idx.insert(node_id, (x, y, x, y), obj=(id_to_key_dict[key])) node_id += 1 del data_idx end_time = time.time() logger.info("exec time: "+str(end_time-start_time)) logger.info('Finished building R-tree')
def searching(feature_vectors_database, feature_vectors_retrieval, similarity_metric, image_paths, retrieval_number, file, list_of_parameters, feature_extraction_method, path_database): ''' feature_vectors: atriutos calculados labels: label de cada classe similarity_metric: qual medida utilizar recuperados as k imagens com menor distancia. Se k = 0, entao o valor eh setado como sendo o tamanho da classe da imagem ''' #name to save the pickle file parameters_name = "" for parameter in list_of_parameters: parameters_name = parameters_name + "_" + parameter file = path_database + "features/sortingRTree" + "_" + feature_extraction_method + parameters_name + '_' + similarity_metric feature_vectors_retrieval = preprocessing.scale(feature_vectors_retrieval) if not (os.path.isfile(file + '.dat')): #normalize signatures feature_vectors_database = preprocessing.scale( feature_vectors_database) # Create a N-Dimensional index p = index.Property() p.dimension = feature_vectors_database.shape[1] idx = index.Index(file, properties=p) # Create the tree for i, vector in enumerate(feature_vectors_database): idx.add(i, vector.tolist()) #save_format = idx.dumps(idx) #with open(file, 'wb') as handle: #pickle.dump(save_format, handle) else: # Create a N-Dimensional index p = index.Property() p.dimension = feature_vectors_database.shape[1] idx = Rtree(file, properties=p) #with open(file, 'rb') as handle: # idx = pickle.load(handle) # Find closests pair for the first N points ########### debug this part ########### small_distances = [] for id1, query in enumerate(feature_vectors_retrieval): nearest = list(idx.nearest(query.tolist(), retrieval_number)) small_distances.append(nearest) result = [] for cont1, i in enumerate(small_distances): aux = [] for j in i: aux.append(image_paths[j]) result.append(aux) return result