class MaxStack: def __init__(self): self.by_time = SortedKeyList(key=lambda t: -t[1]) self.by_val = SortedKeyList(key=lambda t: (-t[0], -t[1])) self.time = -1 def push(self, x: int) -> None: self.time += 1 rec = (x, self.time) self.by_time.add(rec) self.by_val.add(rec) def pop(self) -> int: rec = self.by_time.pop(0) self.by_val.remove(rec) return rec[0] def top(self) -> int: rec = self.by_time[0] return rec[0] def peekMax(self) -> int: rec = self.by_val[0] return rec[0] def popMax(self) -> int: rec = self.by_val.pop(0) self.by_time.remove(rec) return rec[0]
def solve(self, initial_state: State, heuristic: callable) -> State: def smart_heuristic(state: State): value = len(state.board_history) / 10 + heuristic(state) return value state_list = SortedKeyList(key=smart_heuristic) state_list.add(initial_state) visited = dict() while len(state_list) > 0: curr: State = state_list.pop(0) key = list_to_string(curr.current_board.content) visited[key] = True if curr.current_board.content == sorted(curr.current_board.content): return curr for child in curr.next_states(): new_key = list_to_string(child.current_board.content) if not visited.get(new_key, False): while len(state_list) > 100: state_list.pop(-1) state_list.add(child) return None
def test_pop(): slt = SortedKeyList(range(10), key=negate) slt._reset(4) slt._check() assert slt.pop() == 0 slt._check() assert slt.pop(0) == 9 slt._check() assert slt.pop(-2) == 2 slt._check() assert slt.pop(4) == 4 slt._check()
def test_pop(): slt = SortedKeyList(range(10), key=modulo) slt._reset(4) slt._check() assert slt.pop() == 9 slt._check() assert slt.pop(0) == 0 slt._check() assert slt.pop(-2) == 7 slt._check() assert slt.pop(4) == 5 slt._check()
def test_pop(): slt = SortedKeyList(range(10), key=modulo) slt._reset(4) slt._check() assert slt.pop() == 9 slt._check() assert slt.pop(0) == 0 slt._check() assert slt.pop(-2) == 7 slt._check() assert slt.pop(4) == 5 slt._check()
def representative_trajectory(self, cluster): # TODO: Fix this :/ rep_trajectory = [] #Average direction vector: av_vector = np.array([0.0, 0.0]) for line in cluster: av_vector += line.vector av_vector /= len(cluster) print(av_vector) unit_av = av_vector/np.linalg.norm(av_vector) print(unit_av) x = np.array([1.0, 0.0]) theta = np.arccos(x.dot(unit_av)) if unit_av[1] > 0.0: theta = -theta rotation_mat = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) back_rotation_mat = np.array([[np.cos(-theta), np.sin(-theta)], [-np.sin(-theta), np.cos(-theta)]]) rotated_points = [] rotated_lines = [] for line in cluster: rot_v = rotation_mat.dot(line.vector) rot_b = line.a + line.length*rot_v rotated_points.append({"end": False, "point":line.a}) rotated_points.append({"end": True, "point": rot_b}) rotated_lines.append(LineSegment(line.a, rot_b)) rotated_points = sorted(rotated_points, key=lambda x: x["point"][0]) #Sort lines by starting x value line_start_lookup = SortedKeyList(rotated_lines, key=lambda x: x.a[0]) #Sort lines the sweep line crosses by ending x value intersecting_lines = SortedKeyList([], key=lambda x:x.b[0]) last_x = 0.0 for point_dict in rotated_points: if point_dict["end"]: try: intersecting_lines.pop(0) except Exception as e: print("Could not generate a representative trajectory. Examine your clustering parameters") break; else: intersecting_lines.add(line_start_lookup.pop(0)) if len(intersecting_lines) >= self.min_lns: # diff = point_dict["point"][0] - last_x # if diff >= self.gamma: average_y = 0.0 for line in intersecting_lines: slope = line.vector[1]/line.vector[0] average_y += (point_dict["point"][0]-line.a[0])*slope average_y /= len(intersecting_lines) rep_trajectory.append(np.array([point_dict["point"][0], average_y])) return rep_trajectory
def test_pop(): slt = SortedKeyList(range(10), key=negate) slt._reset(4) slt._check() assert slt.pop() == 0 slt._check() assert slt.pop(0) == 9 slt._check() assert slt.pop(-2) == 2 slt._check() assert slt.pop(4) == 4 slt._check()
class MemoryTimestampIndex(TimestampIndex): """ Index of transactions sorted by their timestamps. """ _index: 'SortedKeyList[TransactionIndexElement]' def __init__(self) -> None: self.log = logger.new() self._index = SortedKeyList(key=lambda x: (x.timestamp, x.hash)) def add_tx(self, tx: BaseTransaction) -> bool: assert tx.hash is not None # It is safe to use the in operator because it is O(log(n)). # http://www.grantjenks.com/docs/sortedcontainers/sortedlist.html#sortedcontainers.SortedList.__contains__ element = TransactionIndexElement(tx.timestamp, tx.hash) if element in self._index: return False self._index.add(element) return True def del_tx(self, tx: BaseTransaction) -> None: idx = self._index.bisect_key_left((tx.timestamp, tx.hash)) if idx < len(self._index) and self._index[idx].hash == tx.hash: self._index.pop(idx) def get_newest(self, count: int) -> Tuple[List[bytes], bool]: return get_newest_sorted_key_list(self._index, count) def get_older(self, timestamp: int, hash_bytes: bytes, count: int) -> Tuple[List[bytes], bool]: return get_older_sorted_key_list(self._index, timestamp, hash_bytes, count) def get_newer(self, timestamp: int, hash_bytes: bytes, count: int) -> Tuple[List[bytes], bool]: return get_newer_sorted_key_list(self._index, timestamp, hash_bytes, count) def get_hashes_and_next_idx(self, from_idx: RangeIdx, count: int) -> Tuple[List[bytes], Optional[RangeIdx]]: timestamp, offset = from_idx idx = self._index.bisect_key_left((timestamp, b'')) txs = SortedKeyList(key=lambda x: (x.timestamp, x.hash)) txs.update(self._index[idx:idx+offset+count]) ret_txs = txs[offset:offset+count] hashes = [tx.hash for tx in ret_txs] if len(ret_txs) < count: return hashes, None else: next_offset = offset + count next_timestamp = ret_txs[-1].timestamp if next_timestamp != timestamp: next_idx = txs.bisect_key_left((next_timestamp, b'')) next_offset -= next_idx return hashes, RangeIdx(next_timestamp, next_offset)
def best_first_search(board: Board, heuristic: callable, move_order=DEFAULT_MOVE_ORDER) -> Optional[Board]: board_list = SortedKeyList(key=heuristic) board_list.add(board) visited = dict() while len(board_list) > 0: considered_board = board_list.pop(0) # print("Distance: {}".format(heuristic(considered_board))) key = array_to_string(considered_board.board) # if len(considered_board.move_history) > limit: # continue visited[key] = True if considered_board.is_solved(): return considered_board for direction in move_order: if considered_board.is_move_possible(direction): new_board = np.copy(considered_board.board) new_history = considered_board.move_history[:] new_board = Board(new_board, new_history) new_board.move(direction) new_key = array_to_string(new_board.board) if visited.get(new_key, False) == False: board_list.add(new_board) return None
def A_star(board: Board, heuristic: callable, move_order=DEFAULT_MOVE_ORDER) -> Optional[Board]: def smart_heuristic(board): value = len(board.move_history) / 100 + heuristic(board) return value board_list = SortedKeyList(key=smart_heuristic) board_list.add(board) visited = dict() while len(board_list) > 0: considered_board = board_list.pop(0) key = array_to_string(considered_board.board) visited[key] = True if considered_board.is_solved(): return considered_board for direction in move_order: if considered_board.is_move_possible(direction): new_board = np.copy(considered_board.board) new_history = considered_board.move_history[:] new_board = Board(new_board, new_history) new_board.move(direction) new_key = array_to_string(new_board.board) if visited.get(new_key, False) == False: board_list.add(new_board) return None
def _run_clustering(self): print("try to use clustering") # here we need to run clustering # first of all we need to choose features max_feature_count = int(self._max_freq * len(self._urls)) min_feature_count = int(self._min_freq * len(self._urls)) start_index = bisect.bisect_left(self._features_count_list, (min_feature_count, '')) end_index = bisect.bisect_right(self._features_count_list, (max_feature_count, 'ZZZ')) if start_index >= end_index: print("not enough features") return self._next_queue_fallback() chosen_features = SortedSet() for i in range(start_index, end_index): chosen_features.add(self._features_count_list[i][1]) # then we need to build features matrix X = np.empty((len(self._urls), len(chosen_features))) for i in range(len(self._urls)): features = self._urls[self._urls_keys[i]][self._i_features] for j, fname in enumerate(chosen_features): if fname in features: X[i][j] = 1 else: X[i][j] = 0 # now we can run clustering y = self._clusterizer.fit_predict(X) # and we need to create uniform distributed queue def get_list_of_2_sets(): return [set(), set(), 0] # 0 is for used urls, # 1 is for unused # 3 is for total count url_in_cluster = defaultdict(get_list_of_2_sets) for i in range(len(y)): url = self._urls_keys[i] if self._urls[url][self._i_is_used]: url_in_cluster[y[i]][self._i_list_for_used].add(url) else: url_in_cluster[y[i]][self._i_list_for_unused].add(url) url_in_cluster[y[i]][self._i_list_total] += 1 limit = self._subqueue_len cluster_keys = SortedKeyList(url_in_cluster.keys(), key=lambda x: -len(url_in_cluster[x][self._i_list_for_used])) while limit > 0: # Todo: optimize if len(cluster_keys) > 0: less_index = cluster_keys.pop() unused_urls = url_in_cluster[less_index][self._i_list_for_unused] if len(unused_urls) > 0: url = unused_urls.pop() self._subqueue.put(url) limit -= 1 if len(unused_urls) > 0: url_in_cluster[less_index][self._i_list_for_used].add(url) cluster_keys.add(less_index) else: break
class PriorityQueue: def __init__(self, capacity=None, key=None): self._data = SortedKeyList(key=self._rank) self._capacity = inf if capacity is None else capacity self._key = key def _rank(self, item): if self._key: return self._key(*item) return item.rank def add(self, value, rank): self._data.add(Element(value, rank)) self._shrink() def clear(self): return self._data.clear() def __repr__(self): return f"PriorityQueue([{', '.join(f'{v}: {r}' for (v, r) in self._data)}])" def _shrink(self): while len(self._data) > self._capacity: self._data.pop() def update(self, src): raise NotImplementedError def __contains__(self, value): raise NotImplementedError def __iter__(self): for value, rank in self._data: yield value def __getitem__(self, index): if isinstance(index, int): return self._data[index].value return list(self)[index] def size(self): return len(self._data)
def rearrangeString(self, s: str, k: int) -> str: if k == 0: return s cnt = SortedKeyList(Counter(s).items(), key=lambda t: (-t[1], t[0])) last = defaultdict(lambda: -k) ans = '' while len(cnt) > 0: upd = [] for _ in range(min(k, len(cnt))): sym, num = cnt.pop(0) if len(ans) - last[sym] < k: return '' last[sym] = len(ans) ans += sym if num > 1: upd.append((sym, num - 1)) cnt.update(upd) return ans
class MyMongoCollection: def __init__(self, database, name: str): from mymongoDB import MyMongoDB if not isinstance(database, MyMongoDB): raise MongoException( "Only MongoDB objects can be passed as the database argument") self.name: str = name self.parent_database: MyMongoDB = database # сортирован по ключам словарей self.__docs: SortedKeyList[MongoId, MyMongoDoc] = SortedKeyList( key=lambda doc: doc['objectId']) self.__indices: Dict[FieldName, SortedKeyList] = SortedDict() self.__reserved_ids: MutableSet = SortedSet(set()) def __repr__(self) -> str: meta = f"MyMongoCollection({repr(self.parent_database)}, {self.name})" return meta def __len__(self) -> int: return len(self.__docs) def __getstate__(self): return self.__dict__ def __setstate__(self, state): self.__dict__ = state def sort(self, key): self.__docs = SortedKeyList(self.__docs, key=key) for field in self.__indices.keys(): self.create_index(field) def create_index(self, field: str) -> None: if not isinstance(field, str): raise MongoException("'Field' argument must be a string") # только документы с данным полем relevant_docs: List[MyMongoDoc] = [ doc for doc in self.__docs if field in doc.keys() ] self.__indices[field] = SortedKeyList(relevant_docs, key=lambda doc: doc[field]) def insert_one(self, doc: Union[Dict, MutableMapping]) -> MongoLastInserted: if not (isinstance(doc, dict) or issubclass(doc.__class__, MutableMapping)): raise MongoException( f"Document \n{doc}\n must be an instance of dict" "or a type that inherits from collections.MutableMapping") new_doc = _MyMongoDocFactory.get_doc(data=doc) if new_doc.objectId in self.__reserved_ids: raise MongoException( f"Duplicate key error: document already in collection {new_doc}" ) else: self.__reserved_ids.add(new_doc.objectId) self.__docs.add(new_doc) return MongoLastInserted(new_doc) def insert_many(self, *docs) -> List[MyMongoDoc]: last: List[MyMongoDoc] = list() if isinstance(docs[0], dict) or issubclass(docs[0].__class__, MutableMapping): pass # любой другой iterable на свой страх и риск elif hasattr(docs[0], "__iter__") and len(docs) == 1: docs = docs[0] else: raise MongoException( f"Function accepts iterables of dicts or a type that inherits from " "collections.MutableMapping, or simply non-keyword arguments.") for doc in docs: last.append(self.insert_one(doc).document) return last def delete_one(self, object_id: MongoId) -> None: try: if isinstance(object_id, MongoId): self.__reserved_ids.remove(object_id) # найдём объект для удаления в теле коллекции документов по его MongoId, используя # куклу с таким же индексом dummy = {"objectId": object_id} object_idx: int = self.__docs.bisect_left(dummy) doc: MyMongoDoc = self.__docs[object_idx] # очистим индексы от удаляемого объекта for field in self.__indices.keys(): if field in doc.keys(): obj_idx = self.__indices[field].bisect_left(doc) self.__indices[field].remove(obj_idx) self.__docs.pop(object_idx) else: raise TypeError( "Only instances of MongoId can serve as document identifiers." ) except KeyError as e: raise KeyError( f"Collection {self.name} has no object with id {object_id}.") def delete_many(self, object_ids: Iterable[MongoId]): for object_id in object_ids: self.delete_one(object_id) def clear(self): self.__docs.clear() self.__indices.clear() self.__reserved_ids.clear() def find_one(self, query: Dict[str, Any]): result = None relevant_docs, _query = self.__get_relevant_info(query) if len(_query) == 0: return relevant_docs else: relevant_docs = iter(relevant_docs) while result is None: candidate = next(relevant_docs) try: boolean = [ candidate[key] == value for key, value in _query.items() ] if all(boolean): result = candidate except KeyError: pass return result def find(self, query: Dict[str, Any]) -> Iterable[MyMongoDoc]: result = list() relevant_docs, _query = self.__get_relevant_info(query) if len(_query) == 0: return list(relevant_docs) else: relevant_docs = iter(relevant_docs) for candidate in relevant_docs: try: boolean = [ candidate[key] == value for key, value in _query.items() ] if all(boolean): result.append(candidate) except KeyError: pass return list(result) def find_and_update(self, filter_, update_: Dict[FieldName, Any]) -> None: docs = self.find(filter_) for doc in docs: for k, v in update_.items(): doc[k] = v return docs def find_one_and_update(self, filter_, update_: Dict[FieldName, Any]) -> None: doc = self.find_one(filter_) for k, v in update_.items(): doc[k] = v return doc def query( self, where: MutableMapping[FieldName, Callable[..., Bool]] ) -> Iterable[MyMongoDoc]: result = list() relevant_docs, _query = self.__get_relevant_info(where) if len(_query) == 0: return list(relevant_docs) else: # уменьшаем количество документов для поиска relevant_docs = iter(relevant_docs) for candidate in relevant_docs: try: boolean = [ function(candidate[key]) for key, function in _query.items() ] if all(boolean): result.append(candidate) except KeyError: pass return list(result) def __get_relevant_info(self, query) -> Tuple[Iterable, MutableMapping]: relevant_docs = set() query_ = deepcopy(query) indexed_query_fields = self.__indices.keys() & query_.keys() # Авось придёт запрос по индексирвованным полям if len(indexed_query_fields) != 0: # если по полям составлен индекс, то учитывая, что все элементы query логически # связаны оператором AND, то для начала можно просто вынуть пересечение документов, удовлетворяющих # требованиям к индексированным полям. for indexed_query_field in indexed_query_fields: # marker - dummy объект, словарик с искомым полем и значением. Мы ищем с логарифмической сложностью, # куда его можно приткнуть в наш индекс (находим точку до документов с идентичным значением в # искомом поле), а затем извлекаем 0+ равнозначных (в рамках искомого поля) объектов marker = {indexed_query_field: query[indexed_query_field]} index: SortedKeyList = self.__indices[indexed_query_field] # Return an index to insert value in the sorted list. If the value is already present, # the insertion point will be before (to the left of) any # existing values. idx = index.bisect_left(marker) while idx < len(index) and query[indexed_query_field] == index[ idx][indexed_query_field]: relevant_docs.add(index[idx]) idx += 1 # мы можем больше не смотреть на поля, по которым имеется индекс query_ = { k: v for k, v in query.items() if k not in indexed_query_fields } else: # Что поделать, раз уж запрос такой? relevant_docs = self.__docs return relevant_docs, query_
class TransactionsIndex: """ Index of transactions sorted by their timestamps. """ transactions: 'SortedKeyList[TransactionIndexElement]' def __init__(self) -> None: self.transactions = SortedKeyList(key=lambda x: (x.timestamp, x.hash)) def __getitem__(self, index: slice) -> List[TransactionIndexElement]: """ Get items from SortedKeyList given a slice :param index: list index slice, for eg [1:6] """ return self.transactions[index] def update(self, values: List[TransactionIndexElement]) -> None: """ Update sorted list by adding all values from iterable :param values: new values to add to SortedKeyList """ self.transactions.update(values) def add_tx(self, tx: BaseTransaction) -> None: """ Add a transaction to the index :param tx: Transaction to be added """ assert tx.hash is not None # It is safe to use the in operator because it is O(log(n)). # http://www.grantjenks.com/docs/sortedcontainers/sortedlist.html#sortedcontainers.SortedList.__contains__ element = TransactionIndexElement(tx.timestamp, tx.hash) if element in self.transactions: return self.transactions.add(element) def del_tx(self, tx: BaseTransaction) -> None: """ Delete a transaction from the index :param tx: Transaction to be deleted """ idx = self.transactions.bisect_key_left((tx.timestamp, tx.hash)) if idx < len(self.transactions) and self.transactions[idx].hash == tx.hash: self.transactions.pop(idx) def find_tx_index(self, tx: BaseTransaction) -> Optional[int]: """Return the index of a transaction in the index :param tx: Transaction to be found """ idx = self.transactions.bisect_key_left((tx.timestamp, tx.hash)) if idx < len(self.transactions) and self.transactions[idx].hash == tx.hash: return idx return None def get_newest(self, count: int) -> Tuple[List[bytes], bool]: """ Get transactions or blocks from the newest to the oldest :param count: Number of transactions or blocks to be returned :return: List of tx hashes and a boolean indicating if has more txs """ return get_newest_sorted_key_list(self.transactions, count) def get_older(self, timestamp: int, hash_bytes: bytes, count: int) -> Tuple[List[bytes], bool]: """ Get transactions or blocks from the timestamp/hash_bytes reference to the oldest :param timestamp: Timestamp reference to start the search :param hash_bytes: Hash reference to start the search :param count: Number of transactions or blocks to be returned :return: List of tx hashes and a boolean indicating if has more txs """ return get_older_sorted_key_list(self.transactions, timestamp, hash_bytes, count) def get_newer(self, timestamp: int, hash_bytes: bytes, count: int) -> Tuple[List[bytes], bool]: """ Get transactions or blocks from the timestamp/hash_bytes reference to the newest :param timestamp: Timestamp reference to start the search :param hash_bytes: Hash reference to start the search :param count: Number of transactions or blocks to be returned :return: List of tx hashes and a boolean indicating if has more txs """ return get_newer_sorted_key_list(self.transactions, timestamp, hash_bytes, count) def find_first_at_timestamp(self, timestamp: int) -> int: """ Get index of first element at the given timestamp, or where it would be inserted if the timestamp is not in the list. Eg: SortedKeyList = [(3,hash1), (3, hash2), (7, hash3), (8, hash4)] find_first_at_timestamp(7) = 2, which is the index of (7, hash3) find_first_at_timestamp(4) = 2, which is the index of (7, hash3) :param timestamp: timestamp we're interested in :return: the index of the element, or None if timestamp is greater than all in the list """ idx = self.transactions.bisect_key_left((timestamp, b'')) return idx
def test_pop_indexerror1(): slt = SortedKeyList(range(10), key=negate) slt._reset(4) with pytest.raises(IndexError): slt.pop(-11)
class OrderBook: """ An order book for a single instrument. Attributes: --bids -> A PriorityQueue sorted by price to contain all bids. We use SortedKeyList to enforce ordering and have fast insert + remove operations --asks -> A PriorityQueue sorted by price to contain all asks. We use SortedKeyList to enforce ordering and have fast insert + remove operations --best_bid -> A bid which is first in line to be executed. --best_ask -> An ask which is first in line to be executed --attempt_match -> A boolean checking whether a match should be attempted. --trades -> A record of all completed crossings. This is a dequeus (linked lists) because we require fast (O(1)) access, fast insert, and never need to search the list --complete_orders -> A record of completed orders. This is a dequeus (linked lists) because we require fast (O(1)) access, fast insert, and never need to search the list """ def __init__(self): self.bids = SortedKeyList(key=lambda x: -x.price) self.asks = SortedKeyList(key=lambda x: x.price) self.best_bid: Optional[BaseOrder] = None self.best_ask: Optional[BaseOrder] = None self.attempt_match = False self.trades: deque = deque() self.complete_orders: deque = deque() def add_bid(self, order: BaseOrder) -> None: """ Adding a bid to the order book If there is no best bid, the order must be it, else we compare the order with the best bid and update, placing the lower bid price into the book. We use bisect right to ensure ordering by time when prices match """ best_bid = self.best_bid if not best_bid: self.best_bid = order self.attempt_match = True elif order.price <= best_bid.price: self.bids.add(order) else: self.bids.add(best_bid) self.best_bid = order self.attempt_match = True def add_ask(self, order: BaseOrder) -> None: """ Adding an ask to the order book If there is no best ask, the order must be it, else we compare the order with the best ask and update, placing the higher ask price into the book. """ best_ask = self.best_ask if not best_ask: self.best_ask = order self.attempt_match = True elif order.price >= best_ask.price: self.asks.add(order) else: self.asks.add(best_ask) self.best_ask = order self.attempt_match = True def find_in_list(self, orders: List[BaseOrder], order_id: int) -> Optional[BaseOrder]: for order in orders: if order.order_id == order_id: return order return None def add_cancel(self, order: CancelOrder) -> None: """ Cancelling an existing order Check all orders to find the first matching order_id and cancel it if possible. """ if order.order_direction == OrderDirection.buy and self.best_bid is not None: best_bid = self.best_bid bids = self.bids if order.order_id == best_bid.order_id: order.cancel_order(best_bid) self.complete_orders.append(best_bid) if bids: self.best_bid = bids.pop(0) self.attempt_match = True else: self.best_bid = None elif bids: matched_order = self.find_in_list(bids, order.order_id) if matched_order: bids.remove(matched_order) self.complete_orders.append(matched_order) order.cancel_order(matched_order) elif order.order_direction == OrderDirection.sell and self.best_ask is not None: best_ask = self.best_ask asks = self.asks if order.order_id == best_ask.order_id: order.cancel_order(best_ask) self.complete_orders.append(best_ask) if self.asks: self.best_ask = self.asks.pop(0) self.attempt_match = True else: self.best_ask = None elif asks: matched_order = self.find_in_list(asks, order.order_id) if matched_order: asks.remove(matched_order) self.complete_orders.append(matched_order) order.cancel_order(matched_order) return None def add_order(self, order: BaseOrder) -> None: if order.order_type == OrderType.cancel: self.add_cancel(order) elif order.order_direction == OrderDirection.buy: self.add_bid(order) elif order.order_direction == OrderDirection.sell: self.add_ask(order) else: raise InvalidOrderDirectionException() def match(self) -> None: """ Attempt to match orders. If possible, match orders and replace the best bid and best ask as needed. Continue matching until you no longer can. If no match occurs, update so that no match is attempted until conditions change. """ while self.attempt_match and self.best_bid and self.best_ask: self.attempt_match = False best_bid = self.best_bid best_ask = self.best_ask if (best_bid.price >= best_ask.price): execution_price = (best_bid.price + best_ask.price) / 2 matched_quantity = min(best_ask.unfilled_quantity, best_bid.unfilled_quantity) trade = Trade(datetime=np.datetime64("now"), price=execution_price, quantity=matched_quantity) best_bid.update_on_trade(trade) best_ask.update_on_trade(trade) self.trades.append(trade) if best_bid.status != OrderStatus.live: self.complete_orders.append(best_bid) if self.bids: self.best_bid = self.bids.pop(0) self.attempt_match = True else: self.best_bid = None if best_ask.status != OrderStatus.live: self.complete_orders.append(best_ask) if self.asks: self.best_ask = self.asks.pop(0) self.attempt_match = True else: self.best_ask = None else: break self.attempt_match = False def plot_order_book(self) -> None: """ Create a line plot showing order book volume and prices""" fig = plt.figure() ax = fig.add_subplot(111) ax.set_title("Limit Order Book") ax.set_xlabel("Price") ax.set_ylabel("Quantity") if self.best_bid: # Cumulative bid volume bids = [self.best_bid.quantity] + \ [bid.quantity for bid in self.bids] bids = list(np.cumsum(bids)) bids.reverse() # Bid prices bid_prices = [bid.price for bid in self.bids] bid_prices.reverse() bid_prices += [self.best_bid.price] else: return None if self.best_ask: # Cumulative ask volume asks = [self.best_ask.quantity] asks += [ask.quantity for ask in self.asks] asks = list(np.cumsum(asks)) # Ask prices ask_prices = [self.best_ask.price] + \ [ask.price for ask in self.asks] else: return None # Draw ax.step(bid_prices, bids, color='green') ax.step(ask_prices, asks, color='red') ax.set_xlim([min(bid_prices), max(ask_prices)]) plt.savefig("images/order_book.png") def plot_executions(self) -> None: """ Create a line plot showing historic executions """ fig, (ax1, ax2) = plt.subplots(2) fig.suptitle("Historic Executions") ax1.set_xlabel("Time") ax1.set_ylabel("Execution Price") ax2.set_xlabel("Time") ax2.set_ylabel("Executed Quantity") # Draw times = [t.datetime for t in self.trades] ax1.plot(times, [t.price for t in self.trades]) ax2.plot(times, [t.quantity for t in self.trades]) ax1.set_xlim([min(times), max(times)]) ax2.set_xlim([min(times), max(times)]) plt.savefig("images/executions.png")
def test_pop_indexerror2(): slt = SortedKeyList(range(10), key=modulo) slt._reset(4) with pytest.raises(IndexError): slt.pop(10)
def astar(start_pose, goal_position, occupancy_grid): # Open list contain nodes to be checked open_list = SortedKeyList(key=lambda x: x.cost) # Define goal position and occupancy grid for all Nodes Node.occupancy_grid = occupancy_grid Node.goal = Node(goal_position) if not occupancy_grid.is_free(start_pose[:2]): print('Start position is not in the free space.') return None if not occupancy_grid.is_free(goal_position): print('Goal position is not in the free space.') return None start_node = Node(start_pose) explored = 0 open_list.add(start_node) while open_list: current_node = open_list.pop(0) if occupancy_grid.is_explored(current_node.indices): continue occupancy_grid.set_explored(current_node.indices) explored += 1 # if explored % 10 == 0: # occupancy_grid.sketch() # Found the goal if current_node == Node.goal: path = [] current = current_node while current is not None: path.append(current.position) occupancy_grid.set_path(current.indices) current = current.parent return path[::-1] # Generate new nodes to check children = [] for cardinal_step in [(0, -1), (0, 1), (-1, 0), (1, 0)]: node = explore_cardinal(occupancy_grid, current_node, cardinal_step[0], cardinal_step[1]) if node: children.append(node) for diag_step in [(1, 1), (-1, 1), (1, -1), (-1, -1)]: nodes = explore_diagonal(occupancy_grid, current_node, diag_step[0], diag_step[1]) if nodes: children += nodes for child in children: # Check if new node is already in the open list to_add = True for i, open_node in enumerate(open_list): if child == open_node: to_add = False if child.pathcost < open_node.pathcost: del open_list[i] open_list.add(child) break # Add new node to the open list if to_add: open_list.add(child) return None
class ParamsTaskQueue(disp.TaskQueue, ps.SweepListener): """This task queue acts like a greedy task queue, except it only works for tasks which amount to calling model_manager.train for a particular set of models. The process for selecting tasks is as follows: - If we have a hparameter sweep queued and we have enough cores to do one, do that. - Perform as many trials as possible. Note we can only perform trials on parameters we already know the hyperparameters for :ivar int total_cores: the number of physical cores that are available :ivar int sweep_cores: the number of cores to use for sweeping (not greater than the number of total cores) :ivar str module: the module which we are getting the model/dataset/etc :ivar HyperparameterSettings hparams: strategy for hyperparameters :ivar str folder: the folder containing the points folder :ivar list[SweepListener] listeners: the sweep listeners. contains self. :ivar set[tuple] in_progress: parameters which are currently in progress. :ivar list[ParamsTask] sweeps: the sweeps that need to be performed in an arbitrary order :ivar dict[tuple, int] params_to_sweeps_ind: a lookup that goes from the parameters of tasks to the index in sweeps if a sweep is still necessary. :ivar SortedList[ParamsTask] trials: the trials that need to be performed, in ascending order of the number of trials. :ivar dict[tuple, int] params_to_id: dictionary which converts a given set of parameters to the corresponding unique identifier for that set of parameters. :ivar dict[tuple, int] params_to_ntrials: dictionary which converts a given set of parameters to the corresponding number of trials that have been dispatched for those parameters :ivar int next_id: the next id that should be given out to a set of parameters and then incremented. :ivar dict[tuple, ParamsTask] params_to_task: a lookup that goes from params lists to param tasks, where this only contains tasks which have not yet been dispatched, and does not include tasks which are in sweeps :ivar int _len: number of actual tasks currently in queue :ivar bool expecting_more_trials: True to prevent saying we are out of trials, False otherwise """ def __init__(self, total_cores: int, sweep_cores: int, module: str, hparams: hyperparams.HyperparameterSettings, folder: str, listeners: typing.List[ps.SweepListener]): self.total_cores = total_cores self.sweep_cores = sweep_cores self.module = module self.hparams = hparams self.folder = folder self.listeners = list(listeners) self.listeners.append(self) self.in_progress = set() self.sweeps = list() self.params_to_sweeps_ind = dict() self.trials = SortedKeyList(key=lambda tsk: tsk.trials) self.params_to_id = dict() self.params_to_ntrials = dict() self.next_id = 0 self.params_to_task = dict() self._len = 0 self.expecting_more_trials = False def add_task_by_params(self, params: tuple, trials: int) -> None: """Adds a task to this queue based on the parameters which should be swept and the number of trials to perform. Regardless of the value for trials, this will ensure that the hyperparameters for the given model parameters have been found. """ sweep_id = self.params_to_id.get(params) if sweep_id is None: sweep_id = self.next_id self.next_id += 1 self.sweeps.append(ParamsTask(params, 0)) self.params_to_sweeps_ind[params] = len(self.sweeps) - 1 self.params_to_id[params] = sweep_id self.params_to_ntrials[params] = 0 self._len += 1 if trials <= 0: return tsk = self.params_to_task.get(params) if tsk is not None: self.trials.remove(tsk) tsk.trials += trials self.trials.add(tsk) return tsk = ParamsTask(params, trials) self.params_to_task[params] = tsk self.trials.add(tsk) self._len += 1 def set_total_cores(self, total_cores): self.total_cores = total_cores def on_hparams_found(self, values, lr_min, lr_max, batch_size): logger.debug('Found hyperparameters for %s: lr=(%s, %s), bs=%s', values, lr_min, lr_max, batch_size) self.in_progress.remove(values) def on_trials_completed(self, values, perfs_train, losses_train, perfs_val, losses_val): logger.info( 'Completed some trials for %s - mean train/val perf = %s / %s', values, perfs_train.mean(), perfs_val.mean()) logger.debug('%s - perf: %s, loss: %s, val - perf: %s, loss: %s', values, perfs_train, losses_train, perfs_val, losses_val) self.in_progress.remove(values) self.params_to_ntrials[values] += len(losses_train) def _get_next_task(self, cores): # Pseudocode: # if we have enough cores to sweep and a sweep available then # pop from the end of sweeps (so only one index changes) # remove from params_to_sweeps_ind # add to in_progress # return # # pop the item with the most number of trials from trials, ignoring # ones which are already in progress or haven't been sweep yet # # if we cannot finish this then # build the disp.Task which does the right # of trials # update the remaining number of trials for this trial # add to in_progress # return the built disp.Task # build the disp.Task which finishes the queued trials for this set # remove from params_to_task # add to in_progress # return built disp.Task if cores <= 0: return None if cores >= self.sweep_cores and self.sweeps: swp: ParamsTask = self.sweeps.pop() del self.params_to_sweeps_ind[swp.params] self._len -= 1 self.in_progress.add(swp.params) return swp.as_task(self.module, self.hparams, self.folder, self.params_to_id[swp.params], 0, self.sweep_cores, self.listeners, self.params_to_ntrials[swp.params]) if not self.trials: return None pop_ind = len(self.trials) - 1 while True: trl = self.trials[pop_ind] if (trl.params not in self.in_progress and trl.params not in self.params_to_sweeps_ind): break pop_ind -= 1 if pop_ind < 0: return None trl = self.trials.pop(pop_ind) if trl.trials > cores: trl.trials -= cores self.trials.add(trl) self.in_progress.add(trl.params) return trl.as_task(self.module, self.hparams, self.folder, self.params_to_id[trl.params], cores, self.sweep_cores, self.listeners, self.params_to_ntrials[trl.params]) del self.params_to_task[trl.params] self._len -= 1 self.in_progress.add(trl.params) return trl.as_task(self.module, self.hparams, self.folder, self.params_to_id[trl.params], trl.trials, self.sweep_cores, self.listeners, self.params_to_ntrials[trl.params]) def get_next_task(self, cores): res = self._get_next_task(cores) if res is not None: logger.debug('starting task %s', str(res)) return res def have_more_tasks(self): return (self.expecting_more_trials or self._len > 0) def __len__(self) -> int: return self._len
def test_pop_indexerror1(): slt = SortedKeyList(range(10), key=negate) slt._reset(4) with pytest.raises(IndexError): slt.pop(-11)
def compute_dom_homomorphic_map(): """ filter map using dominator homomorphism :returns GraphMap """ def translate_id(node_id, isBinary): """some IDs are newly inserted for collapsed graphs and do not exist in the original flow graph -- translate them to their original equivalent""" if isBinary: if node_id > self.bFlow.get_max_id(): return nodes_new_b[node_id] else: if node_id > self.sFlow.get_max_id(): return nodes_new_s[node_id] return node_id def test_homomorphism(binary_nodes): """Check whether all the mapping is valid so far""" failed_count = 0 for b in binary_nodes: for b_ in binary_nodes: if b_ == b: continue a = f_map.get(b, None) a_ = f_map.get(b_, None) if a is None or a_ is None: continue # Get original IDs for dominance check og_b = translate_id(b, True) og_b_ = translate_id(b_, True) og_a = translate_id(a, False) og_a_ = translate_id(a_, False) log.debug("b,b_={},{}; a,a_={},{}".format( b, b_, a, a_)) log.debug("og_b,og_b_={},{}; og_a,og_a_={},{}".format( og_b, og_b_, og_a, og_a_)) if self.bFlow.predom_tree().test_dominance(og_b, og_b_) != \ self.sFlow.predom_tree().test_dominance(og_a, og_a_) or \ self.bFlow.predom_tree().test_dominance(og_b_, og_b) != \ self.sFlow.predom_tree().test_dominance(og_a_, og_a): log.debug( "bin_dominance={}, src_dominance={}".format( self.bFlow.predom_tree().test_dominance( og_b, og_b_), self.sFlow.predom_tree().test_dominance( og_a, og_a_))) log.debug( "Preorder numbers og_b,og_b_: {},{}".format( self.bFlow.predom_tree( ).get_preorder_number(og_b), self.bFlow.predom_tree(). get_preorder_number(og_b_))) log.debug( "Preorder numbers og_a,og_a_: {},{}".format( self.sFlow.predom_tree( ).get_preorder_number(og_a), self.sFlow.predom_tree(). get_preorder_number(og_a_))) add_back_to_worklist(b) add_back_to_worklist(b_) failed_count += 1 log.debug("Homomorphism failed") return failed_count def add_back_to_worklist(b): if b in fixed_points: return worklist.add(b) f_map[b] = None def check_conflict(r, b): """check if src-bb r is known to be a bad choice for bin-bb b, given the current state of the mapping. """ if r not in f_confl[b]: return False # see if any of the known conflicts are already in the map hasConflict = False for b_, r_ in f_confl[b][r]: if f_map.get( b_, None) == r_: # is the conflicting one in the map? log.debug( "conflict: {}->{} not allowed because {}->{} in mapping" .format(b, r, b_, r_)) hasConflict = True break return hasConflict def select_reference(b): """Among possible references, return the first non-conflicting one""" p_b = potential_map_bin2src[b] for r in p_b: if not check_conflict(r, b): return r return None def add_conflict(b, a, b_, a_): """ Store that b->a and b->a' are conflicting decisions b*= binary, a*=source """ if a not in f_confl[b]: f_confl[b][a] = set() if a_ not in f_confl[b_]: f_confl[b_][a_] = set() f_confl[b][a].add((b_, a_)) # b->a conflicts with b'->a' f_confl[b_][a_].add((b, a)) # b'->a conflicts with b->a log.debug("{}->{} conflicts with {}->{}".format(b, a, b_, a_)) def remove_ambiguous(): """Remove all entries from f_map that where we could have confused siblings""" def do_level(node): """Dive down dom tree, and check for ambiguity at each level""" mapped_by = dict() # src-bb -> bin-bb in this btfg for ch in pdt.successors(node): # if has children, their dom. relationships will make it unambig. if ch in f_map and pdt.out_degree(ch) == 0: srcbbs = potential_map_bin2src[ch] for sbb in srcbbs: if sbb not in mapped_by: mapped_by[sbb] = set() mapped_by[sbb].add(ch) # remove those which have multiple src locations delbb = { bb for _, bbb in mapped_by.iteritems() if len(bbb) > 1 for bb in bbb } if delbb: ambiguous_bbb.update(delbb) for db in delbb: del f_map[db] # dive down for ch in pdt.successors(node): do_level(ch) ambiguous_bbb = set() pdt = self.bFlow.predom_tree().get_tree() do_level(self.bFlow.predom_tree().get_root()) # -- return ambiguous_bbb log.info( "Running dominator homomorphism mapping on '{}', order: {}". format(btfg.name, self.hom_order)) if self.hom_order == 'predominated-first': worklist = SortedKeyList(iterable=nodes_b, key=lambda x: -self.bFlow.predom_tree( ).get_preorder_number(x)) elif self.hom_order == 'postdominated-first': worklist = SortedKeyList(iterable=nodes_b, key=lambda x: -self.bFlow. postdom_tree().get_preorder_number(x)) elif self.hom_order == 'predominator-first': worklist = SortedKeyList( iterable=nodes_b, key=self.bFlow.predom_tree().get_preorder_number) elif self.hom_order == 'postdominator-first': worklist = SortedKeyList( iterable=nodes_b, key=self.bFlow.postdom_tree().get_preorder_number) else: assert False, "Invalid argument (self.hom_order)." # Add known relations between entry and exit nodes of subgraphs & test for safety f_map = dict() f_map.update(fixed_points) log.debug("Fixed points={}".format(f_map.items())) assert test_homomorphism(f_map.keys()) == 0, \ "Initial homomorphism test failed for fixed points." f_confl = {n: dict() for n in nodes_b} f_confl.update({n: dict() for n in fixed_points.keys()}) log.debug("Initial worklist={}".format(worklist)) rounds = 0 while len(worklist) > 0: rounds += 1 # Select non conflicting elements for all in worklist for _ in range(len(worklist)): if self.hom_order == 'pre': b = worklist.pop( -1 ) # using preDom, matching bin dominated (body) first else: b = worklist.pop( 0 ) # using postDom, matching bin dominator (header) first log.debug("Current worklist element: {}".format(b)) if b in fixed_points.keys(): continue # don't touch a = select_reference( b) # multiple b's might pull the same a here. if a is None: log.debug("Only conflicting references for {} left...". format(b)) continue else: f_map[b] = a if not self.quick: # avoids spurious conflicts, but is at least O(n^3) break # Test for homomorphism and reject those violating it rejected = False test_nodes = { k for k, v in f_map.iteritems() if v is not None } # was: nodes_b for b in test_nodes: # reversing improves run-time (heuristic) for b_ in test_nodes: if b_ == b: continue a = f_map.get(b, None) a_ = f_map.get(b_, None) if a is None or a_ is None: # could still be None if we removed it continue # FIXME: could cache the following fwd_fail = self.bFlow.predom_tree().test_dominance( translate_id(b, True), translate_id(b_, True)) != \ self.sFlow.predom_tree().test_dominance( translate_id(a, False), translate_id(a_, False)) rev_fail = self.bFlow.predom_tree().test_dominance( translate_id(b_, True), translate_id(b, True)) != \ self.sFlow.predom_tree().test_dominance( translate_id(a_, False), translate_id(a, False)) if fwd_fail or rev_fail: log.debug( "Dominance check failed: b,a=({},{}) ; b_,a_=({},{})" .format(b, a, b_, a_) + ". Fail type: {}".format( 'both' if fwd_fail and rev_fail else ('fwd' if fwd_fail else 'rev'))) add_conflict(b, a, b_, a_) add_back_to_worklist(b) # and remove from map add_back_to_worklist(b_) rejected = True if not rejected: log.debug("Nothing was rejected by homomorphism") log.debug("Map after {} rounds: {}".format( rounds, {k: v for k, v in f_map.iteritems() if v is not None})) log.debug( "Homomorphism mapper finished on {} after {} rounds".format( btfg.name, rounds)) # some undistinguishable BBs might have been mapped. Remove to prevent switching some. rem_bbs = remove_ambiguous() if rem_bbs: log.info("{}: Removed {} ambiguous map entries: {}".format( btfg.name, len(rem_bbs), rem_bbs)) report['ambiguous-bin'] = rem_bbs # -- g = GraphMap(gA=btfg.flow, gB=stfg.flow, dict_map=f_map, name="dominator homomorphism") return g
class OfflineLearner: """An offline method for evaluating surprising results that the bot has encountered Attributes: surprising (SortedList[(loss, args, kwargs)]): surprising events we have seen recently, ordered from [0] being the most surprising to [-1] being the least surprising heap_size (int): max size we let the heap get to callback (callable[inputs] -> float): takes the inputs and returns samples (int): size of our rolling window random_skip_factor (float): the probability that we just skip over an element when training (0-1, 1 excluded) sum_roll_loss (float): sum(rolling_loss, 0) sum_roll_loss_sqd (float): sum((x*x for x in rolling_loss), 0) rolling_loss (deque[float]): the rolling loss """ def __init__(self, callback: typing.Callable, heap_size: int = 10, samples: int = 100, random_skip_factor: float = 0.7): self.surprising = SortedKeyList(key=lambda x: -x[0]) self.callback = callback self.heap_size = heap_size self.samples = samples self.sum_roll_loss = 0 self.sum_roll_loss_sqd = 0 self.rolling_loss = deque() self.random_skip_factor = random_skip_factor def __call__(self, *args, **kwargs) -> None: """Registers the given arguments and keyword arguments with the given loss""" loss = self.callback(*args, **kwargs) self._handle(loss, 1, args, kwargs) def _handle(self, loss, counter, args, kwargs): if counter > 5: return if counter == 1: popped = self.rolling_loss.popleft() if len( self.rolling_loss) >= self.samples else 0 self.rolling_loss.append(loss) self.sum_roll_loss += loss - popped self.sum_roll_loss_sqd += loss * loss - popped * popped if len(self.rolling_loss) < self.samples: return meas_first = self.sum_roll_loss / len(self.rolling_loss) meas_second = self.sum_roll_loss_sqd / len(self.rolling_loss) variance = meas_second - (meas_first**2) std_dev = math.sqrt(variance) if -2 * std_dev < loss - meas_first < 2 * std_dev: return if counter == 1: print(f'[offline] found something surprising: {loss} ' + f'(mean: {self.sum_roll_loss / self.samples}, ' + f'vari: {variance}, std: {std_dev})') sys.stdout.flush() if len(self.surprising) >= self.heap_size: if loss < self.surprising[-1][0]: return self.surprising.pop() self.surprising.add((loss, args, kwargs, counter)) def think(self, maxtime: float): """Thinks for the specified amount of time by invoking the callback with the arguments passed to register in order of most to least surprising """ end = time.time() + maxtime max_todo = len(self.surprising) done = 0 to_add = [] while time.time() < end and self.surprising: loss, args, kwargs, counter = self.surprising.pop(0) if random.random() < self.random_skip_factor: to_add.append((loss, args, kwargs, counter)) done += 1 continue loss = self.callback(*args, **kwargs) self._handle(loss, counter + 1, args, kwargs) done += 1 if done >= max_todo: break for val in to_add: self.surprising.add(val)
def test_pop_indexerror2(): slt = SortedKeyList(range(10), key=modulo) slt._reset(4) with pytest.raises(IndexError): slt.pop(10)
class RangeModule: def __init__(self): self.ranges = SortedKeyList(key=lambda t: t[0]) def addRange(self, left: int, right: int) -> None: a, b = left, right if len(self.ranges) == 0: self.ranges.add((a, b)) return i = self.ranges.bisect_key_left(left) if i - 1 >= 0: if self.ranges[i-1][0] <= a <= self.ranges[i-1][1]: i -= 1 elif i == len(self.ranges): self.ranges.add((a, b)) return add = False for j in range(i, len(self.ranges)): c, d = self.ranges[j] if max(a, c) < min(b, d): self.ranges.pop(j) x, z = min(a, c), max(b, d) self.ranges.add((x, z)) a, b = x, z add = True else: add = True break if add: self.ranges.add((a, b)) def queryRange(self, left: int, right: int) -> bool: a, b = left, right i = self.ranges.bisect_key_left(left) if i - 1 >= 0: if self.ranges[i-1][0] <= a <= self.ranges[i-1][1]: i -= 1 elif i == len(self.ranges): return False for j in range(i, len(self.ranges)): c, d = self.ranges[j] x, y = max(a, c), min(b, d) if (x, y) == (c, d): if b == d: break else: a = d elif (x, y) == (a, b): break else: return False return True def removeRange(self, left: int, right: int) -> None: a, b = left, right i = self.ranges.bisect_key_left(left) if i - 1 >= 0: if self.ranges[i-1][0] <= a <= self.ranges[i-1][1]: i -= 1 elif i == len(self.ranges): return add = True for j in range(i, len(self.ranges)): c, d = self.ranges[j] y, w = max(a, c), min(b, d) if y < w: self.ranges.pop(j) x, z = min(a, c), max(b, d) if x < y: self.ranges.add((x, y)) add = True a, b = w, z else: add = False break if add: self.ranges.add((a, b))
def astar(start_pose, goal_position, occupancy_grid): # Open list contain nodes to be checked while closed list are inspected nodes open_list = SortedKeyList(key=lambda x: x.cost) # Define goal position and occupancy grid for all Nodes Node.occupancy_grid = occupancy_grid Node.goal = occupancy_grid.get_position( *occupancy_grid.get_index(goal_position)) start_node = Node(start_pose) final_node = Node(goal_position) if not occupancy_grid.is_free(goal_position): print('Goal position is not in the free space.') return None explored = 0 open_list.add(start_node) while open_list: current_node = open_list.pop(0) if occupancy_grid.is_explored( occupancy_grid.get_index(current_node.position)): continue occupancy_grid.set_explored(current_node.indices) explored += 1 # if explored % 100 == 0: # print(explored) # occupancy_grid.sketch() # Found the goal if current_node == final_node: path = [] current = current_node while current is not None: path.append(current.position) occupancy_grid.set_path(current.indices) current = current.parent return path # Generate children children = [] for step in [(0, -1), (0, 1), (-1, 0), (1, 0), (1, 1), (-1, 1), (1, -1), (-1, -1)]: # Adjacent paths node_position = (current_node.indices[0] + step[0], current_node.indices[1] + step[1]) if not occupancy_grid.is_free_index(node_position): continue new_node = Node(occupancy_grid.get_position(*node_position)) new_node.parent = current_node new_node.pathcost = current_node.pathcost + np.sqrt( sum([abs(x) for x in step])) * occupancy_grid.resolution children.append(new_node) for child in children: # Check if new node is already in the open list to_add = True for i, open_node in enumerate(open_list): if child == open_node: to_add = False if child.pathcost < open_node.pathcost: del open_list[i] open_list.add(child) break # Add new node to the open list if to_add: open_list.add(child) return None
class LearnerND(BaseLearner): """Learns and predicts a function 'f: ℝ^N → ℝ^M'. Parameters ---------- func: callable The function to learn. Must take a tuple of N real parameters and return a real number or an arraylike of length M. bounds : list of 2-tuples or `scipy.spatial.ConvexHull` A list ``[(a_1, b_1), (a_2, b_2), ..., (a_n, b_n)]`` containing bounds, one pair per dimension. Or a ConvexHull that defines the boundary of the domain. loss_per_simplex : callable, optional A function that returns the loss for a simplex. If not provided, then a default is used, which uses the deviation from a linear estimate, as well as triangle area, to determine the loss. Attributes ---------- data : dict Sampled points and values. points : numpy array Coordinates of the currently known points values : numpy array The values of each of the known points pending_points : set Points that still have to be evaluated. Notes ----- The sample points are chosen by estimating the point where the gradient is maximal. This is based on the currently known points. In practice, this sampling protocol results to sparser sampling of flat regions, and denser sampling of regions where the function has a high gradient, which is useful if the function is expensive to compute. This sampling procedure is not fast, so to benefit from it, your function needs to be slow enough to compute. This class keeps track of all known points. It triangulates these points and with every simplex it associates a loss. Then if you request points that you will compute in the future, it will subtriangulate a real simplex with the pending points inside it and distribute the loss among it's children based on volume. """ def __init__(self, func, bounds, loss_per_simplex=None): self._vdim = None self.loss_per_simplex = loss_per_simplex or default_loss if hasattr(self.loss_per_simplex, "nth_neighbors"): if self.loss_per_simplex.nth_neighbors > 1: raise NotImplementedError( "The provided loss function wants " "next-nearest neighboring simplices for the loss computation, " "this feature is not yet implemented, either use " "nth_neightbors = 0 or 1" ) self.nth_neighbors = self.loss_per_simplex.nth_neighbors else: self.nth_neighbors = 0 self.data = OrderedDict() self.pending_points = set() self.bounds = bounds if isinstance(bounds, scipy.spatial.ConvexHull): hull_points = bounds.points[bounds.vertices] self._bounds_points = sorted(list(map(tuple, hull_points))) self._bbox = tuple(zip(hull_points.min(axis=0), hull_points.max(axis=0))) self._interior = scipy.spatial.Delaunay(self._bounds_points) else: self._bounds_points = sorted(list(map(tuple, itertools.product(*bounds)))) self._bbox = tuple(tuple(map(float, b)) for b in bounds) self.ndim = len(self._bbox) self.function = func self._tri = None self._losses = dict() self._pending_to_simplex = dict() # vertex → simplex # triangulation of the pending points inside a specific simplex self._subtriangulations = dict() # simplex → triangulation # scale to unit hypercube # for the input self._transform = np.linalg.inv(np.diag(np.diff(self._bbox).flat)) # for the output self._min_value = None self._max_value = None self._output_multiplier = ( 1 # If we do not know anything, do not scale the values ) self._recompute_losses_factor = 1.1 # create a private random number generator with fixed seed self._random = random.Random(1) # all real triangles that have not been subdivided and the pending # triangles heap of tuples (-loss, real simplex, sub_simplex or None) # _simplex_queue is a heap of tuples (-loss, real_simplex, sub_simplex) # It contains all real and pending simplices except for real simplices # that have been subdivided. # _simplex_queue may contain simplices that have been deleted, this is # because deleting those items from the heap is an expensive operation, # so when popping an item, you should check that the simplex that has # been returned has not been deleted. This checking is done by # _pop_highest_existing_simplex self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority) @property def npoints(self): """Number of evaluated points.""" return len(self.data) @property def vdim(self): """Length of the output of ``learner.function``. If the output is unsized (when it's a scalar) then `vdim = 1`. As long as no data is known `vdim = 1`. """ if self._vdim is None and self.data: try: value = next(iter(self.data.values())) self._vdim = len(value) except TypeError: self._vdim = 1 return self._vdim if self._vdim is not None else 1 def to_numpy(self): """Data as NumPy array of size ``(npoints, dim+vdim)``, where ``dim`` is the size of the input dimension and ``vdim`` is the length of the return value of ``learner.function``.""" return np.array([(*p, *np.atleast_1d(v)) for p, v in sorted(self.data.items())]) @property def bounds_are_done(self): return all(p in self.data for p in self._bounds_points) def _ip(self): """A `scipy.interpolate.LinearNDInterpolator` instance containing the learner's data.""" # XXX: take our own triangulation into account when generating the _ip return interpolate.LinearNDInterpolator(self.points, self.values) @property def tri(self): """An `adaptive.learner.triangulation.Triangulation` instance with all the points of the learner.""" if self._tri is not None: return self._tri try: self._tri = Triangulation(self.points) except ValueError: # A ValueError is raised if we do not have enough points or # the provided points are coplanar, so we need more points to # create a valid triangulation return None self._update_losses(set(), self._tri.simplices) return self._tri @property def values(self): """Get the values from `data` as a numpy array.""" return np.array(list(self.data.values()), dtype=float) @property def points(self): """Get the points from `data` as a numpy array.""" return np.array(list(self.data.keys()), dtype=float) def tell(self, point, value): point = tuple(point) if point in self.data: return # we already know about the point if value is None: return self.tell_pending(point) self.pending_points.discard(point) tri = self.tri self.data[point] = value if not self.inside_bounds(point): return self._update_range(value) if tri is not None: simplex = self._pending_to_simplex.get(point) if simplex is not None and not self._simplex_exists(simplex): simplex = None to_delete, to_add = tri.add_point(point, simplex, transform=self._transform) self._update_losses(to_delete, to_add) def _simplex_exists(self, simplex): simplex = tuple(sorted(simplex)) return simplex in self.tri.simplices def inside_bounds(self, point): """Check whether a point is inside the bounds.""" if hasattr(self, "_interior"): return self._interior.find_simplex(point, tol=1e-8) >= 0 else: eps = 1e-8 return all( (mn - eps) <= p <= (mx + eps) for p, (mn, mx) in zip(point, self._bbox) ) def tell_pending(self, point, *, simplex=None): point = tuple(point) if not self.inside_bounds(point): return self.pending_points.add(point) if self.tri is None: return simplex = tuple(simplex or self.tri.locate_point(point)) if not simplex: return # Simplex is None if pending point is outside the triangulation, # then you do not have subtriangles simplex = tuple(simplex) simplices = [self.tri.vertex_to_simplices[i] for i in simplex] neighbors = set.union(*simplices) # Neighbours also includes the simplex itself for simpl in neighbors: _, to_add = self._try_adding_pending_point_to_simplex(point, simpl) if to_add is None: continue self._update_subsimplex_losses(simpl, to_add) def _try_adding_pending_point_to_simplex(self, point, simplex): # try to insert it if not self.tri.point_in_simplex(point, simplex): return None, None if simplex not in self._subtriangulations: vertices = self.tri.get_vertices(simplex) self._subtriangulations[simplex] = Triangulation(vertices) self._pending_to_simplex[point] = simplex return self._subtriangulations[simplex].add_point(point) def _update_subsimplex_losses(self, simplex, new_subsimplices): loss = self._losses[simplex] loss_density = loss / self.tri.volume(simplex) subtriangulation = self._subtriangulations[simplex] for subsimplex in new_subsimplices: subloss = subtriangulation.volume(subsimplex) * loss_density self._simplex_queue.add((subloss, simplex, subsimplex)) def _ask_and_tell_pending(self, n=1): xs, losses = zip(*(self._ask() for _ in range(n))) return list(xs), list(losses) def ask(self, n, tell_pending=True): """Chose points for learners.""" if not tell_pending: with restore(self): return self._ask_and_tell_pending(n) else: return self._ask_and_tell_pending(n) def _ask_bound_point(self): # get the next bound point that is still available new_point = next( p for p in self._bounds_points if p not in self.data and p not in self.pending_points ) self.tell_pending(new_point) return new_point, np.inf def _ask_point_without_known_simplices(self): assert not self._bounds_available # pick a random point inside the bounds # XXX: change this into picking a point based on volume loss a = np.diff(self._bbox).flat b = np.array(self._bbox)[:, 0] p = None while p is None or not self.inside_bounds(p): r = np.array([self._random.random() for _ in range(self.ndim)]) p = r * a + b p = tuple(p) self.tell_pending(p) return p, np.inf def _pop_highest_existing_simplex(self): # find the simplex with the highest loss, we do need to check that the # simplex hasn't been deleted yet while len(self._simplex_queue): # XXX: Need to add check that the loss is the most recent computed loss loss, simplex, subsimplex = self._simplex_queue.pop(0) if ( subsimplex is None and simplex in self.tri.simplices and simplex not in self._subtriangulations ): return abs(loss), simplex, subsimplex if ( simplex in self._subtriangulations and simplex in self.tri.simplices and subsimplex in self._subtriangulations[simplex].simplices ): return abs(loss), simplex, subsimplex # Could not find a simplex, this code should never be reached assert self.tri is not None raise AssertionError( "Could not find a simplex to subdivide. Yet there should always" " be a simplex available if LearnerND.tri() is not None." ) def _ask_best_point(self): assert self.tri is not None loss, simplex, subsimplex = self._pop_highest_existing_simplex() if subsimplex is None: # We found a real simplex and want to subdivide it points = self.tri.get_vertices(simplex) else: # We found a pending simplex and want to subdivide it subtri = self._subtriangulations[simplex] points = subtri.get_vertices(subsimplex) point_new = tuple(choose_point_in_simplex(points, transform=self._transform)) self._pending_to_simplex[point_new] = simplex self.tell_pending(point_new, simplex=simplex) # O(??) return point_new, loss @property def _bounds_available(self): return any( (p not in self.pending_points and p not in self.data) for p in self._bounds_points ) def _ask(self): if self._bounds_available: return self._ask_bound_point() # O(1) if self.tri is None: # All bound points are pending or have been evaluated, but we do not # have enough evaluated points to construct a triangulation, so we # pick a random point return self._ask_point_without_known_simplices() # O(1) return self._ask_best_point() # O(log N) def _compute_loss(self, simplex): # get the loss vertices = self.tri.get_vertices(simplex) values = [self.data[tuple(v)] for v in vertices] # scale them to a cube with sides 1 vertices = vertices @ self._transform values = self._output_multiplier * np.array(values) if self.nth_neighbors == 0: # compute the loss on the scaled simplex return float( self.loss_per_simplex(vertices, values, self._output_multiplier) ) # We do need the neighbors neighbors = self.tri.get_opposing_vertices(simplex) neighbor_points = self.tri.get_vertices(neighbors) neighbor_values = [self.data.get(x, None) for x in neighbor_points] for i, point in enumerate(neighbor_points): if point is not None: neighbor_points[i] = point @ self._transform for i, value in enumerate(neighbor_values): if value is not None: neighbor_values[i] = self._output_multiplier * value return float( self.loss_per_simplex( vertices, values, self._output_multiplier, neighbor_points, neighbor_values, ) ) def _update_losses(self, to_delete: set, to_add: set): # XXX: add the points outside the triangulation to this as well pending_points_unbound = set() for simplex in to_delete: loss = self._losses.pop(simplex, None) subtri = self._subtriangulations.pop(simplex, None) if subtri is not None: pending_points_unbound.update(subtri.vertices) pending_points_unbound = { p for p in pending_points_unbound if p not in self.data } for simplex in to_add: loss = self._compute_loss(simplex) self._losses[simplex] = loss for p in pending_points_unbound: self._try_adding_pending_point_to_simplex(p, simplex) if simplex not in self._subtriangulations: self._simplex_queue.add((loss, simplex, None)) continue self._update_subsimplex_losses( simplex, self._subtriangulations[simplex].simplices ) if self.nth_neighbors: points_of_added_simplices = set.union(*[set(s) for s in to_add]) neighbors = ( self.tri.get_simplices_attached_to_points(points_of_added_simplices) - to_add ) for simplex in neighbors: loss = self._compute_loss(simplex) self._losses[simplex] = loss if simplex not in self._subtriangulations: self._simplex_queue.add((loss, simplex, None)) continue self._update_subsimplex_losses( simplex, self._subtriangulations[simplex].simplices ) def _recompute_all_losses(self): """Recompute all losses and pending losses.""" # amortized O(N) complexity if self.tri is None: return # reset the _simplex_queue self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority) # recompute all losses for simplex in self.tri.simplices: loss = self._compute_loss(simplex) self._losses[simplex] = loss # now distribute it around the the children if they are present if simplex not in self._subtriangulations: self._simplex_queue.add((loss, simplex, None)) continue self._update_subsimplex_losses( simplex, self._subtriangulations[simplex].simplices ) @property def _scale(self): # get the output scale return self._max_value - self._min_value def _update_range(self, new_output): if self._min_value is None or self._max_value is None: # this is the first point, nothing to do, just set the range self._min_value = np.min(new_output) self._max_value = np.max(new_output) self._old_scale = self._scale or 1 return False # if range in one or more directions is doubled, then update all losses self._min_value = min(self._min_value, np.min(new_output)) self._max_value = max(self._max_value, np.max(new_output)) scale_multiplier = 1 / (self._scale or 1) # the maximum absolute value that is in the range. Because this is the # largest number, this also has the largest absolute numerical error. max_absolute_value_in_range = max(abs(self._min_value), abs(self._max_value)) # since a float has a relative error of 1e-15, the absolute error is the value * 1e-15 abs_err = 1e-15 * max_absolute_value_in_range # when scaling the floats, the error gets increased. scaled_err = abs_err * scale_multiplier # do not scale along the axis if the numerical error gets too big if scaled_err > 1e-2: # allowed_numerical_error = 1e-2 scale_multiplier = 1 self._output_multiplier = scale_multiplier scale_factor = self._scale / self._old_scale if scale_factor > self._recompute_losses_factor: self._old_scale = self._scale self._recompute_all_losses() return True return False @cache_latest def loss(self, real=True): # XXX: compute pending loss if real == False losses = self._losses if self.tri is not None else dict() return max(losses.values()) if losses else float("inf") def remove_unfinished(self): # XXX: implement this method self.pending_points = set() self._subtriangulations = dict() self._pending_to_simplex = dict() ########################## # Plotting related stuff # ########################## def plot(self, n=None, tri_alpha=0): """Plot the function we want to learn, only works in 2D. Parameters ---------- n : int the number of boxes in the interpolation grid along each axis tri_alpha : float (0 to 1) Opacity of triangulation lines """ hv = ensure_holoviews() if self.vdim > 1: raise NotImplementedError( "holoviews currently does not support", "3D surface plots in bokeh." ) if self.ndim != 2: raise NotImplementedError( "Only 2D plots are implemented: You can " "plot a 2D slice with 'plot_slice'." ) x, y = self._bbox lbrt = x[0], y[0], x[1], y[1] if len(self.data) >= 4: if n is None: # Calculate how many grid points are needed. # factor from A=√3/4 * a² (equilateral triangle) scale_factor = np.product(np.diag(self._transform)) a_sq = np.sqrt(np.min(self.tri.volumes()) * scale_factor) n = max(10, int(0.658 / a_sq)) xs = ys = np.linspace(0, 1, n) xs = xs * (x[1] - x[0]) + x[0] ys = ys * (y[1] - y[0]) + y[0] z = self._ip()(xs[:, None], ys[None, :]).squeeze() im = hv.Image(np.rot90(z), bounds=lbrt) if tri_alpha: points = np.array( [self.tri.get_vertices(s) for s in self.tri.simplices] ) points = np.pad( points[:, [0, 1, 2, 0], :], pad_width=((0, 0), (0, 1), (0, 0)), mode="constant", constant_values=np.nan, ).reshape(-1, 2) tris = hv.EdgePaths([points]) else: tris = hv.EdgePaths([]) else: im = hv.Image([], bounds=lbrt) tris = hv.EdgePaths([]) im_opts = dict(cmap="viridis") tri_opts = dict(line_width=0.5, alpha=tri_alpha) no_hover = dict(plot=dict(inspection_policy=None, tools=[])) return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover) def plot_slice(self, cut_mapping, n=None): """Plot a 1D or 2D interpolated slice of a N-dimensional function. Parameters ---------- cut_mapping : dict (int → float) for each fixed dimension the value, the other dimensions are interpolated. e.g. ``cut_mapping = {0: 1}``, so from dimension 0 ('x') to value 1. n : int the number of boxes in the interpolation grid along each axis """ hv = ensure_holoviews() plot_dim = self.ndim - len(cut_mapping) if plot_dim == 1: if not self.data: return hv.Scatter([]) * hv.Path([]) elif self.vdim > 1: raise NotImplementedError( "multidimensional output not yet supported by `plot_slice`" ) n = n or 201 values = [ cut_mapping.get(i, np.linspace(*self._bbox[i], n)) for i in range(self.ndim) ] ind = next(i for i in range(self.ndim) if i not in cut_mapping) x = values[ind] y = self._ip()(*values) p = hv.Path((x, y)) # Plot with 5% margins such that the boundary points are visible margin = 0.05 / self._transform[ind, ind] plot_bounds = (x.min() - margin, x.max() + margin) return p.redim(x=dict(range=plot_bounds)) elif plot_dim == 2: if self.vdim > 1: raise NotImplementedError( "holoviews currently does not support 3D surface plots in bokeh." ) if n is None: # Calculate how many grid points are needed. # factor from A=√3/4 * a² (equilateral triangle) scale_factor = np.product(np.diag(self._transform)) a_sq = np.sqrt(np.min(self.tri.volumes()) * scale_factor) n = max(10, int(0.658 / a_sq)) xs = ys = np.linspace(0, 1, n) xys = [xs[:, None], ys[None, :]] values = [ cut_mapping[i] if i in cut_mapping else xys.pop(0) * (b[1] - b[0]) + b[0] for i, b in enumerate(self._bbox) ] lbrt = [b for i, b in enumerate(self._bbox) if i not in cut_mapping] lbrt = np.reshape(lbrt, (2, 2)).T.flatten().tolist() if len(self.data) >= 4: z = self._ip()(*values).squeeze() im = hv.Image(np.rot90(z), bounds=lbrt) else: im = hv.Image([], bounds=lbrt) return im.opts(style=dict(cmap="viridis")) else: raise ValueError("Only 1 or 2-dimensional plots can be generated.") def plot_3D(self, with_triangulation=False): """Plot the learner's data in 3D using plotly. Does *not* work with the `adaptive.notebook_integration.live_plot` functionality. Parameters ---------- with_triangulation : bool, default: False Add the verticices to the plot. Returns ------- plot : `plotly.offline.iplot` object The 3D plot of ``learner.data``. """ plotly = ensure_plotly() plots = [] vertices = self.tri.vertices if with_triangulation: Xe, Ye, Ze = [], [], [] for simplex in self.tri.simplices: for s in itertools.combinations(simplex, 2): Xe += [vertices[i][0] for i in s] + [None] Ye += [vertices[i][1] for i in s] + [None] Ze += [vertices[i][2] for i in s] + [None] plots.append( plotly.graph_objs.Scatter3d( x=Xe, y=Ye, z=Ze, mode="lines", line=dict(color="rgb(125,125,125)", width=1), hoverinfo="none", ) ) Xn, Yn, Zn = zip(*vertices) colors = [self.data[p] for p in self.tri.vertices] marker = dict( symbol="circle", size=3, color=colors, colorscale="Viridis", line=dict(color="rgb(50,50,50)", width=0.5), ) plots.append( plotly.graph_objs.Scatter3d( x=Xn, y=Yn, z=Zn, mode="markers", name="actors", marker=marker, hoverinfo="text", ) ) axis = dict( showbackground=False, showline=False, zeroline=False, showgrid=False, showticklabels=False, title="", ) layout = plotly.graph_objs.Layout( showlegend=False, scene=dict(xaxis=axis, yaxis=axis, zaxis=axis), margin=dict(t=100), hovermode="closest", ) fig = plotly.graph_objs.Figure(data=plots, layout=layout) return plotly.offline.iplot(fig) def _get_data(self): return self.data def _set_data(self, data): if data: self.tell_many(*zip(*data.items())) def _get_iso(self, level=0.0, which="surface"): if which == "surface": if self.ndim != 3 or self.vdim != 1: raise Exception( "Isosurface plotting is only supported" " for a 3D input and 1D output" ) get_surface = True get_line = False elif which == "line": if self.ndim != 2 or self.vdim != 1: raise Exception( "Isoline plotting is only supported for a 2D input and 1D output" ) get_surface = False get_line = True vertices = [] # index -> (x,y,z) faces_or_lines = [] # tuple of indices of the corner points @functools.lru_cache() def _get_vertex_index(a, b): vertex_a = self.tri.vertices[a] vertex_b = self.tri.vertices[b] value_a = self.data[vertex_a] value_b = self.data[vertex_b] da = abs(value_a - level) db = abs(value_b - level) dab = da + db new_pt = db / dab * np.array(vertex_a) + da / dab * np.array(vertex_b) new_index = len(vertices) vertices.append(new_pt) return new_index for simplex in self.tri.simplices: plane_or_line = [] for a, b in itertools.combinations(simplex, 2): va = self.data[self.tri.vertices[a]] vb = self.data[self.tri.vertices[b]] if min(va, vb) < level <= max(va, vb): vi = _get_vertex_index(a, b) should_add = True for pi in plane_or_line: if np.allclose(vertices[vi], vertices[pi]): should_add = False if should_add: plane_or_line.append(vi) if get_surface and len(plane_or_line) == 3: faces_or_lines.append(plane_or_line) elif get_surface and len(plane_or_line) == 4: faces_or_lines.append(plane_or_line[:3]) faces_or_lines.append(plane_or_line[1:]) elif get_line and len(plane_or_line) == 2: faces_or_lines.append(plane_or_line) if len(faces_or_lines) == 0: r_min = min(self.data[v] for v in self.tri.vertices) r_max = max(self.data[v] for v in self.tri.vertices) raise ValueError( f"Could not draw isosurface for level={level}, as" " this value is not inside the function range. Please choose" f" a level strictly inside interval ({r_min}, {r_max})" ) return vertices, faces_or_lines def plot_isoline(self, level=0.0, n=None, tri_alpha=0): """Plot the isoline at a specific level, only works in 2D. Parameters ---------- level : float, default: 0 The value of the function at which you would like to see the isoline. n : int The number of boxes in the interpolation grid along each axis. This is passed to `plot`. tri_alpha : float The opacity of the overlaying triangulation. This is passed to `plot`. Returns ------- `holoviews.core.Overlay` The plot of the isoline(s). This overlays a `plot` with a `holoviews.element.Path`. """ hv = ensure_holoviews() if n == -1: plot = hv.Path([]) else: plot = self.plot(n=n, tri_alpha=tri_alpha) if isinstance(level, Iterable): for l in level: plot = plot * self.plot_isoline(level=l, n=-1) return plot vertices, lines = self._get_iso(level, which="line") paths = [[vertices[i], vertices[j]] for i, j in lines] contour = hv.Path(paths) contour_opts = dict(color="black") contour = contour.opts(style=contour_opts) return plot * contour def plot_isosurface(self, level=0.0, hull_opacity=0.2): """Plots a linearly interpolated isosurface. This is the 3D analog of an isoline. Does *not* work with the `adaptive.notebook_integration.live_plot` functionality. Parameters ---------- level : float, default: 0.0 the function value which you are interested in. hull_opacity : float, default: 0.0 the opacity of the hull of the domain. Returns ------- plot : `plotly.offline.iplot` object The plot object of the isosurface. """ plotly = ensure_plotly() vertices, faces = self._get_iso(level, which="surface") x, y, z = zip(*vertices) fig = plotly.figure_factory.create_trisurf( x=x, y=y, z=z, plot_edges=False, simplices=faces, title="Isosurface" ) isosurface = fig.data[0] isosurface.update( lighting=dict(ambient=1, diffuse=1, roughness=1, specular=0, fresnel=0) ) if hull_opacity < 1e-3: # Do not compute the hull_mesh. return plotly.offline.iplot(fig) hull_mesh = self._get_hull_mesh(opacity=hull_opacity) return plotly.offline.iplot([isosurface, hull_mesh]) def _get_hull_mesh(self, opacity=0.2): plotly = ensure_plotly() hull = scipy.spatial.ConvexHull(self._bounds_points) # Find the colors of each plane, giving triangles which are coplanar # the same color, such that a square face has the same color. color_dict = {} def _get_plane_color(simplex): simplex = tuple(simplex) # If the volume of the two triangles combined is zero then they # belong to the same plane. for simplex_key, color in color_dict.items(): points = [hull.points[i] for i in set(simplex_key + simplex)] points = np.array(points) if np.linalg.matrix_rank(points[1:] - points[0]) < 3: return color if scipy.spatial.ConvexHull(points).volume < 1e-5: return color color_dict[simplex] = tuple(random.randint(0, 255) for _ in range(3)) return color_dict[simplex] colors = [_get_plane_color(simplex) for simplex in hull.simplices] x, y, z = zip(*self._bounds_points) i, j, k = hull.simplices.T lighting = dict(ambient=1, diffuse=1, roughness=1, specular=0, fresnel=0) return plotly.graph_objs.Mesh3d( x=x, y=y, z=z, i=i, j=j, k=k, facecolor=colors, opacity=opacity, lighting=lighting, )