def collect_matches(): initial_summoner_name = "GustavEnk" region = "EUW" summoner = Summoner(name=initial_summoner_name, region=region) patch = Patch.from_str("8.9", region=region) unpulled_summoner_ids = SortedList([summoner.id]) pulled_summoner_ids = SortedList() unpulled_match_ids = SortedList() pulled_match_ids = SortedList() while unpulled_summoner_ids: # Get a random summoner from our list of unpulled summoners and pull their match history new_summoner_id = random.choice(unpulled_summoner_ids) new_summoner = Summoner(id=new_summoner_id, region=region) matches = filter_match_history(new_summoner, patch) unpulled_match_ids.update([match.id for match in matches]) unpulled_summoner_ids.remove(new_summoner_id) pulled_summoner_ids.add(new_summoner_id) while unpulled_match_ids: # Get a random match from our list of matches new_match_id = random.choice(unpulled_match_ids) new_match = Match(id=new_match_id, region=region) for participant in new_match.participants: if participant.summoner.id not in pulled_summoner_ids and participant.summoner.id not in unpulled_summoner_ids: unpulled_summoner_ids.add(participant.summoner.id) # The above lines will trigger the match to load its data by iterating over all the participants. # If you have a database in your datapipeline, the match will automatically be stored in it. unpulled_match_ids.remove(new_match_id) pulled_match_ids.add(new_match_id)
class Episode_scores(object): def __init__(self, options): self.maxlen = options.score_averaging_length self.threshold = options.score_highest_ratio self.episode_scores = deque() self.episode_scores.append(0) # to avoid 0-div in first averaging self.episode_scores_sum = 0 self.sorted_scores = SortedList() self.sorted_scores.add(0) # align to episode_scores self.num_episode = 0 self.options = options def add(self, n, global_t, thread_index): self.episode_scores_sum += n self.episode_scores.append(n) self.sorted_scores.add(-n) # trick to use SortedList in reverse order if len(self.episode_scores) > self.maxlen: oldest = self.episode_scores.popleft() self.sorted_scores.remove(-oldest) self.episode_scores_sum -= oldest self.num_episode += 1 if self.num_episode % self.options.average_score_log_interval == 0: print("@@@ Average Episode score = {:.6f}, s={:9d},th={}".format(self.average(), global_t, thread_index)) def average(self): return self.episode_scores_sum / len(self.episode_scores) def is_highscore(self, n): sorted_scores = self.sorted_scores num_scores = len(sorted_scores) sorted_scores.add(-n) index = sorted_scores.index(-n) highest_ratio = (index + 1) / num_scores sorted_scores.remove(-n) return highest_ratio <= self.threshold
class InMemoryBackend(object): """ The backend that keeps the results in the memory. """ def __init__(self, *args, **kwargs): def get_timestamp(result): return timestamp_parser.parse(result['timestamp']) self._results = dict() self._sorted = SortedList(key=get_timestamp) def disconnect(self): return succeed(None) def store(self, result): """ Store a single benchmarking result and return its identifier. :param dict result: The result in the JSON compatible format. :return: A Deferred that produces an identifier for the stored result. """ id = uuid4().hex self._results[id] = result self._sorted.add(result) return succeed(id) def retrieve(self, id): """ Retrive a result by the given identifier. """ try: return succeed(self._results[id]) except KeyError: return fail(ResultNotFound(id)) def query(self, filter, limit=None): """ Return matching results. """ matching = [] for result in reversed(self._sorted): if len(matching) == limit: break if filter.viewitems() <= result.viewitems(): matching.append(result) return succeed(matching) def delete(self, id): """ Delete a result by the given identifier. """ try: result = self._results.pop(id) self._sorted.remove(result) return succeed(None) except KeyError: return fail(ResultNotFound(id))
def test_delete(): slt = SortedList(range(20), load=4) slt._check() for val in range(20): slt.remove(val) slt._check() assert len(slt) == 0 assert slt._maxes == [] assert slt._lists == []
def test_remove(): slt = SortedList() assert slt.discard(0) == None assert len(slt) == 0 slt._check() slt = SortedList([1, 2, 2, 2, 3, 3, 5], load=4) slt.remove(2) slt._check() assert all(tup[0] == tup[1] for tup in zip(slt, [1, 2, 2, 3, 3, 5]))
class Episode_scores(object): def __init__(self, options): self.maxlen = options.score_averaging_length self.threshold = options.score_highest_ratio self.episode_scores = deque() self.episode_scores.append(0) # to avoid 0-div in first averaging self.episode_scores_sum = 0 self.sorted_scores = SortedList() self.sorted_scores.add(0) # align to episode_scores self.num_episode = 0 self.options = options def add(self, n, global_t, thread_index): self.episode_scores_sum += n self.episode_scores.append(n) self.sorted_scores.add(-n) # trick to use SortedList in reverse order if len(self.episode_scores) > self.maxlen: oldest = self.episode_scores.popleft() self.sorted_scores.remove(-oldest) self.episode_scores_sum -= oldest self.num_episode += 1 if self.num_episode % self.options.average_score_log_interval == 0: print("@@@ Average Episode score = {:.6f}, s={:9d},th={}".format( self.average(), global_t, thread_index)) def average(self): return self.episode_scores_sum / len(self.episode_scores) def is_highscore(self, n): sorted_scores = self.sorted_scores num_scores = len(sorted_scores) sorted_scores.add(-n) index = sorted_scores.index(-n) highest_ratio = (index + 1) / num_scores sorted_scores.remove(-n) return highest_ratio <= self.threshold
def busiest_servers(self, k: int, arrival: List[int], load: List[int]) -> List[int]: available = SortedList(range(k)) busy = [] requests = [0] * k for i, start in enumerate(arrival): while busy and busy[0][0] <= start: available.add(busy[0][1]) heapq.heappop(busy) if not available: continue j = available.bisect_left(i % k) if j == len(available): j = 0 idx = available[j] requests[idx] += 1 heapq.heappush(busy, (start + load[i], idx)) available.remove(idx) max_cnt = max(requests) ans = [] for i, req in enumerate(requests): if req == max_cnt: ans.append(i) return ans
def trade(self): returns = self.returns([self.data, self.book]) prices = self.get_prices(returns) """Trading as described in Algorithm 1""" orders = SortedList(key=lambda x: x.price) for i in np.arange(self.agent_num): orders.add(Order(agent_id=i, price=prices[i])) trades = [] while len(orders) > 1: # Draw random agent's action bid = np.random.choice(orders) orders.remove(bid) cash_left = self.book[bid.agent_id].numpy() ask = orders[0] _seller_stocks = self.book[self.agent_num + ask.agent_id].numpy() if bid.price >= ask.price and cash_left >= ask.price and _seller_stocks > 1: _quantity = min(np.floor(cash_left / ask.price), _seller_stocks) _trade = Trade(bid.agent_id, ask.agent_id, _quantity, ask.price) self._execute_trade(_trade) trades.append(_trade) # Comment the following lines to prevent recomputation of utilities during trading. returns = self.returns([self.data, self.book]) prices = self.get_prices(returns) orders_temp = SortedList(key=lambda x: x.price) for i in orders: orders_temp.add(Order(agent_id=i, price=prices[i])) orders = orders_temp
def getSkyline(self, buildings: List[List[int]]) -> List[List[int]]: events = {} # saving y of each x (start at the left arr and end at the right arr) for start, end, height in buildings: if start not in events: events[start] = ([], []) if end not in events: events[end] = ([], []) events[start][0].append(height) events[end][1].append(height) # print(events) heights_list = SortedList() # start with height = 0 last_height = 0 ans = [] # we can use sorted on the key of dictionary to sort and look from smaller x to larger for x in sorted(events): # for each x, we have 2 arrs: left - heights of start; right - heights of end start_heights, end_heights = events[x] # iterate through both arrs and add them to sorted list for start for height in start_heights: heights_list.add(height) # remove them for end for height in end_heights: heights_list.remove(height) # the height_list could be empty then its 0 else check to see if the max_height is different from previous max_height # if its different from previous heiught, append it to answer arr # update the last height max_height = heights_list[-1] if heights_list else 0 if max_height != last_height: ans.append([x, max_height]) last_height = max_height return ans
class MovieRentingSystem: def __init__(self, n: int, entries: List[List[int]]): self.unrented = defaultdict(SortedList) # {movie: (price, shop)} self.shopAndMovieToPrice = {} # {(shop, movie): price} self.rented = SortedList() # (price, shop, movie) for shop, movie, price in entries: self.unrented[movie].add((price, shop)) self.shopAndMovieToPrice[(shop, movie)] = price def search(self, movie: int) -> List[int]: return [shop for _, shop in self.unrented[movie][:5]] def rent(self, shop: int, movie: int) -> None: price = self.shopAndMovieToPrice[(shop, movie)] self.unrented[movie].remove((price, shop)) self.rented.add((price, shop, movie)) def drop(self, shop: int, movie: int) -> None: price = self.shopAndMovieToPrice[(shop, movie)] self.unrented[movie].add((price, shop)) self.rented.remove((price, shop, movie)) def report(self) -> List[List[int]]: return [[shop, movie] for _, shop, movie in self.rented[:5]]
class MovieRentingSystem: def __init__(self, n: int, entries: List[List[int]]): self.movie2priceshops = defaultdict(SortedList) self.shopmovie2price = {} self.rented_BST = SortedList() # [(price, shop, movie)] for shop, movie, price in entries: self.movie2priceshops[movie].add((price, shop)) self.shopmovie2price[(shop, movie)] = price def search(self, movie: int) -> List[int]: return [item[1] for item in self.movie2priceshops[movie][:5]] def rent(self, shop: int, movie: int) -> None: price = self.shopmovie2price[(shop, movie)] self.movie2priceshops[movie].remove((price, shop)) self.rented_BST.add((price, shop, movie)) def drop(self, shop: int, movie: int) -> None: price = self.shopmovie2price[(shop, movie)] self.movie2priceshops[movie].add((price, shop)) self.rented_BST.remove((price, shop, movie)) def report(self) -> List[List[int]]: return [item[1:] for item in self.rented_BST[:5]]
class StockPrice(object): def __init__(self): self.__curr = 0 self.__lookup = {} self.__sl_by_price = SortedList() def update(self, timestamp, price): """ :type timestamp: int :type price: int :rtype: None """ if timestamp > self.__curr: self.__curr = timestamp if timestamp in self.__lookup: self.__sl_by_price.remove(self.__lookup[timestamp]) self.__lookup[timestamp] = price self.__sl_by_price.add(price) def current(self): """ :rtype: int """ return self.__lookup[self.__curr] def maximum(self): """ :rtype: int """ return next(reversed(self.__sl_by_price)) def minimum(self): """ :rtype: int """ return next(iter(self.__sl_by_price))
class MKAverage: def __init__(self, m: int, k: int): self.q = collections.deque() self.left = SortedList() self.mid = SortedList() self.right = SortedList() self.sums = 0 self.m = m self.k = k self.n = m - 2 * k def addElement(self, num: int) -> None: if len(self.q) == self.m: self.remove(self.q.popleft()) self.q.append(num) self.add(num) def calculateMKAverage(self) -> int: return -1 if len(self.q) < self.m else self.sums // self.n def add(self, num: int) -> None: self.left.add(num) if len(self.left) > self.k: self.sums += self.left[-1] self.mid.add(self.left[-1]) self.left.pop() if len(self.mid) > self.n: self.sums -= self.mid[-1] self.right.add(self.mid[-1]) self.mid.pop() def remove(self, num: int) -> None: if num <= self.left[-1]: self.left.remove(num) elif num <= self.mid[-1]: self.sums -= num self.mid.remove(num) else: self.right.remove(num) if len(self.left) < self.k: self.sums -= self.mid[0] self.left.add(self.mid[0]) self.mid.pop(0) if len(self.mid) < self.n: self.sums += self.right[0] self.mid.add(self.right[0]) self.right.pop(0)
class ExamRoom: def __init__(self, N: int): self.used = SortedList() self.N = N def _addSentinel(self): if self.used[0] != 0: fst = -self.used[0] self.used.add(fst) else: fst = None if self.used[-1] != self.N - 1: last = self.N - 1 + (self.N - 1 - self.used[-1]) self.used.add(last) else: last = None return fst, last def _removeSentinel(self, sentinel): fst, last = sentinel if fst: self.used.remove(fst) if last: self.used.remove(last) def seat(self) -> int: if len(self.used) == 0: p = 0 else: sentinel = self._addSentinel() p = None maxDist = -inf for i in range(len(self.used) - 1): q = (self.used[i] + self.used[i + 1]) // 2 minDist = min(q - self.used[i], self.used[i + 1] - q) if minDist > maxDist: maxDist = minDist p = q self._removeSentinel(sentinel) self.used.add(p) return p def leave(self, p: int) -> None: self.used.remove(p)
class MKAverage: def __init__(self, m: int, k: int): self.m = m self.k = k self.queue = deque() self.ksmallest = SortedList() self.middle = SortedList() self.klargest = SortedList() self.sum = 0 def addElement(self, num: int) -> None: if len(self.queue) == self.m: expired = self.queue.popleft() if expired <= self.ksmallest[-1]: self.ksmallest.remove(expired) elif self.middle and expired <= self.middle[-1]: self.sum -= expired self.middle.remove(expired) else: self.klargest.remove(expired) self.queue.append(num) self.middle.add(num) self.sum += num if self.ksmallest: self.sum += self.ksmallest[-1] self.middle.add(self.ksmallest.pop(-1)) if self.klargest: self.sum += self.klargest[0] self.middle.add(self.klargest.pop(0)) while self.middle and len(self.ksmallest) < self.k: self.sum -= self.middle[0] self.ksmallest.add(self.middle.pop(0)) while self.middle and len(self.klargest) < self.k: self.sum -= self.middle[-1] self.klargest.add(self.middle.pop()) def calculateMKAverage(self) -> int: if len(self.queue) < self.m: return -1 return self.sum // (self.m - self.k * 2)
class SCEngine: ''' Fast tree-based implementation for indexing, using the ``sortedcontainers`` package. Parameters ---------- data : Table Sorted columns of the original table row_index : Column object Row numbers corresponding to data columns unique : bool (defaults to False) Whether the values of the index must be unique ''' def __init__(self, data, row_index, unique=False): node_keys = map(tuple, data) self._nodes = SortedList(starmap(Node, zip(node_keys, row_index))) self._unique = unique def add(self, key, value): ''' Add a key, value pair. ''' if self._unique and (key in self._nodes): message = 'duplicate {0:!r} in unique index'.format(key) raise ValueError(message) self._nodes.add(Node(key, value)) def find(self, key): ''' Find rows corresponding to the given key. ''' return [node.value for node in self._nodes.irange(key, key)] def remove(self, key, data=None): ''' Remove data from the given key. ''' if data is not None: item = Node(key, data) try: self._nodes.remove(item) except ValueError: return False return True items = list(self._nodes.irange(key, key)) for item in items: self._nodes.remove(item) return bool(items) def shift_left(self, row): ''' Decrement rows larger than the given row. ''' for node in self._nodes: if node.value > row: node.value -= 1 def shift_right(self, row): ''' Increment rows greater than or equal to the given row. ''' for node in self._nodes: if node.value >= row: node.value += 1 def items(self): ''' Return a list of key, data tuples. ''' result = OrderedDict() for node in self._nodes: if node.key in result: result[node.key].append(node.value) else: result[node.key] = [node.value] return result.items() def sort(self): ''' Make row order align with key order. ''' for index, node in enumerate(self._nodes): node.value = index def sorted_data(self): ''' Return a list of rows in order sorted by key. ''' return [node.value for node in self._nodes] def range(self, lower, upper, bounds=(True, True)): ''' Return row values in the given range. ''' iterator = self._nodes.irange(lower, upper, bounds) return [node.value for node in iterator] def replace_rows(self, row_map): ''' Replace rows with the values in row_map. ''' nodes = [node for node in self._nodes if node.value in row_map] for node in nodes: node.value = row_map[node.value] self._nodes.clear() self._nodes.update(nodes) def __repr__(self): return '{0!r}'.format(list(self._nodes))
def test_remove_valueerror2(): slt = SortedList(range(100)) slt._reset(10) with pytest.raises(ValueError): slt.remove(100)
def test_remove_valueerror1(): slt = SortedList() slt.remove(0)
def test_remove_valueerror3(): slt = SortedList([1, 2, 2, 2, 3, 3, 5]) slt.remove(4)
class PriorityDict(MutableMapping): """ A PriorityDict provides the same methods as a dict. Additionally, a PriorityDict efficiently maintains its keys in value sorted order. Consequently, the keys method will return the keys in value sorted order, the popitem method will remove the item with the highest value, etc. """ def __init__(self, *args, **kwargs): """ A PriorityDict provides the same methods as a dict. Additionally, a PriorityDict efficiently maintains its keys in value sorted order. Consequently, the keys method will return the keys in value sorted order, the popitem method will remove the item with the highest value, etc. If the first argument is the boolean value False, then it indicates that keys are not comparable. By default this setting is True and duplicate values are tie-breaked on the key. Using comparable keys improves the performance of the PriorityDict. An optional *iterable* argument provides an initial series of items to populate the PriorityDict. Each item in the sequence must itself contain two items. The first is used as a key in the new dictionary, and the second as the key's value. If a given key is seen more than once, the last value associated with it is retained in the new dictionary. If keyword arguments are given, the keywords themselves with their associated values are added as items to the dictionary. If a key is specified both in the positional argument and as a keyword argument, the value associated with the keyword is retained in the dictionary. For example, these all return a dictionary equal to ``{"one": 2, "two": 3}``: * ``SortedDict(one=2, two=3)`` * ``SortedDict({'one': 2, 'two': 3})`` * ``SortedDict(zip(('one', 'two'), (2, 3)))`` * ``SortedDict([['two', 3], ['one', 2]])`` The first example only works for keys that are valid Python identifiers; the others work with any valid keys. Note that this constructor mimics the Python dict constructor. If you're looking for a constructor like collections.Counter(...), see PriorityDict.count(...). """ self._dict = dict() if len(args) > 0 and isinstance(args[0], bool): if args[0]: self._list = SortedList() else: self._list = SortedListWithKey(key=lambda tup: tup[0]) else: self._list = SortedList() self.iloc = _IlocWrapper(self) self.update(*args, **kwargs) def clear(self): """Remove all elements from the dictionary.""" self._dict.clear() self._list.clear() def clean(self, value=0): """ Remove all items with value less than or equal to `value`. Default `value` is 0. """ _list, _dict = self._list, self._dict pos = self.bisect_right(value) for key in (key for value, key in _list[:pos]): del _dict[key] del _list[:pos] def __contains__(self, key): """Return True if and only if *key* is in the dictionary.""" return key in self._dict def __delitem__(self, key): """ Remove ``d[key]`` from *d*. Raises a KeyError if *key* is not in the dictionary. """ value = self._dict[key] self._list.remove((value, key)) del self._dict[key] def __getitem__(self, key): """ Return the priority of *key* in *d*. Raises a KeyError if *key* is not in the dictionary. """ return self._dict[key] def __iter__(self): """ Create an iterator over the keys of the dictionary ordered by the value sort order. """ return iter(key for value, key in self._list) def __reversed__(self): """ Create an iterator over the keys of the dictionary ordered by the reversed value sort order. """ return iter(key for value, key in reversed(self._list)) def __len__(self): """Return the number of (key, value) pairs in the dictionary.""" return len(self._dict) def __setitem__(self, key, value): """Set `d[key]` to *value*.""" if key in self._dict: old_value = self._dict[key] self._list.remove((old_value, key)) self._list.add((value, key)) self._dict[key] = value def copy(self): """Create a shallow copy of the dictionary.""" result = PriorityDict() result._dict = self._dict.copy() result._list = self._list.copy() result.iloc = _IlocWrapper(result) return result def __copy__(self): """Create a shallow copy of the dictionary.""" return self.copy() @classmethod def fromkeys(cls, iterable, value=0): """ Create a new dictionary with keys from `iterable` and values set to `value`. The default *value* is 0. """ return PriorityDict((key, value) for key in iterable) def get(self, key, default=None): """ Return the value for *key* if *key* is in the dictionary, else *default*. If *default* is not given, it defaults to ``None``, so that this method never raises a KeyError. """ return self._dict.get(key, default) def has_key(self, key): """Return True if and only in *key* is in the dictionary.""" return key in self._dict def pop(self, key, default=_NotGiven): """ If *key* is in the dictionary, remove it and return its value, else return *default*. If *default* is not given and *key* is not in the dictionary, a KeyError is raised. """ if key in self._dict: value = self._dict[key] self._list.remove((value, key)) return self._dict.pop(key) else: if default == _NotGiven: raise KeyError else: return default def popitem(self, index=-1): """ Remove and return item at *index* (default: -1). Raises IndexError if dict is empty or index is out of range. Negative indices are supported as for slice indices. """ value, key = self._list.pop(index) del self._dict[key] return key, value def setdefault(self, key, default=0): """ If *key* is in the dictionary, return its value. If not, insert *key* with a value of *default* and return *default*. *default* defaults to ``0``. """ if key in self._dict: return self._dict[key] else: self._dict[key] = default self._list.add((default, key)) return default def elements(self): """ Return an iterator over elements repeating each as many times as its count. Elements are returned in value sort-order. If an element’s count is less than one, elements() will ignore it. """ values = (repeat(key, value) for value, key in self._list) return chain.from_iterable(values) def most_common(self, count=None): """ Return a list of the `count` highest priority elements with their priority. If `count` is not specified, `most_common` returns *all* elements in the dict. Elements with equal counts are ordered by key. """ _list, _dict = self._list, self._dict if count is None: return [(key, value) for value, key in reversed(_list)] end = len(_dict) start = end - count return [(key, value) for value, key in reversed(_list[start:end])] def subtract(self, elements): """ Elements are subtracted from an iterable or from another mapping (or counter). Like dict.update() but subtracts counts instead of replacing them. Both inputs and outputs may be zero or negative. """ self -= Counter(elements) def tally(self, *args, **kwargs): """ Elements are counted from an iterable or added-in from another mapping (or counter). Like dict.update() but adds counts instead of replacing them. Also, the iterable is expected to be a sequence of elements, not a sequence of (key, value) pairs. """ self += Counter(*args, **kwargs) @classmethod def count(self, *args, **kwargs): """ Consume `args` and `kwargs` with a Counter and use that mapping to initialize a PriorityDict. """ return PriorityDict(Counter(*args, **kwargs)) def update(self, *args, **kwargs): """ Update the dictionary with the key/value pairs from *other*, overwriting existing keys. *update* accepts either another dictionary object or an iterable of key/value pairs (as a tuple or other iterable of length two). If keyword arguments are specified, the dictionary is then updated with those key/value pairs: ``d.update(red=1, blue=2)``. """ _list, _dict = self._list, self._dict if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], Mapping): items = args[0] else: items = dict(*args, **kwargs) if (10 * len(items)) > len(_dict): _dict.update(items) _list.clear() _list.update((value, key) for key, value in iteritems(_dict)) else: for key, value in iteritems(items): old_value = _dict[key] _list.remove((old_value, key)) _dict[key] = value _list.add((value, key)) def index(self, key): """ Return the smallest *i* such that `d.iloc[i] == key`. Raises KeyError if *key* is not present. """ value = self._dict[key] return self._list.index((value, key)) def bisect_left(self, value): """ Similar to the ``bisect`` module in the standard library, this returns an appropriate index to insert *value* in PriorityDict. If *value* is already present in PriorityDict, the insertion point will be before (to the left of) any existing entries. """ return self._list.bisect_left((value,)) def bisect(self, value): """Same as bisect_left.""" return self._list.bisect((value,)) def bisect_right(self, value): """ Same as `bisect_left`, but if *value* is already present in PriorityDict, the insertion point will be after (to the right of) any existing entries. """ return self._list.bisect_right((value, _Biggest)) def __iadd__(self, that): """Add values from `that` mapping.""" _list, _dict = self._list, self._dict if len(_dict) == 0: _dict.update(that) _list.update((value, key) for key, value in iteritems(_dict)) elif len(that) * 3 > len(_dict): _list.clear() for key, value in iteritems(that): if key in _dict: _dict[key] += value else: _dict[key] = value _list.update((value, key) for key, value in iteritems(_dict)) else: for key, value in iteritems(that): if key in _dict: old_value = _dict[key] _list.remove((old_value, key)) value = old_value + value _dict[key] = value _list.add((value, key)) return self def __isub__(self, that): """Subtract values from `that` mapping.""" _list, _dict = self._list, self._dict if len(_dict) == 0: _dict.clear() _list.clear() elif len(that) * 3 > len(_dict): _list.clear() for key, value in iteritems(that): if key in _dict: _dict[key] -= value _list.update((value, key) for key, value in iteritems(_dict)) else: for key, value in iteritems(that): if key in _dict: old_value = _dict[key] _list.remove((old_value, key)) value = old_value - value _dict[key] = value _list.add((value, key)) return self def __ior__(self, that): """Or values from `that` mapping (max(v1, v2)).""" _list, _dict = self._list, self._dict if len(_dict) == 0: _dict.update(that) _list.update((value, key) for key, value in iteritems(_dict)) elif len(that) * 3 > len(_dict): _list.clear() for key, value in iteritems(that): if key in _dict: old_value = _dict[key] _dict[key] = old_value if old_value > value else value else: _dict[key] = value _list.update((value, key) for key, value in iteritems(_dict)) else: for key, value in iteritems(that): if key in _dict: old_value = _dict[key] _list.remove((old_value, key)) value = old_value if old_value > value else value _dict[key] = value _list.add((value, key)) return self def __iand__(self, that): """And values from `that` mapping (min(v1, v2)).""" _list, _dict = self._list, self._dict if len(_dict) == 0: _dict.clear() _list.clear() elif len(that) * 3 > len(_dict): _list.clear() for key, value in iteritems(that): if key in _dict: old_value = _dict[key] _dict[key] = old_value if old_value < value else value _list.update((value, key) for key, value in iteritems(_dict)) else: for key, value in iteritems(that): if key in _dict: old_value = _dict[key] _list.remove((old_value, key)) value = old_value if old_value < value else value _dict[key] = value _list.add((value, key)) return self def __add__(self, that): """Add values from this and `that` mapping.""" result = PriorityDict() _list, _dict = result._list, result._dict _dict.update(self._dict) for key, value in iteritems(that): if key in _dict: _dict[key] += value else: _dict[key] = value _list.update((value, key) for key, value in iteritems(_dict)) return result def __sub__(self, that): """Subtract values in `that` mapping from this.""" result = PriorityDict() _list, _dict = result._list, result._dict _dict.update(self._dict) for key, value in iteritems(that): if key in _dict: _dict[key] -= value _list.update((value, key) for key, value in iteritems(_dict)) return result def __or__(self, that): """Or values from this and `that` mapping.""" result = PriorityDict() _list, _dict = result._list, result._dict _dict.update(self._dict) for key, value in iteritems(that): if key in _dict: old_value = _dict[key] _dict[key] = old_value if old_value > value else value else: _dict[key] = value _list.update((value, key) for key, value in iteritems(_dict)) return result def __and__(self, that): """And values from this and `that` mapping.""" result = PriorityDict() _list, _dict = result._list, result._dict _dict.update(self._dict) for key, value in iteritems(that): if key in _dict: old_value = _dict[key] _dict[key] = old_value if old_value < value else value _list.update((value, key) for key, value in iteritems(_dict)) return result def __eq__(self, that): """Compare two mappings for equality.""" if isinstance(that, PriorityDict): that = that._dict return self._dict == that def __ne__(self, that): """Compare two mappings for inequality.""" if isinstance(that, PriorityDict): that = that._dict return self._dict != that def __lt__(self, that): """Compare two mappings for less than.""" if isinstance(that, PriorityDict): that = that._dict _dict = self._dict return (_dict != that and self <= that) def __le__(self, that): """Compare two mappings for less than equal.""" if isinstance(that, PriorityDict): that = that._dict _dict = self._dict return (len(_dict) <= len(that) and all(_dict[key] <= that[key] if key in that else False for key in _dict)) def __gt__(self, that): """Compare two mappings for greater than.""" if isinstance(that, PriorityDict): that = that._dict _dict = self._dict return (_dict != that and self >= that) def __ge__(self, that): """Compare two mappings for greater than equal.""" if isinstance(that, PriorityDict): that = that._dict _dict = self._dict return (len(_dict) >= len(that) and all(_dict[key] >= that[key] if key in _dict else False for key in that)) def isdisjoint(self, that): """ Return True if no key in `self` is also in `that`. This doesn't check that the value is greater than zero. To remove keys with value less than or equal to zero see *clean*. """ return not any(key in self for key in that) def items(self): """ Return a list of the dictionary's items (``(key, value)`` pairs). Items are ordered by their value from least to greatest. """ return list((key, value) for value, key in self._list) def iteritems(self): """ Return an iterable over the items (``(key, value)`` pairs) of the dictionary. Items are ordered by their value from least to greatest. """ return iter((key, value) for value, key in self._list) @not26 def viewitems(self): """ In Python 2.7 and later, return a new `ItemsView` of the dictionary's items. Beware iterating the `ItemsView` as items are unordered. In Python 2.6, raise a NotImplementedError. """ if hexversion < 0x03000000: return self._dict.viewitems() else: return self._dict.items() def keys(self): """ Return a list of the dictionary's keys. Keys are ordered by their corresponding value from least to greatest. """ return list(key for value, key in self._list) def iterkeys(self): """ Return an iterable over the keys of the dictionary. Keys are ordered by their corresponding value from least to greatest. """ return iter(key for value, key in self._list) @not26 def viewkeys(self): """ In Python 2.7 and later, return a new `KeysView` of the dictionary's keys. Beware iterating the `KeysView` as keys are unordered. In Python 2.6, raise a NotImplementedError. """ if hexversion < 0x03000000: return self._dict.viewkeys() else: return self._dict.keys() def values(self): """ Return a list of the dictionary's values. Values are ordered from least to greatest. """ return list(value for value, key in self._list) def itervalues(self): """ Return an iterable over the values of the dictionary. Values are iterated from least to greatest. """ return iter(value for value, key in self._list) @not26 def viewvalues(self): """ In Python 2.7 and later, return a `ValuesView` of the dictionary's values. Beware iterating the `ValuesView` as values are unordered. In Python 2.6, raise a NotImplementedError. """ if hexversion < 0x03000000: return self._dict.viewvalues() else: return self._dict.values() def __repr__(self): """Return a string representation of PriorityDict.""" return 'PriorityDict({0})'.format(repr(dict(self))) def _check(self): self._list._check() assert len(self._dict) == len(self._list) assert all(key in self._dict and self._dict[key] == value for value, key in self._list)
def test_remove_valueerror2(): slt = SortedList(range(100), load=10) slt.remove(100)
def test_remove_valueerror2(): slt = SortedList(range(100)) slt._reset(10) with pytest.raises(ValueError): slt.remove(100)
def _all_segment_intersections_no_horizontal(segments): # noqa # Must be unique assert len(set(segments)) == len(segments) segments = list(segments) # Must not be degenerate for segment in segments: assert segment[0] != segment[1] # Use the convention from the book: sweep on Y axis def event_key(pt): return (pt[1], pt[0]) # From point to list of segments event_queue = SortedDict(event_key) def add_event(pt, segment_key=None): if pt not in event_queue: event_queue[pt] = [] if segment_key is not None: event_queue[pt].append(segment_key) for i, segment in enumerate(segments): if event_key(segment[0]) < event_key(segment[1]): add_event(segment[0], _SweepKey(segment, segment[0])) add_event(segment[1], None) else: add_event(segment[0], None) add_event(segment[1], _SweepKey(segment, segment[1])) active = SortedList() y = -math.inf while len(event_queue) > 0: v = event_queue.popitem(0) pt, segstarts = v # Can't be > since while there are no horizontal segments, # there can still be points in horizontal relation to one another assert pt[1] >= y y = pt[1] # Find all segments within the event point fake_segment = ((pt[0], pt[1]), (pt[0], pt[1] + 1)) fake_key = _SweepKey(fake_segment, pt) touches = [] # The next lower / higher keys, respectively, to enter new events for neighbours = [] if _extra_checks: _assert_fully_sorted(list(active), y) # Iterate on both sides for it in ( active.irange(None, fake_key, inclusive=(True, True), reverse=True), active.irange(fake_key, None, inclusive=(False, True)), ): neighbour = None for sweep_key in it: if sweep_key.at_y(y) != pt[0]: neighbour = sweep_key break touches.append(sweep_key) neighbours.append(neighbour) # Remove the old sweep keys for touch in touches: active.remove(touch) segments_at_pt = [ sweep_key.segment for sweep_key in touches + segstarts ] if len(segments_at_pt) > 1: yield (pt, tuple(segments_at_pt)) # Create new _SweepKeys, automatically sorts # according to order after point sweep_keys = [] for segment in segments_at_pt: # Is this segment still relevant? if max(segment[0][1], segment[1][1]) <= pt[1]: continue sweep_keys.append(_SweepKey(segment, pt)) sweep_keys = list(sorted(sweep_keys)) # Add new events for neighbours if len(sweep_keys) == 0: # If we just removed stuff, the neighbours might now meet... if neighbours[0] is not None and neighbours[1] is not None: ipt = _nonparallel_intersection_point(neighbours[0].segment, neighbours[1].segment) if ipt and ipt[1] > pt[1]: add_event(ipt) continue if neighbours[0] is not None: ipt = _nonparallel_intersection_point(sweep_keys[0].segment, neighbours[0].segment) # hyp.note(fstr('IPTL', ipt, pt)) if ipt and ipt[1] > pt[1]: add_event(ipt) if neighbours[1] is not None: ipt = _nonparallel_intersection_point(sweep_keys[-1].segment, neighbours[1].segment) # hyp.note(fstr('IPTR', ipt, pt)) if ipt and ipt[1] > pt[1]: add_event(ipt) # Add them in and continue for sweep_key in sweep_keys: active.add(sweep_key)
class TxGraph(object): """represents a graph of all transactions within the current window Attributes: median(float) : the current median of the degree of the nodes highMarker(int) : the latest timestamp seen so far lowMarker(int) : the earliest timestamp of the window we are interested in txMap(dict) : this is a collection of EdgeList's with key being the timestamp and the value an instance of EdgeList edgeMap(dict) : this is collection of all Edges within a window with key being the name of an Edge nodeMap(dict) : this represents a collection of Nodes with a window with key being the name of the Node degreeList(list): list of degrees of noded (sorted) """ WINDOW_SIZE = 60 def __init__(self): self.median = 0 self.highMarker = TxGraph.WINDOW_SIZE self.lowMarker = 1 self.txMap = SortedDict() #sorted by unix epoch (timestamp) self.edgeMap = SortedDict() #sorted by edge name self.nodeMap = SortedDict() #sorted by node name self.degreeList = SortedList() #sorted by degreeList def __calculate_median(self, use_existing_list=False): """calculates median by adding degrees to a sortedlist """ if not use_existing_list: #lets reconstruct the list self.degreeList = SortedList() for node in self.nodeMap.itervalues(): if node.degree > 0: self.degreeList.add(node.degree) listLen = len(self.degreeList) if listLen == 0: raise Exception("No items in the degreeList") if listLen == 1: return self.degreeList[0]/1.0 if (listLen % 2) == 0: return (self.degreeList[listLen/2] + self.degreeList[(listLen/2) - 1]) / 2.0 return self.degreeList[listLen/2]/1.0 def __get_edgelist(self, tstamp, create=True): """returns an instance of EdgeList with matching timestamp and creates one if needed """ edgeList = self.txMap.get(tstamp, None) if edgeList is None and create is True: edgeList = EdgeList(tstamp) self.txMap[tstamp] = edgeList return edgeList def __getnode_with_name(self, name, create=True): """returns an instance of Node with matching name and creates one if necessary Args: name(str) : name of the edge create(bool): flag to indicate whether to create a missing node """ node = self.nodeMap.get(name, None) if node is None and create is True: node = Node(name) self.nodeMap[name] = node return node def __incr_degree_of_edge_nodes(self, edge): """increments the degree of the two nodes of an edge """ src = self.__getnode_with_name(edge.source) src.incr_degree() tar = self.__getnode_with_name(edge.target) tar.incr_degree() return (src.degree, tar.degree) def __decr_degree_of_edge_nodes(self, edge): """decrements the degree of the two nodes of an edge """ self.__decr_degree_of_node(edge.source) self.__decr_degree_of_node(edge.target) def __decr_degree_of_node(self, name): """decrements the degree of a node and removes it from the nodeMap if degree is 0 """ node = self.__getnode_with_name(name, create=False) node.decr_degree() if node.degree == 0: del self.nodeMap[node.name] def __remove_edge(self, edge): """removes an edge from the graph and updates the degree of a node. If degree of a node goes to 0, then remove the node as well Args: egde(Edge) : An instance of Edge class """ self.__decr_degree_of_edge_nodes(edge) del self.edgeMap[edge.name] def __update_tstamp_for_existing_edge(self, edgeName, tstamp): """updates the timestamp for an existing edge and moves the edge to an appropriate EdgeList Args: edgeName(str) : name of the edge to be updated tstamp(int) : unix epoch of the timstamp """ currEdge = self.edgeMap[edgeName] if not currEdge: return if tstamp <= currEdge.tstamp: return #ignore older transactions within the window #remove the edge from the edgelist with old timestamp edgeList = self.__get_edgelist(currEdge.tstamp, create=False) del edgeList.edges[currEdge.name] #update the tstamp in the edge currEdge.tstamp = tstamp #move this edge to the correct edgelist edgeList = self.__get_edgelist(tstamp) edgeList.edges[currEdge.name] = currEdge def __update_tx_window(self): """updates the transaction window of the graph This method is called when a newer transaction out the window arrives. It does the following: 1. Gets the edgeList's that are below the lowMarker 2. Goes through the edges and deletes them from the edgeMap 3. Update the degree of the nodes 4. Moves the window by deleting the stale edgeLists """ tsIter = self.txMap.irange(None, self.lowMarker, inclusive=(True,False)) lastTStamp = None for tstamp in tsIter: lastTStamp = tstamp edgeList = self.txMap[tstamp] for edge in edgeList.edges.itervalues(): self.__remove_edge(edge) #lets delete the stale edgelists if lastTStamp: lowIdx = self.txMap.index(lastTStamp) del self.txMap.iloc[:lowIdx+1] def process_transaction(self, tstamp, source, target): """this is the starting point of transaction processing. We first check whether the tx is within the window. If it is, then we update the Edge (if it already exists) or create a new Edge if necessary and update the median. If the tx is not within the window and is newer, we then move the window and remove all stale(older) edges and create a new edge for the newer transaction and finally update the median """ #basic sanity checks if source is None or target is None: raise Exception("Invalid node") if len(source) == 0 or len(target) == 0: raise Exception("Invalid node") if source == target: raise Exception("source and target cannot be the same") #timestamp of the transaction is old and can be ignored if tstamp < self.lowMarker: return #create a new edge representing this transaction newEdge = Edge(tstamp, source, target) if tstamp <= self.highMarker: if newEdge.name in self.edgeMap: self.__update_tstamp_for_existing_edge(newEdge.name, tstamp) #no need to recalculate the median here since degree does not change return """handle new edge 1. find the edgelist with the same timestamp (if not create it) 2. add this edge to the edgelist and edgemap 4. create new Nodes for the edges if needed or update their degrees 5. update the degreeList with the new degrees 6. recalculate the median but use the existing degreeList """ edgeList = self.__get_edgelist(tstamp) edgeList.edges[newEdge.name] = newEdge self.edgeMap[newEdge.name] = newEdge """ this is optimization because most of the degrees of the nodes hasn't changed and therefore we can reuse the existing list """ srcDegree, tarDegree = self.__incr_degree_of_edge_nodes(newEdge) if srcDegree == 1: self.degreeList.add(1) else: self.degreeList.remove(srcDegree - 1) self.degreeList.add(srcDegree) if tarDegree == 1: self.degreeList.add(1) else: self.degreeList.remove(tarDegree - 1) self.degreeList.add(tarDegree) self.median = self.__calculate_median(use_existing_list=True) return """this transaction is newer and we need to move the window 1. update the low and high markers of the timestamp window 2. create edgelist with this newer timestamp 2. add the new edge to the edgelist 3. add the new edge to the edgemap 4. create new Nodes of the edges if needed or update their degrees 5. calculate the median (but reconstruct the degreeList) """ #this tx is newer and we need to move the window self.highMarker = tstamp self.lowMarker = tstamp - TxGraph.WINDOW_SIZE + 1 self.__update_tx_window() if newEdge.name in self.edgeMap: self.__update_tstamp_for_existing_edge(newEdge.name, tstamp) else: edgeList = self.__get_edgelist(tstamp) edgeList.edges[newEdge.name] = newEdge self.edgeMap[newEdge.name] = newEdge self.__incr_degree_of_edge_nodes(newEdge) self.median = self.__calculate_median()
class Gauge(object): """Represents a gauge. A gauge has a value at any moment. It can be modified by an user's adjustment or an effective momentum. """ __slots__ = ( #: The base time and value. 'base', #: A sorted list of momenta. The items are :class:`Momentum` objects. 'momenta', #: The constant maximum value. 'max_value', #: The gauge to indicate maximum value. 'max_gauge', #: The constant minimum value. 'min_value', #: The gauge to indicate minimum value. 'min_gauge', # internal attributes. '_determination', '_events', '_limited_gauges', '__weakref__', ) def __init__(self, value, max, min=0, at=None): self.__preinit__() at = now_or(at) self.base = (at, value) self._set_range(max, min, at=at, _incomplete=True) def __preinit__(self): """Called by :meth:`__init__` and :meth:`__setstate__`.""" self.max_gauge = self.min_gauge = None self.momenta = SortedListWithKey(key=by_until) self._determination = None self._events = SortedList() # a weak set of gauges that refer the gauge as a limit gauge. self._limited_gauges = WeakSet() @property def determination(self): """The cached determination. If there's no the cache, it redetermines and caches that. A determination is a sorted list of 2-dimensional points which take times as x-values, gauge values as y-values. """ if self._determination is None: # redetermine and cache. self._determination = Determination(self) return self._determination def invalidate(self): """Invalidates the cached determination. If you touches the determination at the next first time, that will be redetermined. You don't need to call this method because all mutating methods such as :meth:`incr` or :meth:`add_momentum` calls it. :returns: whether the gauge is invalidated actually. """ if self._determination is None: return False # remove the cached determination. self._determination = None # invalidate limited gauges together. for gauge in self._limited_gauges: gauge._limit_gauge_invalidated(self) return True def get_max(self, at=None): """Predicts the current maximum value.""" if self.max_gauge is None: return self.max_value else: return self.max_gauge.get(at) def get_min(self, at=None): """Predicts the current minimum value.""" if self.min_gauge is None: return self.min_value else: return self.min_gauge.get(at) #: The alias of :meth:`get_max`. max = get_max #: The alias of :meth:`get_min`. min = get_min def _set_range(self, max_=None, min_=None, at=None, _incomplete=False): at = now_or(at) forget_until = at # _incomplete=True when __init__() calls it. if not _incomplete: value = self.get(at) in_range_since = self.determination.in_range_since items = [('max', max_, self.max_gauge, min), ('min', min_, self.min_gauge, max)] for name, limit, prev_limit_gauge, clamp in items: if limit is None: continue if prev_limit_gauge is not None: # unlink from the previous limit gauge. prev_limit_gauge._limited_gauges.discard(self) if isinstance(limit, Gauge): limit_gauge, limit_value = limit, limit.get(at) forget_until = min(forget_until, limit_gauge.base[TIME]) else: limit_gauge, limit_value = None, limit # set limit attrs value_attr, gauge_attr = name + '_value', name + '_gauge' if limit_gauge is None: setattr(self, value_attr, limit_value) setattr(self, gauge_attr, None) else: setattr(self, value_attr, None) setattr(self, gauge_attr, limit_gauge) limit_gauge._limited_gauges.add(self) if _incomplete or in_range_since is None: continue elif in_range_since <= at: value = clamp(value, limit_value) if _incomplete: return return self.forget_past(value, at=forget_until) def set_max(self, max, at=None): """Changes the maximum. :param max: a number or gauge to set as the maximum. :param at: the time to change. (default: now) """ return self._set_range(max_=max, at=at) def set_min(self, min, at=None): """Changes the minimum. :param min: a number or gauge to set as the minimum. :param at: the time to change. (default: now) """ return self._set_range(min_=min, at=at) def set_range(self, max=None, min=None, at=None): """Changes the both of maximum and minimum at once. :param max: a number or gauge to set as the maximum. (optional) :param min: a number or gauge to set as the minimum. (optional) :param at: the time to change. (default: now) """ return self._set_range(max, min, at=at) def _predict(self, at=None): """Predicts the current value and velocity. :param at: the time to observe. (default: now) """ at = now_or(at) determination = self.determination if len(determination) == 1: # skip bisect_right() because it is expensive x = 0 else: x = bisect_right(determination, (at, +inf)) if x == 0: return (determination[0][VALUE], 0.) try: time2, value2 = determination[x] except IndexError: return (determination[-1][VALUE], 0.) time1, value1 = determination[x - 1] value = Segment._calc_value(at, time1, time2, value1, value2) velocity = Segment._calc_velocity(time1, time2, value1, value2) if determination.in_range_since is None: pass elif determination.in_range_since <= time1: value = self._clamp(value, at=at) return (value, velocity) def get(self, at=None): """Predicts the current value. :param at: the time to observe. (default: now) """ value, velocity = self._predict(at) return value def velocity(self, at=None): """Predicts the current velocity. :param at: the time to observe. (default: now) """ value, velocity = self._predict(at) return velocity def goal(self): """Predicts the final value.""" return self.determination[-1][VALUE] def incr(self, delta, outbound=ERROR, at=None): """Increases the value by the given delta immediately. The determination would be changed. :param delta: the value to increase. :param outbound: the strategy to control modification to out of the range. (default: ERROR) :param at: the time to increase. (default: now) :raises ValueError: the value is out of the range. """ at = now_or(at) prev_value = self.get(at=at) value = prev_value + delta if outbound == ONCE: outbound = OK if self.in_range(at) else ERROR if outbound != OK: items = [( self.get_max, max, operator.gt, 'The value to set is bigger than the maximum ({0} > {1})' ), ( self.get_min, min, operator.lt, 'The value to set is smaller than the minimum ({0} < {1})' )] for get_limit, clamp, cmp_, error_form in items: if not cmp_(delta, 0): continue limit = get_limit(at) if not cmp_(value, limit): continue if outbound == ERROR: raise ValueError(error_form.format(value, limit)) elif outbound == CLAMP: value = clamp(prev_value, limit) break return self.forget_past(value, at=at) def decr(self, delta, outbound=ERROR, at=None): """Decreases the value by the given delta immediately. The determination would be changed. :param delta: the value to decrease. :param outbound: the strategy to control modification to out of the range. (default: ERROR) :param at: the time to decrease. (default: now) :raises ValueError: the value is out of the range. """ return self.incr(-delta, outbound=outbound, at=at) def set(self, value, outbound=ERROR, at=None): """Sets the current value immediately. The determination would be changed. :param value: the value to set. :param outbound: the strategy to control modification to out of the range. (default: ERROR) :param at: the time to set. (default: now) :raises ValueError: the value is out of the range. """ at = now_or(at) delta = value - self.get(at=at) return self.incr(delta, outbound=outbound, at=at) def _clamp(self, value, at=None): at = now_or(at) max_ = self.get_max(at) if value > max_: return max_ min_ = self.get_min(at) if value < min_: return min_ return value def clamp(self, at=None): """Clamps the current value.""" at = now_or(at) value = self._clamp(self.get(at), at=at) return self.set(value, outbound=OK, at=at) def when(self, value, after=0): """When the gauge reaches to the goal value. :param value: the goal value. :param after: take (n+1)th time. (default: 0) :raises ValueError: the gauge will not reach to the goal value. """ x = 0 for x, at in enumerate(self.whenever(value)): if x == after: return at form = 'The gauge will not reach to {0}' + \ (' more than {1} times' if x else '') raise ValueError(form.format(value, x)) def whenever(self, value): """Yields multiple times when the gauge reaches to the goal value. :param value: the goal value. """ if self.determination: determination = self.determination first_time, first_value = determination[0] if first_value == value: yield first_time zipped_determination = zip(determination[:-1], determination[1:]) for (time1, value1), (time2, value2) in zipped_determination: if not (value1 < value <= value2 or value1 > value >= value2): continue ratio = (value - value1) / float(value2 - value1) yield (time1 + (time2 - time1) * ratio) def in_range(self, at=None): """Whether the gauge is between the range at the given time. :param at: the time to check. (default: now) """ in_range_since = self.determination.in_range_since if in_range_since is None: return False at = now_or(at) return in_range_since <= at @staticmethod def _make_momentum(velocity_or_momentum, since=None, until=None): """Makes a :class:`Momentum` object by the given arguments. Override this if you want to use your own momentum class. :param velocity_or_momentum: a :class:`Momentum` object or just a number for the velocity. :param since: if the first argument is a velocity, it is the time to start to affect the momentum. (default: ``-inf``) :param until: if the first argument is a velocity, it is the time to finish to affect the momentum. (default: ``+inf``) :raises ValueError: `since` later than or same with `until`. :raises TypeError: the first argument is a momentum, but other arguments passed. """ if isinstance(velocity_or_momentum, Momentum): if not (since is until is None): raise TypeError('Arguments behind the first argument as a ' 'momentum should be None') momentum = velocity_or_momentum else: velocity = velocity_or_momentum if since is None: since = -inf if until is None: until = +inf momentum = Momentum(velocity, since, until) since, until = momentum.since, momentum.until if since == -inf or until == +inf or since < until: pass else: raise ValueError('\'since\' should be earlier than \'until\'') return momentum def add_momenta(self, momenta): """Adds multiple momenta.""" for momentum in momenta: self.momenta.add(momentum) self._events.add((momentum.since, ADD, momentum)) if momentum.until != +inf: self._events.add((momentum.until, REMOVE, momentum)) self.invalidate() def remove_momenta(self, momenta): """Removes multiple momenta.""" for momentum in momenta: try: self.momenta.remove(momentum) except ValueError: raise ValueError('{0} not in the gauge'.format(momentum)) self._events.remove((momentum.since, ADD, momentum)) if momentum.until != +inf: self._events.remove((momentum.until, REMOVE, momentum)) self.invalidate() def add_momentum(self, *args, **kwargs): """Adds a momentum. A momentum includes the velocity and the times to start to affect and to stop to affect. The determination would be changed. All arguments will be passed to :meth:`_make_momentum`. :returns: a momentum object. Use this to remove the momentum by :meth:`remove_momentum`. :raises ValueError: `since` later than or same with `until`. """ momentum = self._make_momentum(*args, **kwargs) self.add_momenta([momentum]) return momentum def remove_momentum(self, *args, **kwargs): """Removes the given momentum. The determination would be changed. All arguments will be passed to :meth:`_make_momentum`. :raises ValueError: the given momentum not in the gauge. """ momentum = self._make_momentum(*args, **kwargs) self.remove_momenta([momentum]) return momentum def momentum_events(self): """Yields momentum adding and removing events. An event is a tuple of ``(time, ADD|REMOVE, momentum)``. """ yield (self.base[TIME], None, None) momentum_ids = set(id(m) for m in self.momenta) for time, method, momentum in list(self._events): if id(momentum) not in momentum_ids: self._events.remove((time, method, momentum)) continue yield time, method, momentum yield (+inf, None, None) def _rebase(self, value=None, at=None, remove_momenta_before=None): """Sets the base and removes momenta between indexes of ``start`` and ``stop``. :param value: the value to set coercively. (default: the current value) :param at: the time to set. (default: now) :param remove_momenta_before: the stopping index of momentum removal. (default: the last) """ at = now_or(at) if value is None: value = self.get(at=at) for gauge in self._limited_gauges: gauge._limit_gauge_rebased(self, value, at=at) self.base = (at, value) del self.momenta[:remove_momenta_before] self.invalidate() return value def clear_momenta(self, value=None, at=None): """Removes all momenta. The value is set as the current value. The determination would be changed. :param value: the value to set coercively. :param at: the time base. (default: now) """ return self._rebase(value, at=at, remove_momenta_before=None) def forget_past(self, value=None, at=None): """Discards the momenta which doesn't effect anymore. :param value: the value to set coercively. :param at: the time base. (default: now) """ at = now_or(at) x = self.momenta.bisect_left((-inf, -inf, at)) return self._rebase(value, at=at, remove_momenta_before=x) def _limit_gauge_invalidated(self, limit_gauge): """The callback function which will be called at a limit gauge is invalidated. """ self.invalidate() def _limit_gauge_rebased(self, limit_gauge, limit_value, at=None): """The callback function which will be called at a limit gauge is rebased. """ at = max(now_or(at), self.base[TIME]) value = self.get(at) if self.in_range(at): clamp = {self.max_gauge: min, self.min_gauge: max}[limit_gauge] value = clamp(value, limit_value) self.forget_past(value, at=at) def __reduce__(self): return restore_gauge, ( self.__class__, self.base, list(tuple(m) for m in self.momenta), self.max_value, self.max_gauge, self.min_value, self.min_gauge ) def __repr__(self, at=None): """Example strings: - ``<Gauge 0.00/2.00>`` - ``<Gauge 0.00 between 1.00~2.00>`` - ``<Gauge 0.00 between <Gauge 0.00/2.00>~<Gauge 2.00/2.00>>`` """ at = now_or(at) value = self.get(at=at) hyper = False limit_reprs = [] limit_items = [(self.max_value, self.max_gauge), (self.min_value, self.min_gauge)] for limit_value, limit_gauge in limit_items: if limit_gauge is None: limit_reprs.append('{0:.2f}'.format(limit_value)) else: hyper = True limit_reprs.append('{0!r}'.format(limit_gauge)) form = '<{0} {1:.2f}' if not hyper and self.min_value == 0: form += '/{2}>' else: form += ' between {3}~{2}>' return form.format(type(self).__name__, value, *limit_reprs)
class TTLCache(object): """A key/value cache implementation where each entry has its own TTL""" def __init__(self, cache_name, timer=time.time): # map from key to _CacheEntry self._data = {} # the _CacheEntries, sorted by expiry time self._expiry_list = SortedList() self._timer = timer self._metrics = register_cache("ttl", cache_name, self) def set(self, key, value, ttl): """Add/update an entry in the cache Args: key: key for this entry value: value for this entry ttl (float): TTL for this entry, in seconds """ expiry = self._timer() + ttl self.expire() e = self._data.pop(key, SENTINEL) if e != SENTINEL: self._expiry_list.remove(e) entry = _CacheEntry(expiry_time=expiry, key=key, value=value) self._data[key] = entry self._expiry_list.add(entry) def get(self, key, default=SENTINEL): """Get a value from the cache Args: key: key to look up default: default value to return, if key is not found. If not set, and the key is not found, a KeyError will be raised Returns: value from the cache, or the default """ self.expire() e = self._data.get(key, SENTINEL) if e == SENTINEL: self._metrics.inc_misses() if default == SENTINEL: raise KeyError(key) return default self._metrics.inc_hits() return e.value def get_with_expiry(self, key): """Get a value, and its expiry time, from the cache Args: key: key to look up Returns: Tuple[Any, float]: the value from the cache, and the expiry time Raises: KeyError if the entry is not found """ self.expire() try: e = self._data[key] except KeyError: self._metrics.inc_misses() raise self._metrics.inc_hits() return e.value, e.expiry_time def pop(self, key, default=SENTINEL): """Remove a value from the cache If key is in the cache, remove it and return its value, else return default. If default is not given and key is not in the cache, a KeyError is raised. Args: key: key to look up default: default value to return, if key is not found. If not set, and the key is not found, a KeyError will be raised Returns: value from the cache, or the default """ self.expire() e = self._data.pop(key, SENTINEL) if e == SENTINEL: self._metrics.inc_misses() if default == SENTINEL: raise KeyError(key) return default self._expiry_list.remove(e) self._metrics.inc_hits() return e.value def __getitem__(self, key): return self.get(key) def __delitem__(self, key): self.pop(key) def __contains__(self, key): return key in self._data def __len__(self): self.expire() return len(self._data) def expire(self): """Run the expiry on the cache. Any entries whose expiry times are due will be removed """ now = self._timer() while self._expiry_list: first_entry = self._expiry_list[0] if first_entry.expiry_time - now > 0.0: break del self._data[first_entry.key] del self._expiry_list[0]
class Fuzzer(object): def __init__(self, classifier, lifecycle=None): self.__strings_by_tag = {} self.__corpus = SortedList() self.__refcounts = {} self.__classifier = classifier self.__seen = set() self.__fully_shrunk = set() self.__lifecycle = lifecycle or LifeCycle() self.__counter = 0 self.incorporate(b'') assert len(self.__corpus) == len(self.__refcounts) == len(self.__seen) \ == 1 def incorporate(self, string): key = cache_key(string) if key in self.__seen: return self.__seen.add(key) labels = self.__classifier(string) item = CorpusItem(string) new_labels = set() improved_labels = set() for l in labels: if ( l not in self.__strings_by_tag or item < CorpusItem(self.__strings_by_tag[l]) ): self.__incref(string) if l in self.__strings_by_tag: self.__decref(self.__strings_by_tag[l]) improved_labels.add(l) else: new_labels.add(l) self.__strings_by_tag[l] = string if new_labels: self.__lifecycle.new_labels(new_labels) if improved_labels: self.__lifecycle.labels_improved(improved_labels) def fuzz(self): while True: for target in reversed(self.__corpus): key = cache_key(target.string) if key not in self.__fully_shrunk: for string in self.__shrinks(target.string): self.incorporate(string) if target not in self.__corpus: break else: self.__fully_shrunk.add(key) break def __shrinkers(self): n = len(self.__corpus[-1].string) while n > 1: yield self.__cutter(n, n) n //= 2 yield self.__byte_clearing n = len(self.__corpus[-1].string) while n > 1: i = n while i > 0: yield self.__cutter(i, n) i //= 2 n -= 1 def __shrinks(self, string): for shrinker in self.__shrinkers(): for s in shrinker(string): yield s def __byte_clearing(self, string): counter = Counter(string) for c in sorted(counter, key=lambda x: (-counter[x], x)): yield string.replace(bytes([c]), b'') def __cutter(self, step, size): assert step > 0 assert size > 0 def accept(string): if size >= len(string): return i = 0 while i + size <= len(string): yield string[:i] + string[i+size:] i += step accept.__name__ = '__cutter(%d, %d)' % (step, size) return accept def __incref(self, string): c = self.__refcounts.get(string, 0) assert c >= 0 if c == 0: self.__counter += 1 self.__corpus.add(CorpusItem(string)) self.__lifecycle.item_added(string) self.__refcounts[string] = c + 1 def __decref(self, string): assert self.__refcounts[string] > 0 self.__refcounts[string] -= 1 if self.__refcounts[string] <= 0: self.__counter += 1 self.__corpus.remove(CorpusItem(string)) del self.__refcounts[string] self.__lifecycle.item_removed(string)
def test_remove_valueerror3(): slt = SortedList([1, 2, 2, 2, 3, 3, 5]) with pytest.raises(ValueError): slt.remove(4)
class ParetoFront: """Maintains an approximate pareto front of ConjectureData objects. That is, we try to maintain a collection of objects such that no element of the collection is pareto dominated by any other. In practice we don't quite manage that, because doing so is computationally very expensive. Instead we maintain a random sample of data objects that are "rarely" dominated by any other element of the collection (roughly, no more than about 10%). Only valid test cases are considered to belong to the pareto front - any test case with a status less than valid is discarded. Note that the pareto front is potentially quite large, and currently this will store the entire front in memory. This is bounded by the number of valid examples we run, which is max_examples in normal execution, and currently we do not support workflows with large max_examples which have large values of max_examples very well anyway, so this isn't a major issue. In future we may weish to implement some sort of paging out to disk so that we can work with larger fronts. Additionally, because this is only an approximate pareto front, there are scenarios where it can be much larger than the actual pareto front. There isn't a huge amount we can do about this - checking an exact pareto front is intrinsically quadratic. "Most" of the time we should be relatively close to the true pareto front, say within an order of magnitude, but it's not hard to construct scenarios where this is not the case. e.g. suppose we enumerate all valid test cases in increasing shortlex order as s_1, ..., s_n, ... and have scores f and g such that f(s_i) = min(i, N) and g(s_i) = 1 if i >= N, then the pareto front is the set {s_1, ..., S_N}, but the only element of the front that will dominate s_i when i > N is S_N, which we select with probability 1 / N. A better data structure could solve this, but at the cost of more expensive operations and higher per element memory use, so we'll wait to see how much of a problem this is in practice before we try that. """ def __init__(self, random): self.__random = random self.__eviction_listeners = [] self.front = SortedList(key=lambda d: sort_key(d.buffer)) self.__pending = None def add(self, data): """Attempts to add ``data`` to the pareto front. Returns True if ``data`` is now in the front, including if data is already in the collection, and False otherwise""" data = data.as_result() if data.status < Status.VALID: return False if not self.front: self.front.add(data) return True if data in self.front: return True # We add data to the pareto front by adding it unconditionally and then # doing a certain amount of randomized "clear down" - testing a random # set of elements (currently 10) to see if they are dominated by # something else in the collection. If they are, we remove them. self.front.add(data) assert self.__pending is None try: self.__pending = data # We maintain a set of the current exact pareto front of the # values we've sampled so far. When we sample a new element we # either add it to this exact pareto front or remove it from the # collection entirely. front = LazySequenceCopy(self.front) # We track which values we are going to remove and remove them all # at the end so the shape of the front doesn't change while we're # using it. to_remove = [] # We now iteratively sample elements from the approximate pareto # front to check whether they should be retained. When the set of # dominators gets too large we have sampled at least 10 elements # and it gets too expensive to continue, so we consider that enough # due diligence. i = self.front.index(data) # First we attempt to look for values that must be removed by the # addition of the data. These are necessarily to the right of it # in the list. failures = 0 while i + 1 < len(front) and failures < 10: j = self.__random.randrange(i + 1, len(front)) swap(front, j, len(front) - 1) candidate = front.pop() dom = dominance(data, candidate) assert dom != DominanceRelation.RIGHT_DOMINATES if dom == DominanceRelation.LEFT_DOMINATES: to_remove.append(candidate) failures = 0 else: failures += 1 # Now we look at the points up to where we put data in to see if # it is dominated. While we're here we spend some time looking for # anything else that might be dominated too, compacting down parts # of the list. dominators = [data] while i >= 0 and len(dominators) < 10: swap(front, i, self.__random.randint(0, i)) candidate = front[i] already_replaced = False j = 0 while j < len(dominators): v = dominators[j] dom = dominance(candidate, v) if dom == DominanceRelation.LEFT_DOMINATES: if not already_replaced: already_replaced = True dominators[j] = candidate j += 1 else: dominators[j], dominators[-1] = ( dominators[-1], dominators[j], ) dominators.pop() to_remove.append(v) elif dom == DominanceRelation.RIGHT_DOMINATES: to_remove.append(candidate) break elif dom == DominanceRelation.EQUAL: break else: j += 1 else: dominators.append(candidate) i -= 1 for v in to_remove: self.__remove(v) return data in self.front finally: self.__pending = None def on_evict(self, f): """Register a listener function that will be called with data when it gets removed from the front because something else dominates it.""" self.__eviction_listeners.append(f) def __contains__(self, data): return isinstance(data, (ConjectureData, ConjectureResult)) and ( data.as_result() in self.front ) def __iter__(self): return iter(self.front) def __getitem__(self, i): return self.front[i] def __len__(self): return len(self.front) def __remove(self, data): try: self.front.remove(data) except ValueError: return if data is not self.__pending: for f in self.__eviction_listeners: f(data)
class TTLCache(object): """A key/value cache implementation where each entry has its own TTL""" def __init__(self, cache_name, timer=time.time): # map from key to _CacheEntry self._data = {} # the _CacheEntries, sorted by expiry time self._expiry_list = SortedList() self._timer = timer def set(self, key, value, ttl): """Add/update an entry in the cache :param key: Key for this entry. :param value: Value for this entry. :param paramttl: TTL for this entry, in seconds. :type paramttl: float """ expiry = self._timer() + ttl self.expire() e = self._data.pop(key, SENTINEL) if e != SENTINEL: self._expiry_list.remove(e) entry = _CacheEntry(expiry_time=expiry, key=key, value=value) self._data[key] = entry self._expiry_list.add(entry) def get(self, key, default=SENTINEL): """Get a value from the cache :param key: The key to look up. :param default: default value to return, if key is not found. If not set, and the key is not found, a KeyError will be raised. :returns a value from the cache, or the default. """ self.expire() e = self._data.get(key, SENTINEL) if e == SENTINEL: if default == SENTINEL: raise KeyError(key) return default return e.value def get_with_expiry(self, key): """Get a value, and its expiry time, from the cache :param key: key to look up :returns The value from the cache, and the expiry time. :rtype: Tuple[Any, float] Raises: KeyError if the entry is not found """ self.expire() try: e = self._data[key] except KeyError: raise return e.value, e.expiry_time def pop(self, key, default=SENTINEL): """Remove a value from the cache If key is in the cache, remove it and return its value, else return default. If default is not given and key is not in the cache, a KeyError is raised. :param key: key to look up :param default: default value to return, if key is not found. If not set, and the key is not found, a KeyError will be raised :returns a value from the cache, or the default """ self.expire() e = self._data.pop(key, SENTINEL) if e == SENTINEL: if default == SENTINEL: raise KeyError(key) return default self._expiry_list.remove(e) return e.value def __getitem__(self, key): return self.get(key) def __delitem__(self, key): self.pop(key) def __contains__(self, key): return key in self._data def __len__(self): self.expire() return len(self._data) def expire(self): """Run the expiry on the cache. Any entries whose expiry times are due will be removed """ now = self._timer() while self._expiry_list: first_entry = self._expiry_list[0] if first_entry.expiry_time - now > 0.0: break del self._data[first_entry.key] del self._expiry_list[0]
class Book: """ Represent bid / ask book. * Book only allows one user order at a time * User order will not affect book statistics like quote and volume """ def __init__(self, side: str, key_func: Optional[Callable[[int], int]]) -> None: self.side = side self.key_func = key_func if key_func else lambda x: x # We need this because price levels follow price priority not time priority (which dict alone can provides) self.prices = SortedList(key=key_func) # Sorted prices self.price_levels: Dict[int, PriceLevel] = {} # Price to level map self.order_pool: Dict[int, PriceLevel] = {} # Order ID to level map # Store order price and PriceLevel. We do not need ID since there is only one order self.user_order_info: Optional[Tuple[int, PriceLevel]] = None self._front_idx: Optional[int] = None def reset(self): self.prices.clear() self.price_levels.clear() self.order_pool.clear() self.user_order_info = None self._front_idx = None # ========== Order Operations ========== def add_limit_order(self, order: LimitOrder) -> None: """ Add limit order to the correct price level """ if order.id in self.order_pool: raise RuntimeError(f'LimitOrder {order.id} already exists') self.order_pool[order.id] = self._get_price_level(order.price, force_index=True).add_limit_order(order) def match_limit_order(self, market_order: MarketOrder) -> Tuple[bool, Optional[Execution]]: """ Match environment order against limit order. Remove empty price level where needed """ # Sometime environment order may not follow time priority. We should follow the referenced order ID in this case user_order = None target_price_level = self.order_pool[market_order.id] # User orders may create price levels that do not exist in the real market. Need to match against those first if target_price_level.price != self.prices[0]: top_level = self.price_levels[self.prices[0]] if top_level.shares > 0: # Shares > 0 means that there are real LimitOrder exists in the top level raise RuntimeError('Market order being matched against levels not in the front') user_order = top_level.pop_user_order() self._remove_price_level_if_empty(top_level) # Now get the user orders that are in front of the matched real LimitOrder price_level, exhausted, executed_order = target_price_level.match_limit_order(market_order) self._remove_price_level_if_empty(price_level) # It can be that both order are None if executed_order is not None: user_order = executed_order # Whether the matching limit order is already exhausted if exhausted: del self.order_pool[market_order.id] # Update user order pool and return executions return exhausted, self._handle_matched_user_limit_order(user_order) if user_order else None def cancel_order(self, order: CancelOrder) -> None: """ Cancel (partial) shares of a LimitOrder """ self.order_pool[order.id].cancel_order(order) def delete_order(self, order: DeleteOrder): """ Delete the whole LimitOrder """ price_level = self.order_pool[order.id].delete_order(order) del self.order_pool[order.id] self._remove_price_level_if_empty(price_level) # ========== User Order Operation ========== def add_user_limit_order(self, order: UserLimitOrder) -> None: """ Add user limit order to the correct price level * Remove the old user order if exists * We do not want to deal with time priority because * This simplifies the flow * Last action's effect will spill over to the current one """ if self.user_order_info: original_id, price_level = self.user_order_info price_level.pop_user_order() # Only one user order is allowed self._remove_price_level_if_empty(price_level) self.user_order_info = order.price, self._get_price_level(order.price).add_user_limit_order(order) def match_limit_order_for_user(self, order: UserMarketOrder) -> Execution: """ Match LimitOrder for UserMarketOrder """ if self.user_order_info: raise RuntimeError('Cannot execute MarketOrder on the side that also has user LimitOrder') total_value = 0 shares = 0 # Recall that we are not actually matching the LimitOrders. No need to remove the executed LimitOrder. for price in self.prices: executed = order.shares - self.price_levels[price].match_limit_order_for_user(order) total_value += price * executed shares += executed if order.shares == 0: break if order.shares > 0: raise RuntimeError('User market order cannot be fully executed') return Execution(order.id, int(total_value / shares), shares if order.side == 'B' else -shares) def delete_user_order(self): """ Remove user order """ if self.user_order_info: _, price_level = self.user_order_info price_level.pop_user_order() self._remove_price_level_if_empty(price_level) self.user_order_info = None def resolve_book_crossing_on_user_order(self, price: int) -> Optional[Execution]: """ User orders may be placed inside the real market, in which case the newly added real order may cross with the user orders. When this happens, we assume that the user orders are executed """ signed_price = self.key_func(price) quote = self.quote if quote and self.key_func(quote) <= signed_price: raise RuntimeError('Real order crosses real order') if self.user_order_info and self.key_func(self.user_order_info[0]) <= signed_price: price_level = self.price_levels[self.user_order_info[0]] execution = self._handle_matched_user_limit_order(price_level.pop_user_order()) # Must be empty self._remove_price_level_if_empty(price_level) return execution return None # ========== Private Methods ========== def _get_price_level(self, price: int, force_index=False) -> PriceLevel: """ Return price level indicated by price. Price level will be added if not already exists """ level = self.price_levels.get(price, None) # shares == 0 means that the PriceLevel was previously occupied by user order only if level is None: self.prices.add(price) level = PriceLevel(price) self.price_levels[price] = level # force_index is used when we are adding a new price level for real order. Order is not added at this point # and shares will be 0. Therefore, we need to force it # On the other hand, we still need to run update_front_index for user order because it may change the # ordering self._update_front_index(force_index, price) elif level.shares == 0: self._update_front_index(force_index, price) return level def _remove_price_level_if_empty(self, price_level: PriceLevel): """ Remove PriceLevel if empty """ if price_level.empty: del self.price_levels[price_level.price] # "remove" will raise ValueError if not exists self.prices.remove(price_level.price) if price_level.shares == 0: # Separate from the logic above because we run be in the situation where real orders are exhausted # but at least one user order is waiting. In this case, this price level is technically gone self._update_front_index() def _update_front_index(self, force_index=False, target_price=None) -> None: """ Find out the first price level that has real order """ if not self.prices: self._front_idx = None else: price = self.prices[0] if self.price_levels[price].shares > 0 or (force_index and price == target_price): self._front_idx = 0 else: self._front_idx = 1 if len(self.prices) > 1 else None def _handle_matched_user_limit_order(self, order: UserLimitOrder) -> Execution: """ Book-keeping actions for UserLimitOrder execution """ self.user_order_info = None return Execution(order.id, order.price, order.shares if self.side == 'B' else -order.shares) # ========== Properties ========== # These statistics should not include user orders. Otherwise, we may end up being our own market @property def quote(self) -> Optional[int]: """ Return the front price without user orders """ if self._front_idx is not None: return self.prices[self._front_idx] return None @property def volume(self) -> Optional[int]: """ Return the volume at the front without user orders """ if self._front_idx is not None: return self.price_levels[self.quote].shares return None def get_depth(self, num_levels: int) -> List[Tuple[int, int]]: """ Return the top n price levels without user orders """ if self._front_idx is not None: return [(price, self.price_levels[price].shares) for price in self.prices[self._front_idx: self._front_idx + num_levels]] return [] @property def empty(self) -> bool: if len(self.order_pool) == 0: if self.user_order_info is None: return len(self.prices) == 0 return len(self.prices) == 1 and self.price_levels[self.prices[0]].shares == 0 return False @property def user_order_price(self) -> Optional[int]: if self.user_order_info: return self.user_order_info[0] return None
class SweepLine(object): ''' This class represents the vertical sweep line which sweeps over the set of segments in the Bentley-Ottmann algorithm. At any moment, it contains a sorted list of all the ComparableSegments which intersect with the sweep line in its current position. Note that if 2 segments s1 and s2 are overlapping, you cannot assume anything about their order in the sorted queue, as s1 < s2 and s1 > s2 are both false Such sorted list would usually rely on a balanced binary search tree data structure in order to have O(log(N)) insertion, deletion and swapping. Instead, I chose to use a SortedList of Grant Jenks' SortedContainers module, which has several advantages that you can discover by browsing its page. It allows O(log(N)) insertion, deletion and swapping, and I find it to be faster in practice. ''' def __init__(self): ''' Initializes an empty sweep line. ''' self.l = SortedList() def isEmpty(self): ''' Returns true if and only if the sweep line is empty. ''' return len(self.l) == 0 def addSegment(self, seg): ''' Adds seg to the sweep line. ''' ComparableSegment.currentX = seg.x1 self.l.add(seg) def removeSegment(self, seg): ''' Removes seg from the sweep line. ''' self.l.remove(seg) def belowSegments(self, seg): ''' Returns a list containing : - The highest segment s_below contained in the sweep line such as s_below.isBelow(seg) - All the segments s contained before s_below in the sweep line but such as s.isBelow(s_below) is false, (i.e. which have the same y-coordinate at ComparableSegment.currentX and gradient). ''' res = [] # i = index of seg i = self.l.index(seg) # Passes segments which have same y-coordinate and gradient # to find s_below while i-1 >= 0: prev = self.l[i-1] i -= 1 if prev.isBelow(seg): res.append(prev) break # Appends all the segments which have same y-coordinate and # gradient as s_below while i-1 >= 0: prev = self.l[i-1] if prev.isBelow(res[0]): break res.append(prev) i -= 1 return res def aboveSegments(self, seg): ''' Returns a list containing : - The lowest segment s_above contained in the sweep line such as seg < s_above - All the segments s contained after s_above in the sweep line but such as s_below < s is false, (i.e. which have the same y-coordinate at ComparableSegment.currentX and gradient). ''' res = [] # i = index of seg i = self.l.index(seg) # Passes segments which have same y-coordinate and gradient # to find s_above while i+1 < len(self.l): succ = self.l[i+1] i += 1 if seg.isBelow(succ): res.append(succ) break # Appends all the segments which have same y-coordinate and # gradient as s_above while i+1 < len(self.l): succ = self.l[i+1] if res[0].isBelow(succ): break res.append(succ) i += 1 return res def sameLevelAs(self, seg): ''' Returns a list containing the segments s of the line such as s.aboveSegments(seg) and seg.aboveSegments(s) are both false, i.e. all the segments with same y-coordinate at ComparableSegment.currentX and gradient as seg. ''' i = self.l.index(seg) res = [self.l[i]] # Looks for same level segments above j = i + 1 while j < len(self.l) and not seg.isBelow(self.l[j]): res.append(self.l[j]) j += 1 # Looks for same level segments below j = i - 1 while j >= 0 and not self.l[j].isBelow(seg): res.append(self.l[j]) j -= 1 return res def betweenY(self, y_inf, y_sup, x): ''' Returns a list of all the segments intersecting the sweep line between y-coordinates y_inf and y_sup included, at x-coordinate x ''' ComparableSegment.currentX = x res = [] i = 0 # Passes segments whose y-coordinate is < y_inf while i < len(self.l) and self.l[i].yAtX(x) < y_inf: i += 1 while i < len(self.l) and self.l[i].yAtX(x) <= y_sup: res.append(self.l[i]) i += 1 return res def revertOrder(self, x, segments): ''' Reverse the order of segments in the sweep line, at coord (x, y). ''' indices = [] for seg in segments: ComparableSegment.currentX = seg.x1 indices.append(self.l.index(seg)) # Update segments currentX so that the swap keep sort order ComparableSegment.currentX = x # Swaps the segments for i in range(floor(len(segments)/2)): i1 = indices[i] i2 = indices[-i-1] self.l[i1] = segments[-i-1] self.l[i2] = segments[i]
def bentley_ottmann(filename, nodisp=False, noinfo=False): """ Fonction principale de notre projet """ global COUPE COUPE = 0 y_cour = None adjuster, segments = load_segments(filename) actifs = SortedList() evenements = [ ] #liste de nos evenements, valeurs des y, que lon transformera en Tas ensuite pt_inter = { } #dictionnaire que lon retournera a la fin, associant les segments interseptés index = 0 cache_inters = {} #cache qui nous dira si on a deja compare 2 seg intersections = [] #liste contenant tous nos points dintersections for seg in segments: #initialisation de nos evenements (x_0, y_0) = seg.endpoints[0].coordinates (x_1, y_1) = seg.endpoints[1].coordinates Segment.y_cour = [x_0, y_0] if y_0 < y_1: #Segments croissant suivant les y evenements.append([y_0, -x_0, seg, 'D']) evenements.append([y_1, -x_1, seg, 'F']) elif y_0 > y_1: #Segments decroissant suivant les y: evenements.append([y_0, -x_0, seg, 'F']) evenements.append([y_1, -x_1, seg, 'D']) else: #Cas d'un segment horizontal evenements.append([y_1, -min(x_0, x_1), seg, max(x_0, x_1)]) pt_inter[seg] = [] #Initialisation du dictionnaire cache_inters[seg] = [] heapify( evenements ) #Tas des evenement,3 types, 'D' 'F' 'I': Debut, fin, intersection #trié en fonction des y croissant, puis des x décroissants. def indice(seg): """ Retourne l'indice de seg dans la liste actifs, None si le segment n'est pas présent. Cette fonction auxiliaire est implémentée suite aux problèmes majeurs rencontrés avec la méthode index de la classe SortedList """ for i, elmt in enumerate(actifs): if seg is elmt: return i def intersection(seg, seg_2): """ Fonction qui va légitimer et gérer l'intersection entre 2 segments donnés. """ global COUPE if seg_2 not in cache_inters[seg]: #On ne compare pas deux segments #déja comparés intersection = seg.intersection_with(seg_2) cache_inters[seg].append(seg_2) cache_inters[seg_2].append(seg) if intersection is not None: intersection = adjuster.hash_point(intersection) #Ajustement if intersection not in seg.endpoints or intersection not in seg_2.endpoints: #Le point nest pas lextrémitié des deux segments pt_inter[seg].append(seg_2) pt_inter[seg_2].append(seg) heappush(evenements, [ intersection.coordinates[1], -intersection.coordinates[0], seg, 'I', seg_2 ]) #L'ordre dans le tuple est important: il permet de savoir #qui est à gauche ou à droite if intersection not in intersections: intersections.append(intersection) COUPE += 1 return while evenements: #Boucle traitant tous les évènements tant que notre tas #n'est pas vide. y_cour = heappop(evenements) if y_cour[3] == 'D': #evenement de debut de segment Segment.y_cour = [-y_cour[1], y_cour[0]] actifs = SortedList(actifs) #Mise à jour de actifs seg = y_cour[2] actifs.add(seg) #Ajout du nouveau segment aux actifs if len(actifs ) > 1: #Si un seul segment dans actifs: on ne fait rien try: index = actifs.index(seg) except ValueError: index = indice(seg) if index != len(actifs) - 1: seg_2 = actifs[index + 1] intersection(seg, seg_2) if index != 0: seg_2 = actifs[index - 1] intersection(seg_2, seg) elif y_cour[3] == 'F': #evenement de fin de segment Segment.y_cour = [-y_cour[1], y_cour[0]] actifs = SortedList(actifs) #Mise à jour de actifs seg = y_cour[2] try: index = actifs.index(seg) except ValueError: index = indice(seg) actifs.pop(index) actifs = SortedList(actifs) #Mise à jour de actifs if len(actifs) > 1: if 0 < index < len(actifs): #On n'enleve pas le seg le plus à #droite/gauche seg = actifs[index] seg_2 = actifs[index - 1] intersection(seg, seg_2) elif y_cour[3] == 'I': #evenement de point d'intersection seg, seg_2 = y_cour[2], y_cour[4] try: actifs.remove(seg) except ValueError: index = indice(seg) if index is not None: #Renvoie parfois une erreur: #"segment not in actifs" del actifs[index] try: actifs.remove(seg_2) except ValueError: index_2 = indice(seg_2) if index_2 is not None: del actifs[index_2] Segment.y_cour = [-y_cour[1], y_cour[0] + 0.00000000001] #Cf. convention: A une intersection, on se situe #au dessus de l'intersection actifs = SortedList(actifs) #Mise à jour de actifs actifs.add(seg) #Une fois changés de place l'intersection passée, #on remet nos deux segments dans actifs actifs.add(seg_2) try: index = actifs.index(seg) #Indice du seg a droite une fois #l'intersection faite except ValueError: index = indice(seg) if len(actifs ) > 2: #On teste les nouvelles intersections possibles if index < len( actifs) - 1: #Cas de l'extrémité droite de actifs seg_2 = actifs[index + 1] intersection(seg, seg_2) if index - 1 != 0: #Cas de l'extrémité gauche seg_2 = actifs[index - 2] intersection(seg_2, y_cour[4]) else: #Cas dun segment horizontal seg_h = y_cour[2] for seg in actifs: inter = seg_h.intersection_with(seg) if inter: inter = adjuster.hash_point(inter) if inter not in seg_h.endpoints or inter not in seg.endpoints: #Le point n'est pas l'extrémité ds deux segments pt_inter[seg_h].append(seg) pt_inter[seg].append(seg_h) if inter not in intersections: intersections.append(inter) COUPE += 1 if nodisp and noinfo: return pt_inter, intersections if noinfo: tycat(segments, intersections) return pt_inter, intersections if nodisp: print( "Le nombre d'intersections (= le nombre de points differents) est : ", len(intersections)) print("Le nombre de coupes est : ", COUPE) return pt_inter, intersections print( "le nombre d'intersections (= le nombre de points differents) est : ", len(intersections)) print("le nombre de coupes est : ", COUPE)
class Table: 'Class that implements PK and FK constraints with Aerospike' #Class variables _registry = [] _NameSpace = "test_DB_SEF" _VerifyConstraints = True _UseFKTables = True _CurrentClient = None _Debug = True #Class methods @staticmethod def SetDebugMode(DebugTF): if DebugTF: Table._Debug = True else: Table._Debug = False @staticmethod def GetDebugMode(): if Table._Debug == True: print("Debug mode is On") else: print("Debug mode is Off") @staticmethod def SetVerifyConstraints(EnabledTF): if EnabledTF: Table._VerifyConstraints = True else: Table._VerifyConstraints = False @staticmethod def GetVerifyConstraints(): if Table._VerifyConstraints == True: print("Verifying Constraints is ENABLED") return True else: print("Verifying Constraints is DISABLED") return False @staticmethod def UseFKTables(EnabledTF): if EnabledTF: Table._UseFKTables = True else: Table._UseFKTables = False @staticmethod def GetUseFKTables(): if Table._UseFKTables == True: print("Use FK Tables is ENABLED") return True else: print("Use FK Tables is is DISABLED") return False @staticmethod def SetTableClient(thisClient): Table._CurrentClient = thisClient print("Info: Table Client Set, connection state is: ", Table._CurrentClient.is_connected()) @staticmethod def AddTableRowHash(TableName, hashkey): refTable = None for indexTable in Table._registry: #Get the table object reference print("indexTab:", indexTable.TableName, "table passed:", TableName) if (indexTable.TableName == TableName): refTable = indexTable break if refTable == None: print("Error: ", TableName, " not found in Table _registry!") return False #refTable.Rows.append(hashkey) refTable.Rows.add(hashkey) # Sorted return True @staticmethod def RemoveAllTables(thisClient, Force = False): if Force == False: confirmation = input("Do you want to remove all tables? (yes to confirm): ") if confirmation != "yes": print("Note: Ignoring call to RemoveAllTables") return False records_removed = 0 for indexTable in Table._registry: #Get the table object reference print("Removing rows in Table ", indexTable.TableName) for thisRowHash in indexTable.Rows: key = (indexTable.NameSpace, indexTable.TableName, None, thisRowHash) try: thisClient.remove(key) records_removed+=1 except Exception as e: print("error: {0}".format(e), sys.stderr) print("Error occured after ", records_removed, "records removed") Table._registry.clear() print("Note: ", records_removed, "records removed") return @staticmethod def RemoveDBStructure(thisClient, DBName): confirmation = input("Do you want to remove the DB structure? (yes to confirm): ") if confirmation != "yes": print("Note: Ignoring call to RemoveDBStructure") return False try: hashkey = thisClient.get_key_digest(Table._NameSpace, DBName, "Structure") key = (Table._NameSpace, DBName, None, hashkey) thisClient.remove(key) except Exception as e: print("Error attempting to remove ", DBName, "structure record.") print("error: {0}".format(e), sys.stderr) return #constructor def __init__(self, namespace, name): self.NameSpace = namespace self.TableName = name self.Columns = [] # List of colNames for this Table. self.PK = [] # List of colNames that are to be used as PKs for this Table. self.FK = [] # List of dicts composed of {'colName': someColName, 'refTable': <obj ref>, 'refColName': otherColName} self.FKnames = [] # simplified list of FK colNames for this Table self.FKtables = {} # dict with FK table names as keys. Values are the FKTable objects self.Rows = SortedList() # sortedcontainers version of a List. Values are hashkeys of DB records in this Table. self._registry.append(self) def AddCol(self, colName): if colName not in self.Columns: self.Columns.append(colName) return True else: print("ERROR: AddCol with column name ", colName, "ignored: Name already exists.") return False def AddPK(self, colName): if colName in self.Columns: self.PK.append(colName) return True else: print("ERROR: AddPK with column name ", colName, " ignored: Not a Column in Table.") return False def AddFK(self, colName, refTableName, refColName): #Verify FK name exists in table if colName not in self.Columns: print("ERROR: AddFK with column name", colName, "to", refColName, "ignored: Not a Column in Table.") return False #Verify Ref table exists #Determine refTable (object) from refTableName refTable = None for indexTable in Table._registry: if (indexTable.NameSpace == self.NameSpace) and (indexTable.TableName == refTableName): refTable = indexTable if not refTable: print("ERROR: AddFK with ref Table name ", refTableName, " ignored: Not a Table in Namespace.") return False #Verify Ref column exits in Ref table if refColName in refTable.Columns: #Add FK information to FK[] list of dict FKdict = {"colName":colName, "refTable":refTable, "refColName":refColName} self.FK.append(FKdict) #Create an FK Table for this FK, regardless of UseFKTables setting. self.FKtables[str(colName+"_FKTable")] = Table("test_DB_SEF", colName+"_FKTable") self.FKtables[str(colName+"_FKTable")].AddCol(colName+"_FKPK") self.FKtables[str(colName+"_FKTable")].AddCol("TableHashes") self.FKtables[str(colName+"_FKTable")].AddPK(colName+"_FKPK") return True else: print("ERROR: AddFK with ref Column name ", refColName, " ignored: Not a Column in ", refTableName) return False def VerifyStructure(self, listOfValues): if len(listOfValues) != len(self.Columns): print("VerifyStructure with attributes ", listOfValues, " failed: Incorrect number of attributes.") return False else: return True def VerifyUniquePKval(self, listOfPKValues): #Need to be able to handle composite PK values. #1. Verify the # of values in listOfValues matches the # of values in self.PK[]. False = Error # -VerifyStructure checks length of attribs for entire table, this step just checks #PKs, which # means calling function has to isolate PK vals from rest of vals if type(listOfPKValues) != list: listOfPKValues = [listOfPKValues] if len(listOfPKValues) != len(self.PK): print("VerifyUniquePK with PKs ", listOfPKValues, " failed: Incorrect number of PK attributes.") return False #2. Verify that listOfValues is NOT in the list of PK values for this Table #Convert to hashkey hashkey = self.getHashKey(str(listOfPKValues)) if hashkey in self.Rows: if Table._Debug: print(listOfPKValues, " exists in the following row (VerifyUniquePKval):") print(self.Read_PK(listOfPKValues)) return False return True def VerifyFKvalExists(self, listOfFKValues): # Suspect this may not work for multiple FK values or FK refs multivariate PK if len(listOfFKValues) != len(self.FK): print("VerifyFKvalExists with FKs ", listOfFKValues, " failed: Incorrect number of FK attributes.") return False #1. When verifying FKs, need to group by refTable. #2. For the set of FKs for each refTable, get the hash and see if it exists in the refTables list of Rows (hashes) #3. When getting the hash, the PK values used will have to be in the same order as specified by the table structure # -FKs are not necessarily in the same order as their PK counterparts as defined in the tables #Orig: insertDict = dict(zip(self.Columns, listOfFKValues)) insertDict = dict(zip(self.FKnames, listOfFKValues)) if Table._Debug: print("VerifyFK insertDict", insertDict) # if this works, Replace the next 5 lines with self.refTableList, created in VerifyFKrefsCompletePK refTableList = [] # First, make a list of all referenced tables (it's not a 1:1 mapping of PK/FK to tables) for thisFK in self.FK: if thisFK["refTable"] not in refTableList: refTableList.append(thisFK["refTable"]) for thisRefTable in refTableList: PKValList = [] #Each ref'd table should have a value (named by FK) for each of its PKs for thisPK in thisRefTable.PK: #find the value of the corresponding FK in self Table, add it to the PKValList for thisFK in self.FK: if Table._Debug: print("VerifyFKvalExists for refTable, PK, FK: ", thisRefTable.TableName, thisPK, thisFK) if thisFK["refColName"] == thisPK: PKValList.append(insertDict[thisFK["colName"]]) if Table._Debug: print("VerifyFKvalExists, PKValList: ", PKValList) #Now PKValList should have a list of values that correspond to PKs in the refTable, in the same order #Check if they exist by looking for a fail on VerifyUniquePKval unique = thisRefTable.VerifyUniquePKval(PKValList) if unique: print("VerifyFKvalExists with FKs ", listOfFKValues, "failed, those values not in", thisRefTable.TableName) return False #If we get to here, all ref tables and PKs should have been checked return True def VerifyFKrefsCompletePK(self): if Table._VerifyConstraints: if len(self.FK) == 0: print(self.TableName, "has no FK constraints to verify.") return True refTableList = [] # First, make a list of all referenced tables (it's not a 1:1 mapping of PK/FK to tables) for thisFK in self.FK: if thisFK["refTable"] not in refTableList: refTableList.append(thisFK["refTable"]) if Table._Debug: print("refTables in FK list:", refTableList) # Next, for each table in that list, check all of its PKs are listed in this table's FK list for thisRefTable in refTableList: PKfound = 0 for thisPK in thisRefTable.PK: for thisFK in self.FK: if thisPK == thisFK["refColName"]: PKfound+=1 print("thisPK == thisFK refColName: ",thisPK, thisFK["refColName"], "with #PKs found", PKfound) if PKfound != len(thisRefTable.PK): print("ERROR:", self.TableName, "has incomplete FK refs to table",thisRefTable.TableName,"PKs, does not ref", thisPK) return False # If we're here, FK/PK refs are ok. Make a FKnames list for FKdict in self.FK: self.FKnames.append(FKdict["colName"]) return True def Report(self): print("NameSpace : ", self.NameSpace) print("Table Name : ", self.TableName) print("Columns : ", self.Columns, "Total: ", len(self.Columns)) print("PK Columns : ", self.PK, "Total: ", len(self.PK)) print("FK Columns : ", self.FK, "Total: ", len(self.FK)) print("FK Col Names : ", self.FKnames, "Total: ", len(self.FKnames)) print("FK Tables : ", self.FKtables) print("1st 3 Rows of ", len(self.Rows), " rows total:") for thisRow in self.Rows: if self.Rows.index(thisRow) >= 3: break print("Application storage ", thisRow) def Insert(self, listOfValues): currentPKvals = [] #Assume we've already verified PK name is in Columns #1st, build a list of PK values based on what's in listOfValues for PKcol in self.PK: #Take the index of PK from Columns and use that to grab same index from listOfValues currentPKvals.append(listOfValues[self.Columns.index(PKcol)]) if Table._Debug: print("PKvals from Insert: ", currentPKvals, "for Table: ", self.TableName) if Table._VerifyConstraints: if not self.VerifyStructure(listOfValues): print("Insert not successful, VerifyStructure failed") return False if not self.VerifyUniquePKval(currentPKvals): print("Insert failed! PK values not unique!") return False #if Table has FKkeys, VerifyFKs exist in the refTable if len(self.FK) > 0: currentFKvals = [] for FKcol in self.FK: currentFKvals.append(listOfValues[self.Columns.index(FKcol["colName"])]) if Table._Debug: print("FKvals from Insert: ", currentFKvals, "for Table: ", self.TableName) if not self.VerifyFKvalExists(currentFKvals): print("Insert failed! FK values don't exist in referenced Table!") return False #if all of the verifies are successful, insert the data. # -JSONize it (which means forming a dictionary by combining listOfValues with Columns names) string4JSON = dict(zip(self.Columns, listOfValues)) if Table._Debug: jsonString = json.dumps(string4JSON) print("JSON string:", jsonString) #Get the HK from the put. # Records are addressable via a tuple of (namespace, set, key) # key = (self.NameSpace, self.TableName, str(currentPKvals)) # but, can also pre-build the key and store it using the key digest hashkey = self.getHashKey(str(currentPKvals)) key = (self.NameSpace, self.TableName, None, hashkey) try: # Write a record #client.put(key, jsonString) Table._CurrentClient.put(key, string4JSON) except Exception as e: print("error: {0}".format(e), sys.stderr) #self.Rows.append(hashkey) # reglular List self.Rows.add(hashkey) # SortedList #And don't forget to update the FKTables if there are any FKs in this table if len(self.FK) > 0 and Table._UseFKTables: for FKcol in self.FK: currentFKval = [] currentFKval.append(listOfValues[self.Columns.index(FKcol["colName"])]) if Table._Debug: print("FKval from Insert: ", currentFKval, "for Table: ", self.TableName) print("FKTable is: ", self.FKtables[str(FKcol["colName"]+"_FKTable")].TableName) self.FKtables[str(FKcol["colName"]+"_FKTable")].Insert4FK(currentFKval, hashkey) #Verify, key or hashkey #if not self.VerifyFKvalExists(currentFKvals): # ??? Why is this here? return True def Insert4FK(self, FKValue, parent_keyValue): if Table._Debug: print("**** The parent_keyValue is", parent_keyValue) #We'll need the key regardless of whether or not this FKValue already exists hashkey = self.getHashKey(str(FKValue)) #Setting the PK key = (self.NameSpace, self.TableName, None, hashkey) #Check and see if the FKValue is among the PK values for this Table. if self.VerifyUniquePKval(FKValue): #For an FKTable, the FK in question is being used as the PK ## Yes (is unique)- It's not there yet so add a row with PK = FKValue, the keyValue as data to that row. try: # Store a record with the PK value (FK as PK) in one bin and start the list of parent keys in the next Table._CurrentClient.put(key, {self.PK[0]:FKValue, "TableHashes":[parent_keyValue]}) except Exception as e: print("error: {0}".format(e), sys.stderr) self.Rows.add(hashkey) # SortedList else: ## No (not unique) so we need to add the parent_keyValue to an existing row. ## Add the parent_keyValue to the list for this PK try: (key, metadata, record) = Table._CurrentClient.get(key) #Do an RMW, adding the parent key to the list currentHashList = SortedList(record["TableHashes"]) # Use these 2 lines for Sorted currentHashList.add(parent_keyValue) #currentHashList = record["TableHashes"] # Use these 2 lines for unsorted #currentHashList.append(parent_keyValue) Table._CurrentClient.put(key, {"TableHashes":currentHashList}) except Exception as e: print("error: {0}".format(e), sys.stderr) def Update(self, listOfValues): # This code is nearly identical to that of Insert function. See comments there for # basic understanding, comments here will be for changes due to Update vs Insert currentPKvals = [] for PKcol in self.PK: currentPKvals.append(listOfValues[self.Columns.index(PKcol)]) if Table._Debug: print("PKvals from Update: ", currentPKvals, "for Table: ", self.TableName) if Table._VerifyConstraints: if not self.VerifyStructure(listOfValues): print("Update not successful, VerifyStructure failed") return False # Note: Opposite logic as insert. Can't update a record that doesn't exist. # By skipping this verification (VerifyConstraints = False) Update becomes an Insert if self.VerifyUniquePKval(currentPKvals): print("Update failed! PK value doesn't currently exist in Table: ", self.TableName) return False #if Table has FKkeys, VerifyFKs (the updated values) exist in the refTable. Same as for Insert if len(self.FK) > 0: currentFKvals = [] for FKcol in self.FK: currentFKvals.append(listOfValues[self.Columns.index(FKcol["colName"])]) if Table._Debug: print("FKvals from Update: ", currentFKvals, "for Table: ", self.TableName) if not self.VerifyFKvalExists(currentFKvals): print("Update failed! FK values don't exist in referenced Table!") return False #if all of the verifies are successful, update the data. # -JSONize it (which means forming a dictionary by combining listOfValues with Columns names) string4JSON = dict(zip(self.Columns, listOfValues)) jsonString = json.dumps(string4JSON) if Table._Debug: print("JSON string:", jsonString) hashkey = self.getHashKey(str(currentPKvals)) key = (self.NameSpace, self.TableName, None, hashkey) try: # Write a record #client.put(key, jsonString) Table._CurrentClient.put(key, string4JSON) except Exception as e: print("error: {0}".format(e), sys.stderr) #No need to append a new hashkey since an update assumes we're not adding a new hashkey (just updating data) #self.Rows.add(hashkey) # SortedList return True def Delete(self, listOfPKValues): if self.VerifyUniquePKval(listOfPKValues): # A True result implies the values don't yet exist in the table. print("Error: PK values specfied don't exist in ", self.TableName) return False #hashkey to remove from FKTable or to use for generating a key to remove a record parent_hashkey = self.getHashKey(listOfPKValues) # Not using FKTables. Just remove the record if it exists. if Table._UseFKTables == False: if Table._Debug: print("****Removing record ", listOfPKValues) key = (self.NameSpace, self.TableName, None, parent_hashkey) try: Table._CurrentClient.remove(key) #delete from the database except Exception as e: print("Error: Attempted delete. Not using FKTables") print("Error: {0}".format(e), sys.stderr) self.Rows.remove(parent_hashkey) return True #Implemente Restricted Delete. If the PK value is referenced by another table, the Delete will fail. ## When a row delete request is made at the parent Table, remove the row hash value from all child # FK value rows. If no more hashes exist in that row, the row can be deleted. ## When a row delete request is made at another table that is referenced by the parent table with # an FK, search for that FK’s value in the corresponding FKTable child table. If it exists, block # the delete. If it does not exist, complete the delete. #Case 1: Calling table has FKs in it. if len(self.FK) > 0: if Table._Debug: print("****This table has FK refs (it's a parent):", self.TableName) #Put listOfPKValues in a dict with their colNames so we can easily refer to them FKPKDict = dict(zip(self.PK, listOfPKValues)) #Some of these are not FKs but they won't be requested. for thisFKref in self.FK: FKTableToUpdate = self.FKtables[str(thisFKref['colName'] + '_FKTable')] #If the FK is part of the PK, it's already in this dict. if thisFKref['colName'] in self.PK: checkFKvalue = FKPKDict[thisFKref['colName']] else: #If the FK isn't part of the PK, we'll need to loop up its value so we can find it in the FKTable. record = self.Read_PK(listOfPKValues) checkFKvalue = record[thisFKref['colName']] hashkey = FKTableToUpdate.getHashKey([checkFKvalue]) key = (FKTableToUpdate.NameSpace, FKTableToUpdate.TableName, None, hashkey) if Table._Debug: print("****Updating FKtable ", FKTableToUpdate.TableName) try: (key, metadata, record) = Table._CurrentClient.get(key) #Do an RMW, revmoing the parent key fromm the list currentHashList = SortedList(record["TableHashes"]) if len(currentHashList) > 1: currentHashList.remove(parent_hashkey) Table._CurrentClient.put(key, {"TableHashes":currentHashList}) else: Table._CurrentClient.remove(key) #delete from the database FKTableToUpdate.Rows.remove(hashkey) #delete from the target table's list of record hashkeys. except Exception as e: print("error: {0}".format(e), sys.stderr) print("Error trying to delete in ", FKTableToUpdate.TableName) return False #if all that went well, actually remove the requested record if Table._Debug: print("****Removing record ", listOfPKValues) key = (self.NameSpace, self.TableName, None, parent_hashkey) try: Table._CurrentClient.remove(key) #delete from the database except Exception as e: print("Error: Attempted delete. Using FKTables, this table has FKs") print("Error: {0}".format(e), sys.stderr) self.Rows.remove(parent_hashkey) else: #Case 2: Calling table has no FKs (but it might be referenced) if Table._Debug: print("****This table has no FK refs, checking for references:", self.TableName) #Need to determine if the values in the delete are used in a referencing table. If yes, terminate #the delete. Otherwise, let it continue. PKDict = dict(zip(self.PK, listOfPKValues)) PKvalueIsReferenced = False for thisTable in Table._registry: for thisFKref in thisTable.FK: # If there are 0 elements in FK[], the table isn't ref'g the calling table if thisFKref['refTable'] == self: # The calling table is refd by this table. Now look for this value checkPKvalue = PKDict[thisFKref['refColName']] FKTableToCheck = thisTable.FKtables[str(thisFKref['colName'] + '_FKTable')] if FKTableToCheck.VerifyUniquePKval(checkPKvalue) == False: #If False, then the value exists in the FKTable PKvalueIsReferenced = True if PKvalueIsReferenced: print("WARNING: Can't delete this record, it is referenced by another record!") return False else: if Table._Debug: print("****No references found, removing record", listOfPKValues) key = (self.NameSpace, self.TableName, None, parent_hashkey) try: Table._CurrentClient.remove(key) #delete from the database except Exception as e: print("Error: Attempted delete. Using FKTables, this table has no FKs") print("Error: {0}".format(e), sys.stderr) self.Rows.remove(parent_hashkey) return True def Read_PK(self, PKValues): hashkey = self.getHashKey(str(PKValues)) key = (self.NameSpace, self.TableName, None, hashkey) (key, metadata, record) = Table._CurrentClient.get(key) return record def Read_hashkey(self, hashkey): key = (self.NameSpace, self.TableName, None, hashkey) (key, metadata, record) = Table._CurrentClient.get(key) return record def getHashKey(self, PKValues): hashkey = Table._CurrentClient.get_key_digest(self.NameSpace, self.TableName, str(PKValues)) return hashkey
def test_remove_valueerror1(): slt = SortedList() slt.remove(0)
class AllocQueue: def __init__(self, s): assert(isinstance(s, Spec)) # dict of unfungible objects, sorted by paddr. self.unfun_objects = SortedDict(lambda x: -x) # dict of lists of fungible objects, indexed by size_bits self.objects = {} self.sizes = SortedList() # Sort objects secondarily by name, for two reasons: # 1. Determinism (Python sets are not ordered) # 2. Makes it more likely that objects of the same size will be # allocated contiguously by name. For example, the CAmkES DMA # allocator relies on this for large contiguous allocations. # Reversed because we are pushing onto a queue. for o in sorted(s.objs, key=lambda obj: obj.name, reverse=True): if hasattr(o, 'paddr') and o.paddr: self._push_unfun_obj(o) elif o.get_size_bits(): self._push_fun(o) def _push_fun(self, o): size_bits = o.get_size_bits() if size_bits in self.objects: self.objects[size_bits].append(o) else: self.objects[size_bits] = collections.deque([o]) self.sizes.add(size_bits) def pop_fun(self, size_bits): if size_bits in self.objects: popped = self.objects[size_bits].pop() if not len(self.objects[size_bits]): self.sizes.remove(size_bits) del self.objects[size_bits] return popped def _push_unfun_obj(self, o): if o.paddr in self.unfun_objects: old = self.unfun_objects[o.paddr] raise AllocatorException("Duplicate paddr 0x%x (%s, %s)" % (o.paddr, old.name, o.name)) self.unfun_objects[o.paddr] = o def pop_unfun(self): if self.unfun_objects: (paddr, obj) = self.unfun_objects.popitem() return obj return None def max_size(self): return self.sizes[len(self.sizes) - 1] def min_size(self): return self.sizes[0] def more_unfun(self): return len(self.unfun_objects) > 0 def more_fun(self, size_bits=0): if size_bits: return size_bits in self.objects and len(self.objects[size_bits]) return len(self.objects) > 0
def test_remove_valueerror3(): slt = SortedList([1, 2, 2, 2, 3, 3, 5]) slt.remove(4)
class Orderbook: # keeps a set of orders, representing a snapshot at a given time # Attributes: # timestamp - the time associated with the snapshot of market, automatically updated with latest order; # it's the USER's responsibility to maintain its integrity # order_dict - a dictionary mapping order_id to order, the container of Order objects # price_volume_dict - a dictionary mapping price # to a list of SortedList IDs # utilized to retrieve the corresponding order with cancel type, # delete the old key-value pair and rehash once an order has been modified. # ask_list - a list of ask orders' id sorted by price in increasing order # bid_list - a list of bid orders' id sorted by price in decrease order # assume in each "trade" update, only best price be executed. # error_tol - the tolerance of precision of remaining volume # update_counter - count the number of updates # Methods: # public: # initializer(Orderbook, np.ndarray, [optional] int timestamp) - initialize an Orderbook object from a number ndarray # execute_update(Orderbook, Update) - update an order and time stamp, can't execute an order before timestamp # ?todo: *design method allow information retrieval # show_head(int n == 5) - print the first n ask orders and n bid orders in the market # private: # _remove_order(Orderbook self, int id) - remove from order_dict, price_volume_dict, ask_list, bid_list # _cancel_order(Orderbook self, Update update) - match order; remove it # _trade_order(Orderbook self, Update update) - guaranteed to trade the "best order"; rehash/remove it if necessary # _place_order(Orderbook self, Update update) - create an new Order object; # hash into order_dict and price_volume_dict; # insert into ask_list and bid_list. # _check_timestamp_consistency(Orderbook self, Update update) - helper function, check if an update is consistent # with orderbook's timestamp. # i.e. update's time is ahead of orderbook's timestampt # _id_to_price - # _id_to_price_neg - # _id_to_birthtime - # _place_order_helper (Orderbook self, Order new_order) - add a new order to all containers # _add_to_pvdict # _remove_from_pv_dict #Note: cancel - 1, place - 2, trade - 3 def __init__(self, data, timestamp=0, error_tol=1e-6): #Input: a numpy ndarray, formatted as desired initial orders #Returns: #Modifies: # Initialize an orderbook object self.timestamp = timestamp self.order_dict = {} self.price_volume_dict = {} self.ask_list = SortedList(key=self._id_to_price) self.bid_list = SortedList(key=self._id_to_price_neg) self.error_tol = error_tol self.update_counter = 0 for i in range(data.shape[0]): order = Order(data[i, :]) id = order.get_id() self.order_dict[id] = order self._add_to_pvdict(order) if (order.get_is_bid()): self.bid_list.add(id) else: self.ask_list.add(id) def get_highest_bid_info(self): return { 'price': self.order_dict(self.bid_list[0]).get_price(), 'volume': self.order_dict(self.bid_list[0]).get_remaining() } def get_lowest_ask_info(self): return { 'price': self.order_dict(self.ask_list[0]).get_price(), 'volume': self.order_dict(self.ask_list[0]).get_remaining() } def _remove_order(self, id): #remove the order of given id from all containers in the Orderbook object #Input: id of the order to be removed #Returns: #Modifies: # order_dict - remove the key from order_dict; order object retained. # price_volume_dict - remove the key from price_volume_dict, order object retained. # ask_list - remove the key from ask_list; order object retained; original order retained. # bid_list - remove the key from bid_list; order object retained; original order retained. order = self.order_dict[id] self._remove_from_pv_dict(order.get_price(), order.get_remaining(), order.get_is_bid()) if (order.get_is_bid()): self.bid_list.remove(id) else: self.ask_list.remove(id) self.order_dict.pop(id) def _cancel_order(self, update): #cancel an order with corresponding (price, volume) pair with the update Object. #If multiple orders with the same key exist, the one with SMALLEST birthtime value will be cancelled #Input: an Update object; # assume update.reason = 1, i.e. the "reason" attribute of update should be "cancel" #Returns: #Modifies: # the value of (price, volume) key in self.price_volume_pair. # order_dict - corresponding key will be removed by _remove_order # price_volume_dict - corresponding key will be removed by _remove_order # ask_list - corresponding id will be removed by _remove_order # bid_list - corresponding id will be removed by _remove_order assert (update.reason == 1, "INCONSISTEN UPDATE REASON") assert (self._check_timestamp_consistency(update), "INCONSISTEN TIMESTAMPS, ATTEMPT TO EXECTUE PAST UPDATE") price = update.get_price() before_remaining = update.get_remaining() - update.get_delta() id = self._get_id_from_price_remaining(price, before_remaining, update.get_is_bid()) order = self.order_dict[id] if update.get_remaining() == 0: self._remove_order(id) order.modify(update) def _trade_order(self, update): # trade an order AT THE TOP OF ORDERBOOK. i.e either bid_list[0] or ask_list[0] will be modified. # remove the order from containers if full executed, otherwise rehash it in self.price_volume_dict. # it's USER's responsibility to check the assumption of trading rule if satisfied # Input: an Update object; # assume update.reason = 2, i.e. the "reason" attribute of update should be "trade" # Returns: # Modifies: # the trading Order object. # price_volume_dict - corresponding key will be removed by _remove_order # other containers - the order will be removed from all relevant containers if order.remaining == 0 assert (update.get_reason() == 3, "INCONSISTEN UPDATE REASON") assert (self._check_timestamp_consistency(update), "INCONSISTEN TIMESTAMPS, ATTEMPT TO EXECTUE PAST UPDATE") if (update.is_bid): id = self.bid_list[0] else: id = self.ask_list[0] id = self._get_id_from_price_remaining( update.get_price(), update.get_remaining() - update.get_delta(), update.get_is_bid()) order = self.order_dict[id] #assert (order.remaining == update.remaining - update.delta and order.price == update.price, # "INCONSISTENT TRADING PRICE/VOLUME ") if update.get_remaining() == 0.0: self._remove_order(id) order.modify(update) def _place_order(self, update): # place a new order into the orderbook. All containers shall be modified. # Input: an Update object; # assume update.reason = 2, i.e. the "reason" attribute of update should be "place" # Returns: # Modifies: # create trading Order object. # all containers assert (update.get_reason() == 2, "INCONSISTEN UPDATE REASON") assert (self._check_timestamp_consistency(update), "INCONSISTEN TIMESTAMPS, ATTEMPT TO EXECTUE PAST UPDATE") if (abs(update.get_remaining() - update.get_delta()) < self.error_tol): new_order = Order(update) self._place_order_helper(new_order) else: id = self._get_id_from_price_remaining( update.get_price(), update.get_remaining() - update.get_delta(), update.get_is_bid()) order = self.order_dict[id] order.modify(update) def _check_timestamp_consistency(self, update, match_time=True): # INPUT: an Update object # Returns: True if update's timestamps >= the orderbook's, False otherwise # Modifies: # self.timestamp will be modified if match_time is true if (update.get_timestamp() >= self.timestamp): if (match_time): self.timestamp = update.get_timestamp() return True else: return False def execute_update(self, update): # call an appropriate method to execute the order # Input: an Update object # Returns: # Modifies: # the orderbook self.update_counter += 1 if (update.get_reason() == 1): self._cancel_order(update) elif (update.get_reason() == 3): self._trade_order(update) elif (update.get_reason() == 2): self._place_order(update) else: raise Exception("INVALID UPDATE REASON, VALUE NOT IN {1, 2, 3}") def show_head(self, n=20): # print the first n bids and first n asks # Input: # n - optional, number of orders in each list to be printed # Returns: # Modifies: print("number of executed updates: ", self.update_counter) print("number of total orders: ", len(self.order_dict)) print("number of asks: ", len(self.ask_list)) print("number of bids: ", len(self.bid_list)) print("ASK: ") for i in range(n): id = self.ask_list[i] order = self.order_dict[id] print(i, "." + "price =", order.get_price(), "volume =", order.get_remaining()) print("BID: ") for i in range(n): id = self.bid_list[i] order = self.order_dict[id] print(i, "." + "price =", order.get_price(), "volume =", order.get_remaining()) def _id_to_price(self, id): return self.order_dict[id].get_price() def _id_to_price_neg(self, id): return -self.order_dict[id].get_price() def _id_to_birthtime(self, id): return self.order_dict[id].get_birthtime() def _place_order_helper(self, new_order): # helper function add a new order into an orderbook # Input: an order object # Returns: # Modifies: all containers id = new_order.get_id() self.order_dict[id] = new_order self._add_to_pvdict(new_order) if (new_order.get_is_bid()): self.bid_list.add(id) else: self.ask_list.add(id) def _id_to_remaining(self, id): return self.order_dict[id].get_remaining() def _add_to_pvdict(self, order): if order.get_price() not in self.price_volume_dict.keys(): self.price_volume_dict[order.get_price()] = SortedList() self.price_volume_dict[order.get_price()].add(order.get_id()) def _remove_from_pv_dict(self, price, remaining, is_bid): victim_list = self.price_volume_dict[price] for ele in victim_list: if (abs(self._id_to_remaining(ele) - remaining) < self.error_tol and self.order_dict[ele].get_is_bid() == is_bid): victim_list.remove(ele) return raise Exception("CANNOT FIND REMAINING VOLUME WITHIN THRESHOLD") def _get_id_from_price_remaining(self, price, remaining, is_bid): l = self.price_volume_dict[price] for ele in l: if (abs(self._id_to_remaining(ele) - remaining) < self.error_tol and self.order_dict[ele].get_is_bid() == is_bid): return ele print(self.timestamp) print("ids at this price", l) print("price and volume for orders at the same price:") for id in l: print(self.order_dict[id].get_price(), self.order_dict[id].get_remaining()) print(price, remaining) raise Exception("CANNOT FIND REMAINING VOLUME WITHIN THRESHOLD")
def test_remove_valueerror2(): slt = SortedList(range(100), load=10) slt.remove(100)
class Analyse(): def __init__(self, config): self.year = config['YEAR'] self.startTime = config['START_TIME'] self.endTime = config['END_TIME'] self.stockList = config['STOCK_LIST'] self.mode = config['MODE'] self.logBucket = config['LOG_BUCKET_DATA'] self.hedgeFlag = config['HEDGE'] self.hedgeStock = config['HEDGE_STOCK'] self.divideByVol = config['DIVIDE_BY_VOLATILITY'] self.modStockList = self.stockList if (self.hedgeFlag): self.betaCorrelation = config['BETA_CORR'] self.modStockList = [config['HEDGE_STOCK']] + self.stockList self.corrFlag = config['BETA_CORR_TYPE'] if (self.mode == 'bucket'): self.bucketSize = config['BUCKET_SIZE'] self.numBucket = config['NUM_BUCKET'] elif (self.mode == 'percentile'): self.bucketSize = config['BUCKET_SIZE'] self.minSize = config['MIN_SIZE'] self.maxSize = config['MAX_SIZE'] self.absFlag = config['ABS_FLAG'] config['STOCK_LIST'] = self.modStockList # Datastore contains functions to read and update prices self.dataStore = dataStore(config) config['STOCK_LIST'] = self.stockList # Class members containing relevant Statistics # self.results: Dictionary containing stock names as keys # Maps to a list of lists, where each list member # contains gapSize, timeStamp, Open/Close prices # along with holding periods, etc self.results = {} self.gapListNormalized = [] self.prevCloseVWAPWindow = config['VWAP_PREV_CLOSE_WINDOW'] self.currOpenVWAPWindow = config['VWAP_CURR_OPEN_WINDOW'] self.posEntryVWAPWindow = config['VWAP_POSITION_ENTRY_WINDOW'] self.posExitVWAPWindow = config['VWAP_POSITION_EXIT_WINDOW'] self.printFlag = 0 self.stopLoss = config['STOP_LOSS'] self.targetPrice = config['TARGET_PRICE'] self.tTestFlag = config['T_TEST_FLAG'] if (self.tTestFlag): self.profitByGapPercentile = {} for i in range(0, 100): self.profitByGapPercentile[i] = [] self.stockReturns = {} def loadData(self): ''' Loads price data for the specified year and stock list Returns: None, only class members are modified ''' self.dataStore.loadPriceData() for stock in self.stockList: price = pd.DataFrame( self.dataStore.priceDataList[stock][:]).iloc[:, 6] returns = ((price / price.shift(1)) - 1)[1:] self.stockReturns[stock] = returns if (self.hedgeFlag): price = pd.DataFrame( self.dataStore.priceDataList[self.hedgeStock][:]).iloc[:, 6] returns = ((price / price.shift(1)) - 1)[1:] self.stockReturns[self.hedgeStock] = returns # print(self.stockReturns) def getRetList(self, stock): price = pd.DataFrame( self.dataStore.priceDataList[stock][::minInDay]).iloc[:, 6] price = ((price / price.shift(1)) - 1)[1:] return price def getBenchmarkVolatility(self): price = pd.DataFrame( self.hedgePriceList[self.hedgeStock][::minInDay]).iloc[:, 6] price = ((price / price.shift(1)) - 1)[1:] return price def getVolatilityNDays(self, stock, n, currTimeRow): """ Gets the volatility by taking returns of close prices for the last n days and does P(t) / P(t-1) - 1 for each of the n days and takes stDev """ # price = pd.DataFrame(self.dataStore.priceDataList[stock][currTimeRow - 1 - (n * 375):currTimeRow - 1]).iloc[:, 6] # returns = ((price / price.shift(1)) - 1)[1:] returns = self.stockReturns[stock].iloc[currTimeRow - 1 - (n * 375):currTimeRow - 1] if (debug): print("Volatility: " + str(np.std(returns))) return np.std(returns) def getCorrelation(self, stock1, stock2, i1, i2, n): """ Takes the prices of two stocks, calculates their return and gives their correlation """ # price1 = pd.DataFrame(self.dataStore.priceDataList[stock1][-(n * 375) - 1 + i1:i1]).iloc[:, 6] # price2 = pd.DataFrame(self.dataStore.priceDataList[stock2][-(n * 375) - 1 + i2:i2]).iloc[:, 6] returns1 = self.stockReturns[stock1].iloc[i1 - 1 - (n * 375):i1 - 1] returns2 = self.stockReturns[stock2].iloc[i2 - 1 - (n * 375):i2 - 1] print(i1, i2) # print(returns1[-10:]) # print(returns2[-10:]) # if(len(price1) > len(price2)): # # print("Price1: " + str(price1)) # # print("Price2: " + str(price2)) # price1 = price1[-len(price2):] # print(i1,i2,len(price1),len(price2)) # if(len(price2) > len(price1)): # price2 = price2[-len(price1):] # print(i1,i2,len(price1),len(price2)) correlation = np.corrcoef(returns1, returns2)[1][0] return correlation def getVolAvgPrice(self, stock, left, right): ''' Computes the volume weighted price for the range [left, right) price = (low + high + open + close)/4 ''' if (debug): print('\n' + ''.join(['*'] * 50)) print("Stock prices") print(left, right) print("Left price: " + str(self.dataStore.priceDataList[stock][left])) print("Right price: " + str(self.dataStore.priceDataList[stock][right])) price = np.array(self.dataStore.priceDataList[stock][left:right])[:, 5:] price = price.astype(np.float64) # 5, 6, 7, 8, 9: Open, Close, Low, High, Volume # After trimming off strings, 0, 1, 2, 3, 4: Opne, Close, Low, High, Volume avgPrice = (price[:, 0] + price[:, 1] + price[:, 2] + price[:, 3]) / 4.0 volume = price[:, 4] volAvgPrice = np.average(avgPrice, weights=volume) return volAvgPrice def getTTestScores(self, boundary, profitByGapPercentileLocal, verbose=False): #Returns the T test score and p-value of two arrays arr1 = [] arr2 = [] for i in range(1, boundary + 1): arr1 += profitByGapPercentileLocal[i] for i in range(boundary + 1, 99): arr2 += profitByGapPercentileLocal[i] tTest = ttest_ind(arr1, arr2) tValue, pValue = tTest[0], tTest[1] if (verbose): print("Boundary: " + str(boundary)) print("T Value: " + str(tValue)) print("P Value: " + str(pValue)) return tValue, pValue def getGapStats(self, holdPeriodList, volType='nGapVol', verbose=False): ''' Gives the statistics (Gap trading) for all hold periods specified The stats include timestamp, curr open price (after VWAP), prev close price (after VWAP), volatility holding period (H), min price/max price in interval, closing price after H etc Args: holdPeriodList: Contains holding periods as number of minutes volType; dailyVol or nDayVol (n = 30 by default) Returns: Dictionary as described above ''' statList = {} priceList = {} gapList = {} if (self.hedgeFlag): # BM is benchmark gapListBM = [] volListBM = [] timeListBM = [] # retList contains daily returns retListBM = [] priceListBM = [] priceTimeBM = [] #Stores all the timestamps for which the benchmark is indexed benchmarkTimeStamps = [ eachList[0] for eachList in self.dataStore.priceDataList[self.hedgeStock] ] volN = 70 # For standard volatility calculation of gapsize if (volType != 'stdVol'): volN = 30 volDays = 70 # For standard volatility of entire calculation of returns for stock in self.modStockList: # Perform analysis for each stock infoList = self.dataStore.priceDataList[stock] statList[stock] = [] priceList[stock] = [] gapList[stock] = [] # gapListBenchmark[self.hedgeStock] = [] retList = self.getRetList(stock) prevTime = 0 print 'Currently analysing:', stock for i in range(len(infoList)): currTime = infoList[i][0] currTimeStamp = datetime.fromtimestamp(currTime) currDay = currTimeStamp.date() currHour = currTimeStamp.time().hour currMins = currTimeStamp.time().minute # Account for duplicates if (prevTime == currTime): continue prevTime = currTime if (not (self.startTime <= currTime <= self.endTime)): # Check if it is in the valid range continue if ((currHour == 9) and (currMins == 15)): # Checking for day starting time if (stock == 'SBIN' and currTimeStamp.date().day == 9 and currTimeStamp.date().month == 11 and self.year == 2016): self.printFlag = 1 if (debug): print('\n' + ''.join(['*'] * 50)) #getting prices for stock currOpen = self.getVolAvgPrice(stock, i, i + self.currOpenVWAPWindow) prevClose = self.getVolAvgPrice( stock, i - self.prevCloseVWAPWindow, i) posEntryPrice = self.getVolAvgPrice( stock, i + self.currOpenVWAPWindow, i + self.currOpenVWAPWindow + self.posEntryVWAPWindow) if ((self.hedgeFlag) and (self.hedgeStock == stock)): priceListBM.append(currOpen) priceTimeBM.append(currTime) priceList[stock].append(currOpen) gapList[stock].append((currOpen - prevClose) / prevClose) # Not enough samples to compute std dev, added five to handle edge cases if (len(gapList[stock]) < volN + 5): continue # Refers to the stats common accross the holding periods commStats = {} commStats['time'] = currTime commStats['readableTime'] = datetime.fromtimestamp( currTime) commStats['ticker'] = stock commStats['currOpen'] = currOpen commStats['prevClose'] = prevClose commStats['posEntryPrice'] = posEntryPrice commStats['gapSize'] = ((currOpen - prevClose) / prevClose) if (self.absFlag): commStats['gapSize'] = np.abs(commStats['gapSize']) if (volType == 'stdVol'): commStats['volatility'] = np.std( retList[len(gapList[stock]) - volN:len(gapList[stock])]) else: commStats['volatility'] = np.std( gapList[stock][-volN:]) commStats['gapRatio'] = commStats['gapSize'] / commStats[ 'volatility'] #correct volatility using stDev of returns for 70 days of per minute returns commStats['stockVolatility'] = self.getVolatilityNDays( stock, volDays, i) if (self.hedgeFlag): if (stock != self.hedgeStock): # Binary search in the timeStamps of the benchmark row = bisect(timeListBM, currTime) - 1 retBM = retListBM[row] volBM = volListBM[row] bmI = bisect_left(benchmarkTimeStamps, currTime) posEntryBM = self.getVolAvgPrice( self.hedgeStock, bmI + self.currOpenVWAPWindow, bmI + self.currOpenVWAPWindow + self.posEntryVWAPWindow) #modifying volatility commStats[ 'indexVolatility'] = self.getVolatilityNDays( self.hedgeStock, volDays, bmI) if (debug): #Prints the timestamps of both the current stock row and the current benchmark row print( self.dataStore.priceDataList[stock][i][0], self.dataStore.priceDataList[ self.hedgeStock][bmI][0]) commStats['posEntryPriceBM'] = posEntryBM if (self.corrFlag != 'constant'): priceRow = bisect(priceTimeBM, currTime) # print len(priceList[stock][-volN:]) # print len(priceList[stock]) # print len(priceListBM) # print -volN + priceRow, priceRow # print len(priceList[self.hedgeStock][-volN + priceRow: priceRow]) self.betaCorrelation = np.corrcoef( priceList[stock][-volN:], priceListBM[-volN + priceRow:priceRow])[1][0] # self.betaCorrelation = self.getCorrelation(stock,self.hedgeStock,i,bmI,volDays) # beta = self.betaCorrelation * (volBM / commStats['volatility']) # beta = self.betaCorrelation * (commStats['volatility'] / volBM) beta = self.betaCorrelation * ( commStats['stockVolatility'] / commStats['indexVolatility']) if (debug): print("Stock Volatility: " + str(commStats['stockVolatility'])) print("Index Volatility: " + str(commStats['indexVolatility'])) commStats['betaCorr'] = self.betaCorrelation commStats['Beta'] = beta if (verbose): print(''.join(['*'] * 50)) print("Beta : " + str(beta)) print("Stock : " + stock) print("Stock currOpen: " + str(currOpen)) print("Stock prevClose: " + str(prevClose)) print("Stock Return: " + str(commStats['gapSize'])) print("Stock Volatility: " + str(commStats['volatility'])) print("Stock Normalized Return: " + str(commStats['gapRatio'])) print("Benchmark Return: " + str(retBM)) print("Benchmark Volatility: " + str(volBM)) print("Benchmark Normalized Return: " + str(retBM / volBM)) else: timeListBM.append(currTime) retListBM.append(commStats['gapSize']) volListBM.append(commStats['volatility']) minPriceList = [float(infoList[i][6])] maxPriceList = [float(infoList[i][6])] # Identifying the array index limit holdLim = min(max(holdPeriodList), len(infoList) - i - 1) for j in range(holdLim): minPriceList.append( min(minPriceList[-1], float(infoList[i + j][6]))) maxPriceList.append( max(maxPriceList[-1], float(infoList[i + j][6]))) #Appending volatility normalized gap value for determining distribution plot self.gapListNormalized.append(commStats['gapSize'] / commStats['volatility']) reachedStopOrTarget = 0 stopOrTargetRelReturn = 0 for hold in holdPeriodList: tmpStats = commStats.copy() minPrice = minPriceList[min(hold, holdLim)] maxPrice = maxPriceList[min(hold, holdLim)] tmpStats['holdPeriod'] = hold tmpStats['min'] = minPrice tmpStats['max'] = maxPrice tmpStats['finClose'] = infoList[min( (i + self.currOpenVWAPWindow + self.posEntryVWAPWindow + hold), len(infoList) - 1)][6] #Normalizing the volatility based on hold period tmpStats['stockVolAfterNorm'] = commStats[ 'stockVolatility'] * np.sqrt(hold) if (self.hedgeFlag): bmI = bisect_left(benchmarkTimeStamps, currTime) tmpStats['finCloseBM'] = self.dataStore.priceDataList[self.hedgeStock][min((bmI + self.currOpenVWAPWindow + self.posEntryVWAPWindow + hold)\ , len(self.dataStore.priceDataList[self.hedgeStock]) - 1)][6] # exitTime = infoList[i + hold][0] # bmI = bisect_left(benchmarkTimeStamps, exitTime) # tmpStats['finCloseBM'] = self.dataStore.priceDataList[self.hedgeStock][min((bmIExit), len(infoList) -1)][6] if (not (stock == self.hedgeStock)): #Calculating profits and all tmpStats['profit'] = ((- np.sign(tmpStats['currOpen'] - tmpStats['prevClose'])) * \ ((tmpStats['finClose'] - tmpStats['posEntryPrice']) / tmpStats['posEntryPrice'])) tmpStats['absReturn'] = tmpStats['profit'] tmpStats['absReturnPerUnitVol'] = tmpStats[ 'absReturn'] / tmpStats['stockVolAfterNorm'] if (self.hedgeFlag): tmpStats['marketReturn'] = ( (tmpStats['finCloseBM'] - tmpStats['posEntryPriceBM']) / tmpStats['posEntryPriceBM']) tmpStats['returnOffset'] = ( (tmpStats['finCloseBM'] - tmpStats['posEntryPriceBM']) / tmpStats['posEntryPriceBM'] ) * tmpStats['Beta'] tmpStats['relReturn'] = tmpStats['profit'] + ( np.sign(tmpStats['currOpen'] - tmpStats['prevClose']) * tmpStats['returnOffset']) tmpStats['relReturnPerUnitVol'] = tmpStats[ 'relReturn'] / tmpStats['stockVolAfterNorm'] if ((tmpStats['relReturn'] <= self.stopLoss or tmpStats['relReturn'] >= self.targetPrice) and reachedStopOrTarget == 0): reachedStopOrTarget = 1 stopOrTargetRelReturn = tmpStats[ 'relReturn'] if (reachedStopOrTarget): tmpStats[ 'relReturnWithStopLoss'] = stopOrTargetRelReturn else: tmpStats[ 'relReturnWithStopLoss'] = tmpStats[ 'relReturn'] else: tmpStats['relReturn'] = tmpStats['profit'] tmpStats['relReturnPerUnitVol'] = tmpStats[ 'absReturnPerUnitVol'] # tmpStats['profitDividedByVol'] = tmpStats['relReturn'] / tmpStats['stockVolAfterNorm'] if (self.printFlag == 1): for key in tmpStats: print(key + ": " + str(tmpStats[key])) statList[stock].append(tmpStats) if (not (stock == self.hedgeStock)): grandDict[stock].append(tmpStats) # grandDF.append(tmpStats) self.printFlag = 0 self.results = statList # print sorted([statList[key][x]['gapRatio'] for key in statList.keys() for x in range(len(statList[key]))])[-1000:-900] return statList def compileResults(self, holdPeriodList): ''' Compile the results extracted from getGapStats() The rows are indexed with RELATIVE RANK The columns are Count (For all stocks), also compute stock based results. E, P, R fraction. Win Rate: The fraction of actual fades Anti: Average profit on winning fade trades With: Average loss on losing fade trades Exp: Expectation of profit Args: Hold period list, should be consistent with getGapStats() Returns: Matrix with the following column convention 0: Count, 1: E, 2: P, 3: R, 4: P(S), 5: Anti, 6: With, 7:Exp ''' self.timeWiseStats = {} self.cumStats = {} for hold in holdPeriodList: # numStocks rows, column mapping is given above if self.mode == 'relative': numRows = len(self.stockList) elif self.mode == 'percentile': numRows = int(100 / self.bucketSize) else: numRows = (2 * self.numBucket) + 1 self.cumStats[hold] = np.zeros((numRows, 8)) self.timeWiseStats[hold] = {} if (self.logBucket): # Stores a list of timetamps for each bucket # Stores a list of self.bucketTimeList = [] self.bucketTradeList = [] tmpDict = {key: [] for key in self.stockList} for i in range(numRows): self.bucketTimeList.append(list()) self.bucketTradeList.append(tmpDict.copy()) for stockId in range(len(self.stockList)): stock = self.stockList[stockId] for i in range(len(self.results[stock])): tmpStats = self.results[stock][i] time = tmpStats['time'] hold = tmpStats['holdPeriod'] currOpen = tmpStats['currOpen'] prevClose = tmpStats['prevClose'] minPrice = tmpStats['min'] maxPrice = tmpStats['max'] finClose = tmpStats['finClose'] gapRatio = tmpStats['gapRatio'] gapSize = tmpStats['gapSize'] # volatility= tmpStats['volatility'] posEntry = tmpStats['posEntryPrice'] finClose = tmpStats['finClose'] volatility = tmpStats['stockVolAfterNorm'] #Hedging support, not technically hedging, just offsetting with respect to the index return if (self.hedgeFlag): hedge = ( (tmpStats['finCloseBM'] - tmpStats['posEntryPriceBM']) / tmpStats['posEntryPriceBM']) * tmpStats['Beta'] if (self.divideByVol): hedge /= volatility # Initial 8 elements represent the standard stats # The last ones will be used for ranking later tmpArr = np.zeros(12) tmpArr[0] += 1 tmpArr[8] = stockId tmpArr[9] = gapSize tmpArr[10] = gapRatio tmpArr[11] = stockId targetPrice = finClose profit = ((-np.sign(currOpen - prevClose)) * ((targetPrice - posEntry) / posEntry)) if (self.divideByVol): profit /= volatility if (self.hedgeFlag): profit -= (-np.sign(currOpen - prevClose)) * hedge fillFlag = np.sign(profit) if (fillFlag < 0): # Refers to the E case i.e. extension tmpArr[1] += 1 tmpArr[6] += profit else: if ((currOpen - prevClose) * (prevClose - targetPrice) < 0): # Refers to the P case i.e. partial fill tmpArr[2] += 1 else: # Refers to the R case i.e. reversal tmpArr[3] += 1 # Adding profits tmpArr[5] += profit # Adding the result to the corresponding time in the dict if (time not in self.timeWiseStats[hold]): self.timeWiseStats[hold][time] = [] self.timeWiseStats[hold][time].append(tmpArr) for hold in holdPeriodList: if self.mode == 'percentile': minSize = self.minSize maxSize = self.maxSize self.gapQueue = deque([], maxlen=maxSize) self.orderedGaps = SortedList(load=50) for time in sorted(self.timeWiseStats[hold].keys()): if (self.mode == 'relative'): # Sort the list according to the magnitude of gap size self.timeWiseStats[hold][time].sort( key=lambda x: np.abs(x[-1]), reverse=True) for i in range(len(self.timeWiseStats[hold][time])): self.cumStats[hold][i] += self.timeWiseStats[hold][ time][i][:8] elif (self.mode == 'percentile'): newGapLen = len(self.timeWiseStats[hold][time]) newValList = [] # If there are enough elements for identifying percentile if (len(self.gapQueue) >= minSize): for i in range(newGapLen): searchKey = self.timeWiseStats[hold][time][i][10] # if (self.absFlag): # searchKey = np.abs(searchKey) percentile = self.orderedGaps.bisect_left( searchKey) currSize = len(self.gapQueue) # To avoid having percentile as 1.0, since percentile <= percSize + 1 percentile = percentile / (currSize + 2.0) row = int(percentile * int(100 / self.bucketSize)) # print row self.cumStats[hold][row] += self.timeWiseStats[ hold][time][i][:8] if (self.tTestFlag): self.profitByGapPercentile[int( percentile * 100)].append( self.timeWiseStats[hold][time][i][5] + self.timeWiseStats[hold][time][i][6]) if (self.logBucket): # Adding time to this bucket's list self.bucketTimeList[row].append(time) # Since at least one of these is zero, by construction profit = self.timeWiseStats[hold][time][i][ 5] + self.timeWiseStats[hold][time][i][6] stockId = int( self.timeWiseStats[hold][time][i][11]) self.bucketTradeList[row][ self.stockList[stockId]].append(profit) bucketTradeListGlobal[row][ self.stockList[stockId]].append(profit) # Updating the queue and removing elements from the tree for i in range(newGapLen): lastVal = self.gapQueue.popleft() self.orderedGaps.remove(lastVal) for i in range(newGapLen): searchKey = self.timeWiseStats[hold][time][i][10] # if (self.absFlag): # searchKey = np.abs(searchKey) newValList.append(searchKey) # Adding the new values to the queue simultaneously self.gapQueue.extend(newValList) # Adding the new values to the tree simultaneously self.orderedGaps.update(newValList) else: for i in range(len(self.timeWiseStats[hold][time])): # Sort the list according to the magnitude of gap size gapRatio = self.timeWiseStats[hold][time][i][10] # Get the position in the matrix, note that the bucket sizes are of size 10% bucket = int( np.sign(gapRatio) * int(np.abs(gapRatio * 10) / self.bucketSize)) bucket = int( np.sign(bucket) * self.numBucket ) if np.abs(bucket) >= self.numBucket else bucket row = self.numBucket + bucket self.cumStats[hold][row] += self.timeWiseStats[hold][ time][i][:8] def tTestWrapper(self, profitByGapPercentile, verbose=True): """ Tries various boundary values and gets the stats for each value from 1..99 as the boundary for percentile and Perfroms T Test on the profits >=value and <=value arrays """ print(''.join(['*'] * 50)) print("Cumulative Stats") if (self.tTestFlag): for i in range(10, 100, 10): tValue, pValue = self.getTTestScores(i, profitByGapPercentile) if (verbose): print("Boundary: " + str(i)) print("T Value: " + str(tValue)) print("P Value: " + str(pValue)) def getProfitGapPercentile(self): return self.profitByGapPercentile def finalizeStats(self, holdPeriodList): ''' Finally processes the stats matrices, note that the resulting matrices cannot be compiled again directly as frequencies have become probs ''' for hold in holdPeriodList: self.cumStats[hold] = processStatMatrix(self.cumStats[hold]) def plotDistribution(self, plotSeries, saveAsFile=False, logValues=False): ''' Plots a histogram for the given plotsSeries Args: saveAsFile: Whether to save to file or plotting on screen logValues: Whether the y axis is log scaled Return: None, side effects could include saving a file ''' stDev = np.std(plotSeries) #xLabels from -3*sigma to 3*sigma xLabels = np.array(range(-3, 4)) * stDev plt.figure(figsize=(100, 100)) fig, ax = plt.subplots(1, 1) axes = plt.gca() plt.hist(plotSeries, bins=100, log=logValues) plt.xlabel("Normalized Gap Size") plt.ylabel("Number of Gap Sizes") axes.set_xlim([xLabels[0] - 0.5, xLabels[-1] + 0.5]) ax.set_xticks(xLabels) plt.tight_layout() if (saveAsFile): plt.savefig("results/gapDistribution.svg") else: plt.show()
class Log(object): """Keep a random sample of stuff seen so far. Based on Dr. Menzies' implementation.""" MAX_SIZE = 256 def __init__(self, inits=None, label=None, max_size=MAX_SIZE): self._cache = SortedList() self._report = None self.label = label or '' self._n = 0 self.max_size = max_size self._valid_statistics = False self._invalidate_statistics() if inits: map(self.__iadd__, inits) def random_index(self): return base.random_index(self._cache) @classmethod def wrap(cls, x, max_size=MAX_SIZE): if isinstance(x, cls): return x return cls(inits=x, max_size=max_size) def __len__(self): return len(self._cache) def extend(self, xs): if not isinstance(xs, collections.Iterable): raise TypeError() map(self.__iadd__, xs) def __iadd__(self, x): if x is None: return x self._n += 1 if issubclass(x.__class__, Log): map(self.__iadd__, x._cache) return self changed = False # if cache has room, add item if self.max_size is None or len(self._cache) < self.max_size: changed = True self._cache.add(x) # cache is full: maybe replace an old item else: # items less likely to be replaced later in the run: # leads to uniform sample of entire run if random.random() <= self.max_size / len(self): changed = True self._cache.remove(random.choice(self._cache)) self._cache.add(x) if changed: self._invalidate_statistics() self._change(x) return self def __add__(self, x, max_size=MAX_SIZE): inits = itertools.chain(self._cache, x._cache) return self.__class__(inits=inits, max_size=max_size) def any(self): return random.choice(self._cache) def report(self): if self._report is None: self._report = self._generate_report() return self._report def setup(self): raise NotImplementedError() def as_list(self): return self._cache.as_list() def _invalidate_statistics(self): ''' default implementation. if _valid_statistics is something other than a boolean, reimplement! ''' self._valid_statistics = False def ish(self, *args, **kwargs): raise NotImplementedError() def _change(self, x): ''' override to add incremental updating functionality ''' pass def _prepare_data(self): s = '_prepare_data() not implemented for ' + self.__class__.__name__ raise NotImplementedError(s) def __iter__(self): return iter(self._cache) def contents(self): return self._cache.as_list()
class ExamRoom: def __init__(self, n: int): self.treemap1 = SortedList( ) # 每个元素为(区间大小,start,end),因为区间大小为偶数时与减少一时相同,所以要减少1 self.treemap2 = SortedList() # 每个元素为(start,end),主要目的是在leave时,定位p所在的区间 self.n = n self.treemap1.add((self.distance(0, n - 1), 0, n - 1)) self.treemap2.add((0, n - 1)) def distance(self, start, end): d = end - start + 1 if d % 2 == 0: d -= 1 return d def seat(self) -> int: r = self.treemap1.pop() # 弹出最大的区间 size, start, end = r start = -start self.treemap2.remove((start, end)) if start == 0: # 左边界 p = 0 if end >= 1: rr = (self.distance(1, end), -1, end) self.treemap1.add(rr) self.treemap2.add((1, end)) elif end == self.n - 1: # 右边界 p = self.n - 1 if p - 1 >= start: lr = (self.distance(start, p - 1), -start, p - 1) self.treemap1.add(lr) self.treemap2.add((start, p - 1)) else: # 选中了中间的区间,区间可以拆分成2个 p = (start + end) // 2 if p > start: lr = (self.distance(start, p - 1), -start, p - 1) self.treemap1.add(lr) self.treemap2.add((start, p - 1)) if p < end: rr = (self.distance(p + 1, end), -(p + 1), end) self.treemap1.add(rr) self.treemap2.add((p + 1, end)) return p def leave(self, p: int) -> None: midRange = (p, p) idx = self.treemap2.bisect_left(midRange) if idx > 0: leftRange = self.treemap2[idx - 1] if leftRange[1] == p - 1: # 左侧区间与p相邻,进行合并 self.treemap2.remove(leftRange) self.treemap1.remove( (self.distance(leftRange[0], leftRange[1]), -leftRange[0], leftRange[1])) midRange = (leftRange[0], p) idx -= 1 if idx < len(self.treemap2): rightRange = self.treemap2[idx] if rightRange[0] == p + 1: # 右侧区间与p相邻,进行合并 self.treemap2.remove(rightRange) self.treemap1.remove((self.distance(rightRange[0], rightRange[1]), -rightRange[0], rightRange[1])) midRange = (midRange[0], rightRange[1]) self.treemap2.add(midRange) self.treemap1.add( (self.distance(midRange[0], midRange[1]), -midRange[0], midRange[1]))
def test_remove_valueerror1(): slt = SortedList() with pytest.raises(ValueError): slt.remove(0)
def test_remove_valueerror1(): slt = SortedList() with pytest.raises(ValueError): slt.remove(0)
def test_remove_valueerror3(): slt = SortedList([1, 2, 2, 2, 3, 3, 5]) with pytest.raises(ValueError): slt.remove(4)
def test_remove_valueerror2(): slt = SortedList(range(100)) slt._reset(10) slt.remove(100)
class Timeline: """ Ordered set of segments. A timeline can be seen as an ordered set of non-empty segments (Segment). Segments can overlap -- though adding an already exisiting segment to a timeline does nothing. Parameters ---------- segments : Segment iterator, optional initial set of (non-empty) segments uri : string, optional name of segmented resource Returns ------- timeline : Timeline New timeline """ @classmethod def from_df(cls, df: pd.DataFrame, uri: Optional[str] = None) -> 'Timeline': segments = list(df[PYANNOTE_SEGMENT]) timeline = cls(segments=segments, uri=uri) return timeline def __init__(self, segments: Optional[Iterable[Segment]] = None, uri: str = None): if segments is None: segments = () # set of segments (used for checking inclusion) segments_set = set(segments) if any(not segment for segment in segments_set): raise ValueError('Segments must not be empty.') self.segments_set_ = segments_set # sorted list of segments (used for sorted iteration) self.segments_list_ = SortedList(segments_set) # sorted list of (possibly redundant) segment boundaries boundaries = (boundary for segment in segments_set for boundary in segment) self.segments_boundaries_ = SortedList(boundaries) # path to (or any identifier of) segmented resource self.uri: str = uri def __len__(self): """Number of segments >>> len(timeline) # timeline contains three segments 3 """ return len(self.segments_set_) def __nonzero__(self): return self.__bool__() def __bool__(self): """Emptiness >>> if timeline: ... # timeline is empty ... else: ... # timeline is not empty """ return len(self.segments_set_) > 0 def __iter__(self) -> Iterable[Segment]: """Iterate over segments (in chronological order) >>> for segment in timeline: ... # do something with the segment See also -------- :class:`pyannote.core.Segment` describes how segments are sorted. """ return iter(self.segments_list_) def __getitem__(self, k: int) -> Segment: """Get segment by index (in chronological order) >>> first_segment = timeline[0] >>> penultimate_segment = timeline[-2] """ return self.segments_list_[k] def __eq__(self, other: 'Timeline'): """Equality Two timelines are equal if and only if their segments are equal. >>> timeline1 = Timeline([Segment(0, 1), Segment(2, 3)]) >>> timeline2 = Timeline([Segment(2, 3), Segment(0, 1)]) >>> timeline3 = Timeline([Segment(2, 3)]) >>> timeline1 == timeline2 True >>> timeline1 == timeline3 False """ return self.segments_set_ == other.segments_set_ def __ne__(self, other: 'Timeline'): """Inequality""" return self.segments_set_ != other.segments_set_ def index(self, segment: Segment) -> int: """Get index of (existing) segment Parameters ---------- segment : Segment Segment that is being looked for. Returns ------- position : int Index of `segment` in timeline Raises ------ ValueError if `segment` is not present. """ return self.segments_list_.index(segment) def add(self, segment: Segment) -> 'Timeline': """Add a segment (in place) Parameters ---------- segment : Segment Segment that is being added Returns ------- self : Timeline Updated timeline. Note ---- If the timeline already contains this segment, it will not be added again, as a timeline is meant to be a **set** of segments (not a list). If the segment is empty, it will not be added either, as a timeline only contains non-empty segments. """ segments_set_ = self.segments_set_ if segment in segments_set_ or not segment: return self segments_set_.add(segment) self.segments_list_.add(segment) segments_boundaries_ = self.segments_boundaries_ segments_boundaries_.add(segment.start) segments_boundaries_.add(segment.end) return self def remove(self, segment: Segment) -> 'Timeline': """Remove a segment (in place) Parameters ---------- segment : Segment Segment that is being removed Returns ------- self : Timeline Updated timeline. Note ---- If the timeline does not contain this segment, this does nothing """ segments_set_ = self.segments_set_ if segment not in segments_set_: return self segments_set_.remove(segment) self.segments_list_.remove(segment) segments_boundaries_ = self.segments_boundaries_ segments_boundaries_.remove(segment.start) segments_boundaries_.remove(segment.end) return self def discard(self, segment: Segment) -> 'Timeline': """Same as `remove` See also -------- :func:`pyannote.core.Timeline.remove` """ return self.remove(segment) def __ior__(self, timeline: 'Timeline') -> 'Timeline': return self.update(timeline) def update(self, timeline: Segment) -> 'Timeline': """Add every segments of an existing timeline (in place) Parameters ---------- timeline : Timeline Timeline whose segments are being added Returns ------- self : Timeline Updated timeline Note ---- Only segments that do not already exist will be added, as a timeline is meant to be a **set** of segments (not a list). """ segments_set = self.segments_set_ segments_set |= timeline.segments_set_ # sorted list of segments (used for sorted iteration) self.segments_list_ = SortedList(segments_set) # sorted list of (possibly redundant) segment boundaries boundaries = (boundary for segment in segments_set for boundary in segment) self.segments_boundaries_ = SortedList(boundaries) return self def __or__(self, timeline: 'Timeline') -> 'Timeline': return self.union(timeline) def union(self, timeline: 'Timeline') -> 'Timeline': """Create new timeline made of union of segments Parameters ---------- timeline : Timeline Timeline whose segments are being added Returns ------- union : Timeline New timeline containing the union of both timelines. Note ---- This does the same as timeline.update(...) except it returns a new timeline, and the original one is not modified. """ segments = self.segments_set_ | timeline.segments_set_ return Timeline(segments=segments, uri=self.uri) def co_iter(self, other: 'Timeline') -> Iterator[Tuple[Segment, Segment]]: """Iterate over pairs of intersecting segments >>> timeline1 = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 4)]) >>> timeline2 = Timeline([Segment(1, 3), Segment(3, 5)]) >>> for segment1, segment2 in timeline1.co_iter(timeline2): ... print(segment1, segment2) (<Segment(0, 2)>, <Segment(1, 3)>) (<Segment(1, 2)>, <Segment(1, 3)>) (<Segment(3, 4)>, <Segment(3, 5)>) Parameters ---------- other : Timeline Second timeline Returns ------- iterable : (Segment, Segment) iterable Yields pairs of intersecting segments in chronological order. """ for segment in self.segments_list_: # iterate over segments that starts before 'segment' ends temp = Segment(start=segment.end, end=segment.end) for other_segment in other.segments_list_.irange(maximum=temp): if segment.intersects(other_segment): yield segment, other_segment def crop_iter(self, support: Support, mode: CropMode = 'intersection', returns_mapping: bool = False) \ -> Iterator[Union[Tuple[Segment, Segment], Segment]]: """Like `crop` but returns a segment iterator instead See also -------- :func:`pyannote.core.Timeline.crop` """ if mode not in {'loose', 'strict', 'intersection'}: raise ValueError("Mode must be one of 'loose', 'strict', or " "'intersection'.") if not isinstance(support, (Segment, Timeline)): raise TypeError("Support must be a Segment or a Timeline.") if isinstance(support, Segment): # corner case where "support" is empty if support: segments = [support] else: segments = [] support = Timeline(segments=segments, uri=self.uri) for yielded in self.crop_iter(support, mode=mode, returns_mapping=returns_mapping): yield yielded return # loose mode if mode == 'loose': for segment, _ in self.co_iter(support): yield segment return # strict mode if mode == 'strict': for segment, other_segment in self.co_iter(support): if segment in other_segment: yield segment return # intersection mode for segment, other_segment in self.co_iter(support): mapped_to = segment & other_segment if not mapped_to: continue if returns_mapping: yield segment, mapped_to else: yield mapped_to def crop(self, support: Support, mode: CropMode = 'intersection', returns_mapping: bool = False) \ -> Union['Timeline', Tuple['Timeline', Dict[Segment, Segment]]]: """Crop timeline to new support Parameters ---------- support : Segment or Timeline If `support` is a `Timeline`, its support is used. mode : {'strict', 'loose', 'intersection'}, optional Controls how segments that are not fully included in `support` are handled. 'strict' mode only keeps fully included segments. 'loose' mode keeps any intersecting segment. 'intersection' mode keeps any intersecting segment but replace them by their actual intersection. returns_mapping : bool, optional In 'intersection' mode, return a dictionary whose keys are segments of the cropped timeline, and values are list of the original segments that were cropped. Defaults to False. Returns ------- cropped : Timeline Cropped timeline mapping : dict When 'returns_mapping' is True, dictionary whose keys are segments of 'cropped', and values are lists of corresponding original segments. Examples -------- >>> timeline = Timeline([Segment(0, 2), Segment(1, 2), Segment(3, 4)]) >>> timeline.crop(Segment(1, 3)) <Timeline(uri=None, segments=[<Segment(1, 2)>])> >>> timeline.crop(Segment(1, 3), mode='loose') <Timeline(uri=None, segments=[<Segment(0, 2)>, <Segment(1, 2)>])> >>> timeline.crop(Segment(1, 3), mode='strict') <Timeline(uri=None, segments=[<Segment(1, 2)>])> >>> cropped, mapping = timeline.crop(Segment(1, 3), returns_mapping=True) >>> print(mapping) {<Segment(1, 2)>: [<Segment(0, 2)>, <Segment(1, 2)>]} """ if mode == 'intersection' and returns_mapping: segments, mapping = [], {} for segment, mapped_to in self.crop_iter(support, mode='intersection', returns_mapping=True): segments.append(mapped_to) mapping[mapped_to] = mapping.get(mapped_to, list()) + [segment] return Timeline(segments=segments, uri=self.uri), mapping return Timeline(segments=self.crop_iter(support, mode=mode), uri=self.uri) def overlapping(self, t: float) -> List[Segment]: """Get list of segments overlapping `t` Parameters ---------- t : float Timestamp, in seconds. Returns ------- segments : list List of all segments of timeline containing time t """ return list(self.overlapping_iter(t)) def overlapping_iter(self, t: float) -> Iterator[Segment]: """Like `overlapping` but returns a segment iterator instead See also -------- :func:`pyannote.core.Timeline.overlapping` """ segment = Segment(start=t, end=t) for segment in self.segments_list_.irange(maximum=segment): if segment.overlaps(t): yield segment def __str__(self): """Human-readable representation >>> timeline = Timeline(segments=[Segment(0, 10), Segment(1, 13.37)]) >>> print(timeline) [[ 00:00:00.000 --> 00:00:10.000] [ 00:00:01.000 --> 00:00:13.370]] """ n = len(self.segments_list_) string = "[" for i, segment in enumerate(self.segments_list_): string += str(segment) string += "\n " if i + 1 < n else "" string += "]" return string def __repr__(self): """Computer-readable representation >>> Timeline(segments=[Segment(0, 10), Segment(1, 13.37)]) <Timeline(uri=None, segments=[<Segment(0, 10)>, <Segment(1, 13.37)>])> """ return "<Timeline(uri=%s, segments=%s)>" % (self.uri, list(self.segments_list_)) def __contains__(self, included: Union[Segment, 'Timeline']): """Inclusion Check whether every segment of `included` does exist in timeline. Parameters ---------- included : Segment or Timeline Segment or timeline being checked for inclusion Returns ------- contains : bool True if every segment in `included` exists in timeline, False otherwise Examples -------- >>> timeline1 = Timeline(segments=[Segment(0, 10), Segment(1, 13.37)]) >>> timeline2 = Timeline(segments=[Segment(0, 10)]) >>> timeline1 in timeline2 False >>> timeline2 in timeline1 >>> Segment(1, 13.37) in timeline1 True """ if isinstance(included, Segment): return included in self.segments_set_ elif isinstance(included, Timeline): return self.segments_set_.issuperset(included.segments_set_) else: raise TypeError( 'Checking for inclusion only supports Segment and ' 'Timeline instances') def empty(self) -> 'Timeline': """Return an empty copy Returns ------- empty : Timeline Empty timeline using the same 'uri' attribute. """ return Timeline(uri=self.uri) def copy(self, segment_func: Optional[Callable[[Segment], Segment]] = None) \ -> 'Timeline': """Get a copy of the timeline If `segment_func` is provided, it is applied to each segment first. Parameters ---------- segment_func : callable, optional Callable that takes a segment as input, and returns a segment. Defaults to identity function (segment_func(segment) = segment) Returns ------- timeline : Timeline Copy of the timeline """ # if segment_func is not provided # just add every segment if segment_func is None: return Timeline(segments=self.segments_list_, uri=self.uri) # if is provided # apply it to each segment before adding them return Timeline(segments=[segment_func(s) for s in self.segments_list_], uri=self.uri) def extent(self) -> Segment: """Extent The extent of a timeline is the segment of minimum duration that contains every segments of the timeline. It is unique, by definition. The extent of an empty timeline is an empty segment. A picture is worth a thousand words:: timeline |------| |------| |----| |--| |-----| |----------| timeline.extent() |--------------------------------| Returns ------- extent : Segment Timeline extent Examples -------- >>> timeline = Timeline(segments=[Segment(0, 1), Segment(9, 10)]) >>> timeline.extent() <Segment(0, 10)> """ if self.segments_set_: segments_boundaries_ = self.segments_boundaries_ start = segments_boundaries_[0] end = segments_boundaries_[-1] return Segment(start=start, end=end) else: import numpy as np return Segment(start=np.inf, end=-np.inf) def support_iter(self) -> Iterator[Segment]: """Like `support` but returns a segment generator instead See also -------- :func:`pyannote.core.Timeline.support` """ # The support of an empty timeline is an empty timeline. if not self: return # Principle: # * gather all segments with no gap between them # * add one segment per resulting group (their union |) # Note: # Since segments are kept sorted internally, # there is no need to perform an exhaustive segment clustering. # We just have to consider them in their natural order. # Initialize new support segment # as very first segment of the timeline new_segment = self.segments_list_[0] for segment in self: # If there is no gap between new support segment and next segment, if not (segment ^ new_segment): # Extend new support segment using next segment new_segment |= segment # If there actually is a gap, else: yield new_segment # Initialize new support segment as next segment # (right after the gap) new_segment = segment # Add new segment to the timeline support yield new_segment def support(self) -> 'Timeline': """Timeline support The support of a timeline is the timeline with the minimum number of segments with exactly the same time span as the original timeline. It is (by definition) unique and does not contain any overlapping segments. A picture is worth a thousand words:: timeline |------| |------| |----| |--| |-----| |----------| timeline.support() |------| |--------| |----------| Returns ------- support : Timeline Timeline support """ return Timeline(segments=self.support_iter(), uri=self.uri) def duration(self) -> float: """Timeline duration The timeline duration is the sum of the durations of the segments in the timeline support. Returns ------- duration : float Duration of timeline support, in seconds. """ # The timeline duration is the sum of the durations # of the segments in the timeline support. return sum(s.duration for s in self.support_iter()) def gaps_iter(self, support: Optional[Support] = None) -> Iterator[Segment]: """Like `gaps` but returns a segment generator instead See also -------- :func:`pyannote.core.Timeline.gaps` """ if support is None: support = self.extent() if not isinstance(support, (Segment, Timeline)): raise TypeError("unsupported operand type(s) for -':" "%s and Timeline." % type(support).__name__) # segment support if isinstance(support, Segment): # `end` is meant to store the end time of former segment # initialize it with beginning of provided segment `support` end = support.start # support on the intersection of timeline and provided segment for segment in self.crop(support, mode='intersection').support(): # add gap between each pair of consecutive segments # if there is no gap, segment is empty, therefore not added gap = Segment(start=end, end=segment.start) if gap: yield gap # keep track of the end of former segment end = segment.end # add final gap (if not empty) gap = Segment(start=end, end=support.end) if gap: yield gap # timeline support elif isinstance(support, Timeline): # yield gaps for every segment in support of provided timeline for segment in support.support(): for gap in self.gaps_iter(support=segment): yield gap def gaps(self, support: Optional[Support] = None) \ -> 'Timeline': """Gaps A picture is worth a thousand words:: timeline |------| |------| |----| |--| |-----| |----------| timeline.gaps() |--| |--| Parameters ---------- support : None, Segment or Timeline Support in which gaps are looked for. Defaults to timeline extent Returns ------- gaps : Timeline Timeline made of all gaps from original timeline, and delimited by provided support See also -------- :func:`pyannote.core.Timeline.extent` """ return Timeline(segments=self.gaps_iter(support=support), uri=self.uri) def segmentation(self) -> 'Timeline': """Segmentation Create the unique timeline with same support and same set of segment boundaries as original timeline, but with no overlapping segments. A picture is worth a thousand words:: timeline |------| |------| |----| |--| |-----| |----------| timeline.segmentation() |-|--|-| |-|---|--| |--|----|--| Returns ------- timeline : Timeline (unique) timeline with same support and same set of segment boundaries as original timeline, but with no overlapping segments. """ # COMPLEXITY: O(n) support = self.support() # COMPLEXITY: O(n.log n) # get all boundaries (sorted) # |------| |------| |----| # |--| |-----| |----------| # becomes # | | | | | | | | | | | | timestamps = set([]) for (start, end) in self: timestamps.add(start) timestamps.add(end) timestamps = sorted(timestamps) # create new partition timeline # | | | | | | | | | | | | # becomes # |-|--|-| |-|---|--| |--|----|--| # start with an empty copy timeline = Timeline(uri=self.uri) if len(timestamps) == 0: return Timeline(uri=self.uri) segments = [] start = timestamps[0] for end in timestamps[1:]: # only add segments that are covered by original timeline segment = Segment(start=start, end=end) if segment and support.overlapping(segment.middle): segments.append(segment) # next segment... start = end return Timeline(segments=segments, uri=self.uri) def to_annotation(self, generator: Union[str, Iterable[Label], None, None] = 'string', modality: Optional[str] = None) \ -> 'Annotation': """Turn timeline into an annotation Each segment is labeled by a unique label. Parameters ---------- generator : 'string', 'int', or iterable, optional If 'string' (default) generate string labels. If 'int', generate integer labels. If iterable, use it to generate labels. modality : str, optional Returns ------- annotation : Annotation Annotation """ from .annotation import Annotation annotation = Annotation(uri=self.uri, modality=modality) if generator == 'string': from .utils.generators import string_generator generator = string_generator() elif generator == 'int': from .utils.generators import int_generator generator = int_generator() for segment in self: annotation[segment] = next(generator) return annotation def write_uem(self, file: TextIO): """Dump timeline to file using UEM format Parameters ---------- file : file object Usage ----- >>> with open('file.uem', 'w') as file: ... timeline.write_uem(file) """ uri = self.uri if self.uri else "<NA>" for segment in self: line = f"{uri} 1 {segment.start:.3f} {segment.end:.3f}\n" file.write(line) def for_json(self): """Serialization See also -------- :mod:`pyannote.core.json` """ data = {PYANNOTE_JSON: self.__class__.__name__} data[PYANNOTE_JSON_CONTENT] = [s.for_json() for s in self] if self.uri: data[PYANNOTE_URI] = self.uri return data @classmethod def from_json(cls, data): """Deserialization See also -------- :mod:`pyannote.core.json` """ uri = data.get(PYANNOTE_URI, None) segments = [Segment.from_json(s) for s in data[PYANNOTE_JSON_CONTENT]] return cls(segments=segments, uri=uri) def _repr_png_(self): """IPython notebook support See also -------- :mod:`pyannote.core.notebook` """ from .notebook import repr_timeline return repr_timeline(self)
def Delete(self, listOfPKValues): if self.VerifyUniquePKval(listOfPKValues): # A True result implies the values don't yet exist in the table. print("Error: PK values specfied don't exist in ", self.TableName) return False #hashkey to remove from FKTable or to use for generating a key to remove a record parent_hashkey = self.getHashKey(listOfPKValues) # Not using FKTables. Just remove the record if it exists. if Table._UseFKTables == False: if Table._Debug: print("****Removing record ", listOfPKValues) key = (self.NameSpace, self.TableName, None, parent_hashkey) try: Table._CurrentClient.remove(key) #delete from the database except Exception as e: print("Error: Attempted delete. Not using FKTables") print("Error: {0}".format(e), sys.stderr) self.Rows.remove(parent_hashkey) return True #Implemente Restricted Delete. If the PK value is referenced by another table, the Delete will fail. ## When a row delete request is made at the parent Table, remove the row hash value from all child # FK value rows. If no more hashes exist in that row, the row can be deleted. ## When a row delete request is made at another table that is referenced by the parent table with # an FK, search for that FK’s value in the corresponding FKTable child table. If it exists, block # the delete. If it does not exist, complete the delete. #Case 1: Calling table has FKs in it. if len(self.FK) > 0: if Table._Debug: print("****This table has FK refs (it's a parent):", self.TableName) #Put listOfPKValues in a dict with their colNames so we can easily refer to them FKPKDict = dict(zip(self.PK, listOfPKValues)) #Some of these are not FKs but they won't be requested. for thisFKref in self.FK: FKTableToUpdate = self.FKtables[str(thisFKref['colName'] + '_FKTable')] #If the FK is part of the PK, it's already in this dict. if thisFKref['colName'] in self.PK: checkFKvalue = FKPKDict[thisFKref['colName']] else: #If the FK isn't part of the PK, we'll need to loop up its value so we can find it in the FKTable. record = self.Read_PK(listOfPKValues) checkFKvalue = record[thisFKref['colName']] hashkey = FKTableToUpdate.getHashKey([checkFKvalue]) key = (FKTableToUpdate.NameSpace, FKTableToUpdate.TableName, None, hashkey) if Table._Debug: print("****Updating FKtable ", FKTableToUpdate.TableName) try: (key, metadata, record) = Table._CurrentClient.get(key) #Do an RMW, revmoing the parent key fromm the list currentHashList = SortedList(record["TableHashes"]) if len(currentHashList) > 1: currentHashList.remove(parent_hashkey) Table._CurrentClient.put(key, {"TableHashes":currentHashList}) else: Table._CurrentClient.remove(key) #delete from the database FKTableToUpdate.Rows.remove(hashkey) #delete from the target table's list of record hashkeys. except Exception as e: print("error: {0}".format(e), sys.stderr) print("Error trying to delete in ", FKTableToUpdate.TableName) return False #if all that went well, actually remove the requested record if Table._Debug: print("****Removing record ", listOfPKValues) key = (self.NameSpace, self.TableName, None, parent_hashkey) try: Table._CurrentClient.remove(key) #delete from the database except Exception as e: print("Error: Attempted delete. Using FKTables, this table has FKs") print("Error: {0}".format(e), sys.stderr) self.Rows.remove(parent_hashkey) else: #Case 2: Calling table has no FKs (but it might be referenced) if Table._Debug: print("****This table has no FK refs, checking for references:", self.TableName) #Need to determine if the values in the delete are used in a referencing table. If yes, terminate #the delete. Otherwise, let it continue. PKDict = dict(zip(self.PK, listOfPKValues)) PKvalueIsReferenced = False for thisTable in Table._registry: for thisFKref in thisTable.FK: # If there are 0 elements in FK[], the table isn't ref'g the calling table if thisFKref['refTable'] == self: # The calling table is refd by this table. Now look for this value checkPKvalue = PKDict[thisFKref['refColName']] FKTableToCheck = thisTable.FKtables[str(thisFKref['colName'] + '_FKTable')] if FKTableToCheck.VerifyUniquePKval(checkPKvalue) == False: #If False, then the value exists in the FKTable PKvalueIsReferenced = True if PKvalueIsReferenced: print("WARNING: Can't delete this record, it is referenced by another record!") return False else: if Table._Debug: print("****No references found, removing record", listOfPKValues) key = (self.NameSpace, self.TableName, None, parent_hashkey) try: Table._CurrentClient.remove(key) #delete from the database except Exception as e: print("Error: Attempted delete. Using FKTables, this table has no FKs") print("Error: {0}".format(e), sys.stderr) self.Rows.remove(parent_hashkey) return True