class SortedSet(MutableSet, Sequence): """ A `SortedSet` provides the same methods as a `set`. Additionally, a `SortedSet` maintains its items in sorted order, allowing the `SortedSet` to be indexed. Unlike a `set`, a `SortedSet` requires items be hashable and comparable. """ def __init__(self, iterable=None, load=1000, _set=None): if _set is None: self._set = set() else: self._set = set self._list = SortedList(self._set, load=load) if iterable is not None: self.update(iterable) def __contains__(self, value): return (value in self._set) def __getitem__(self, index): if isinstance(index, slice): return SortedSet(self._list[index]) else: return self._list[index] def __delitem__(self, index): _list = self._list if isinstance(index, slice): values = _list[index] self._set.difference_update(values) else: value = _list[index] self._set.remove(value) del _list[index] def __setitem__(self, index, value): _list, _set = self._list, self._set prev = _list[index] _list[index] = value if isinstance(index, slice): _set.difference_update(prev) _set.update(prev) else: _set.remove(prev) _set.add(prev) def __eq__(self, that): if len(self) != len(that): return False if isinstance(that, SortedSet): return (self._list == that._list) elif isinstance(that, set): return (self._set == that) else: _set = self._set return all(val in _set for val in that) def __ne__(self, that): if len(self) != len(that): return True if isinstance(that, SortedSet): return (self._list != that._list) elif isinstance(that, set): return (self._set != that) else: _set = self._set return any(val not in _set for val in that) def __lt__(self, that): if isinstance(that, set): return (self._set < that) else: return (len(self) < len(that)) and all(val in that for val in self._list) def __gt__(self, that): if isinstance(that, set): return (self._set > that) else: _set = self._set return (len(self) > len(that)) and all(val in _set for val in that) def __le__(self, that): if isinstance(that, set): return (self._set <= that) else: return all(val in that for val in self._list) def __ge__(self, that): if isinstance(that, set): return (self._set >= that) else: _set = self._set return all(val in _set for val in that) def __and__(self, that): return self.intersection(that) def __or__(self, that): return self.union(that) def __sub__(self, that): return self.difference(that) def __xor__(self, that): return self.symmetric_difference(that) def __iter__(self): return iter(self._list) def __len__(self): return len(self._set) def __reversed__(self): return reversed(self._list) def add(self, value): if value not in self._set: self._set.add(value) self._list.add(value) def bisect_left(self, value): return self._list.bisect_left(value) def bisect(self, value): return self._list.bisect(value) def bisect_right(self, value): return self._list.bisect_right(value) def clear(self): self._set.clear() self._list.clear() def copy(self): return SortedSet(load=self._list._load, _set=set(self._set)) def __copy__(self): return self.copy() def count(self, value): return int(value in self._set) def discard(self, value): if value in self._set: self._set.remove(value) self._list.discard(value) def index(self, value, start=None, stop=None): return self._list.index(value, start, stop) def isdisjoint(self, that): return self._set.isdisjoint(that) def issubset(self, that): return self._set.issubset(that) def issuperset(self, that): return self._set.issuperset(that) def pop(self, index=-1): value = self._list.pop(index) self._set.remove(value) return value def remove(self, value): self._set.remove(value) self._list.remove(value) def difference(self, *iterables): diff = self._set.difference(*iterables) new_set = SortedSet(load=self._list._load, _set=diff) return new_set def difference_update(self, *iterables): values = set(chain(*iterables)) if (4 * len(values)) > len(self): self._set.difference_update(values) self._list.clear() self._list.update(self._set) else: _discard = self.discard for value in values: _discard(value) def intersection(self, *iterables): comb = self._set.intersection(*iterables) new_set = SortedSet(load=self._list._load, _set=comb) return new_set def intersection_update(self, *iterables): self._set.intersection_update(*iterables) self._list.clear() self._list.update(self._set) def symmetric_difference(self, that): diff = self._set.symmetric_difference(that) new_set = SortedSet(load=self._list._load, _set=diff) return new_set def symmetric_difference_update(self, that): self._set.symmetric_difference_update(that) self._list.clear() self._list.update(self._set) def union(self, *iterables): return SortedSet(chain(iter(self), *iterables), load=self._list._load) def update(self, *iterables): values = set(chain(*iterables)) if (4 * len(values)) > len(self): self._set.update(values) self._list.clear() self._list.update(self._set) else: _add = self.add for value in values: _add(value) @recursive_repr def __repr__(self): return '%s(%r)' % (self.__class__.__name__, list(self))
class SortedListWithKey(MutableSequence): def __init__(self, iterable=None, key=lambda val: val, value_orderable=True, load=1000): self._key = key self._list = SortedList(load=load) self._ordered = value_orderable if value_orderable: self._pair = lambda key, value: (key, value) else: self._pair = Pair if iterable is not None: self.update(iterable) def clear(self): self._list.clear() def add(self, value): pair = self._pair(self._key(value), value) self._list.add(pair) def update(self, iterable): _key, _pair = self._key, self._pair self._list.update(_pair(_key(val), val) for val in iterable) def __contains__(self, value): _list = self._list _key = self._key(value) _pair = self._pair(_key, value) if self._ordered: return _pair in _list _maxes = _list._maxes if _maxes is None: return False pos = bisect_left(_maxes, _pair) if pos == len(_maxes): return False _lists = _list._lists idx = bisect_left(_lists[pos], _pair) len_lists = len(_lists) len_sublist = len(_lists[pos]) while True: pair = _lists[pos][idx] if _key != pair.key: return False if value == pair.value: return True idx += 1 if idx == len_sublist: pos += 1 if pos == len_lists: return False len_sublist = len(_lists[pos]) idx = 0 def discard(self, value): _list = self._list _key = self._key(value) _pair = self._pair(_key, value) if self._ordered: _list.discard(_pair) return _maxes = _list._maxes if _maxes is None: return pos = bisect_left(_maxes, _pair) if pos == len(_maxes): return _lists = _list._lists idx = bisect_left(_lists[pos], _pair) len_lists = len(_lists) len_sublist = len(_lists[pos]) while True: pair = _lists[pos][idx] if _key != pair.key: return if value == pair.value: _list._delete(pos, idx) return idx += 1 if idx == len_sublist: pos += 1 if pos == len_lists: return len_sublist = len(_lists[pos]) idx = 0 def remove(self, value): _list = self._list _key = self._key(value) _pair = self._pair(_key, value) if self._ordered: _list.remove(_pair) return _maxes = _list._maxes if _maxes is None: raise ValueError pos = bisect_left(_maxes, _pair) if pos == len(_maxes): raise ValueError _lists = _list._lists idx = bisect_left(_lists[pos], _pair) len_lists = len(_lists) len_sublist = len(_lists[pos]) while True: pair = _lists[pos][idx] if _key != pair.key: raise ValueError if value == pair.value: _list._delete(pos, idx) return idx += 1 if idx == len_sublist: pos += 1 if pos == len_lists: raise ValueError len_sublist = len(_lists[pos]) idx = 0 def __delitem__(self, index): del self._list[index] def __getitem__(self, index): if isinstance(index, slice): return list(tup[1] for tup in self._list[index]) else: return self._list[index][1] def __setitem__(self, index, value): _key, _pair = self._key, self._pair if isinstance(index, slice): self._list[index] = list(_pair(_key(val), val) for val in value) else: self._list[index] = _pair(_key(value), value) def __iter__(self): return iter(tup[1] for tup in iter(self._list)) def __reversed__(self): return iter(tup[1] for tup in reversed(self._list)) def __len__(self): return len(self._list) def bisect_left(self, value): pair = self._pair(self._key(value), value) return self._list.bisect_left(pair) def bisect(self, value): pair = self._pair(self._key(value), value) return self._list.bisect(pair) def bisect_right(self, value): pair = self._pair(self._key(value), value) return self._list.bisect_right(pair) def count(self, value): _list = self._list _key = self._key(value) _pair = self._pair(_key, value) if self._ordered: return _list.count(_pair) _maxes = _list._maxes if _maxes is None: return 0 pos = bisect_left(_maxes, _pair) if pos == len(_maxes): return 0 _lists = _list._lists idx = bisect_left(_lists[pos], _pair) total = 0 len_lists = len(_lists) len_sublist = len(_lists[pos]) while True: pair = _lists[pos][idx] if _key != pair.key: return total if value == pair.value: total += 1 idx += 1 if idx == len_sublist: pos += 1 if pos == len_lists: return total len_sublist = len(_lists[pos]) idx = 0 def copy(self): _key, _ordered, _load = self._key, self._ordered, self._list._load kwargs = dict(key=_key, value_orderable=_ordered, load=_load) return SortedListWithKey(self, **kwargs) def __copy__(self): return self.copy() def append(self, value): pair = self._pair(self._key(value), value) self._list.append(pair) def extend(self, iterable): _key, _pair = self._key, self._pair self._list.extend(_pair(_key(val), val) for val in iterable) def insert(self, index, value): pair = self._pair(self._key(value), value) self._list.insert(index, pair) def pop(self, index=-1): return self._list.pop(index)[1] def index(self, value, start=None, stop=None): _list = self._list _key = self._key(value) _pair = self._pair(_key, value) if self._ordered: return _list.index(_pair, start, stop) _len = _list._len if start == None: start = 0 if start < 0: start += _len if start < 0: start = 0 if stop == None: stop = _len if stop < 0: stop += _len if stop > _len: stop = _len if stop <= start: raise ValueError _maxes = _list._maxes pos = bisect_left(_maxes, _pair) if pos == len(_maxes): raise ValueError _lists = _list._lists idx = bisect_left(_lists[pos], _pair) len_lists = len(_lists) len_sublist = len(_lists[pos]) while True: pair = _lists[pos][idx] if _key != pair.key: raise ValueError if value == pair.value: loc = _list._loc(pos, idx) if start <= loc < stop: return loc idx += 1 if idx == len_sublist: pos += 1 if pos == len_lists: raise ValueError len_sublist = len(_lists[pos]) idx = 0 def as_list(self): return list(tup[1] for tup in self._list.as_list()) def __add__(self, that): result = SortedListWithKey( key=self._key, value_orderable=self._ordered, load=self._list._load) values = self.as_list() values.extend(that) result.update(values) return result def __iadd__(self, that): self.update(that) return self def __mul__(self, that): values = self.as_list() * that return SortedListWithKey( values, key=self._key, value_orderable=self._ordered, load=self._list._load) def __imul__(self, that): values = self.as_list() * that self.clear() self.update(values) return self def __eq__(self, that): return ((len(self) == len(that)) and all(lhs == rhs for lhs, rhs in zip(self, that))) def __ne__(self, that): return ((len(self) != len(that)) or any(lhs != rhs for lhs, rhs in zip(self, that))) def __lt__(self, that): return ((len(self) <= len(that)) and all(lhs < rhs for lhs, rhs in zip(self, that))) def __le__(self, that): return ((len(self) <= len(that)) and all(lhs <= rhs for lhs, rhs in zip(self, that))) def __gt__(self, that): return ((len(self) >= len(that)) and all(lhs > rhs for lhs, rhs in zip(self, that))) def __ge__(self, that): return ((len(self) >= len(that)) and all(lhs >= rhs for lhs, rhs in zip(self, that))) @recursive_repr def __repr__(self): return '%s(%s, key=%r, value_orderable=%r, load=%r)' % ( self.__class__.__name__, self.as_list(), self._key, self._ordered, self._list._load)