def insert(self, key, data): res = self.search(key) if res is not None: self.tree[res].data = data return None walk = self.root_idx if self.tree[walk].key is None: self.tree[walk].key = key self.tree[walk].data = data return None new_node, prev_node, flag = TreeNode(key, data), 0, True while flag: if not self.comparator(key, self.tree[walk].key): if self.tree[walk].right is None: new_node.parent = prev_node self.tree.append(new_node) self.tree[walk].right = self.size self.size += 1 flag = False prev_node = walk = self.tree[walk].right else: if self.tree[walk].left is None: new_node.parent = prev_node self.tree.append(new_node) self.tree[walk].left = self.size self.size += 1 flag = False prev_node = walk = self.tree[walk].left self._update_size(walk)
def join(self, other): """ Joins two trees current and other such that all elements of the current splay tree are smaller than the elements of the other tree. Parameters ========== other: SplayTree SplayTree which needs to be joined with the self tree. """ maxm = self.root_idx while self.tree[maxm].right is not None: maxm = self.tree[maxm].right minm = other.root_idx while other.tree[minm].left is not None: minm = other.tree[minm].left if not self.comparator(self.tree[maxm].key, other.tree[minm].key): raise ValueError("Elements of %s aren't less " "than that of %s" % (self, other)) self.splay(maxm, self.tree[maxm].parent) idx_update = self.tree._size for node in other.tree: if node is not None: node_copy = TreeNode(node.key, node.data) if node.left is not None: node_copy.left = node.left + idx_update if node.right is not None: node_copy.right = node.right + idx_update self.tree.append(node_copy) else: self.tree.append(node) self.tree[self.root_idx].right = \ other.root_idx + idx_update
def __new__(cls, key=None, root_data=None, comp=None, is_order_statistic=False): obj = object.__new__(cls) if key is None and root_data is not None: raise ValueError('Key required.') key = None if root_data is None else key root = TreeNode(key, root_data) root.is_root = True obj.root_idx = 0 obj.tree, obj.size = ArrayForTrees(TreeNode, [root]), 1 obj.comparator = lambda key1, key2: key1 < key2 \ if comp is None else comp obj.is_order_statistic = is_order_statistic return obj
def build(self): """ Builds the segment tree from the segments, using iterative algorithm based on queues. """ if self.cache: return None endpoints = [] for segment in self.segments: endpoints.extend(segment) endpoints.sort() elem_int = Queue() elem_int.append( TreeNode([False, endpoints[0] - 1, endpoints[0], False], None)) i = 0 while i < len(endpoints) - 1: elem_int.append( TreeNode([True, endpoints[i], endpoints[i], True], None)) elem_int.append( TreeNode([False, endpoints[i], endpoints[i + 1], False], None)) i += 1 elem_int.append( TreeNode([True, endpoints[i], endpoints[i], True], None)) elem_int.append( TreeNode([False, endpoints[i], endpoints[i] + 1, False], None)) self.tree = [] while len(elem_int) > 1: m = len(elem_int) while m >= 2: I1 = elem_int.popleft() I2 = elem_int.popleft() I = self._union(I1, I2) I.left = len(self.tree) I.right = len(self.tree) + 1 self.tree.append(I1), self.tree.append(I2) elem_int.append(I) m -= 2 if m & 1 == 1: Il = elem_int.popleft() elem_int.append(Il) Ir = elem_int.popleft() Ir.left, Ir.right = -3, -2 self.tree.append(Ir) self.root_idx = -1 for segment in self.segments: I = TreeNode([True, segment[0], segment[1], True], None) calls = [self.root_idx] while calls: idx = calls.pop() if self._contains(I, self.tree[idx]): if self.tree[idx].data is None: self.tree[idx].data = [] self.tree[idx].data.append(I) continue calls = self._iterate(calls, I, idx) self.cache = True
def query(self, qx, init_node=None): """ Queries the segment tree. Parameters ========== qx: int/float The query point init_node: int The index of the node from which the query process is to be started. Returns ======= intervals: set The set of the intervals which contain the query point. References ========== .. [1] https://en.wikipedia.org/wiki/Segment_tree """ if not self.cache: self.build() if init_node is None: init_node = self.root_idx qn = TreeNode([True, qx, qx, True], None) intervals = [] calls = [init_node] while calls: idx = calls.pop() if _check_type(self.tree[idx].data, list): intervals.extend(self.tree[idx].data) calls = self._iterate(calls, qn, idx) return set(intervals)
def _union(self, i1, i2): """ Helper function for taking union of two intervals. """ return TreeNode([i1.key[0], i1.key[1], i2.key[2], i2.key[3]], None)