class MaxStack: def __init__(self): self.by_time = SortedKeyList(key=lambda t: -t[1]) self.by_val = SortedKeyList(key=lambda t: (-t[0], -t[1])) self.time = -1 def push(self, x: int) -> None: self.time += 1 rec = (x, self.time) self.by_time.add(rec) self.by_val.add(rec) def pop(self) -> int: rec = self.by_time.pop(0) self.by_val.remove(rec) return rec[0] def top(self) -> int: rec = self.by_time[0] return rec[0] def peekMax(self) -> int: rec = self.by_val[0] return rec[0] def popMax(self) -> int: rec = self.by_val.pop(0) self.by_time.remove(rec) return rec[0]
def test_delete(): slt = SortedKeyList(range(20), key=modulo) slt._reset(4) slt._check() for val in range(20): slt.remove(val) slt._check() assert len(slt) == 0 assert slt._maxes == [] assert slt._lists == []
def test_remove(): slt = SortedKeyList(key=modulo) assert slt.discard(0) == None assert len(slt) == 0 slt._check() slt = SortedKeyList([1, 2, 2, 2, 3, 3, 5], key=modulo) slt._reset(4) slt.remove(2) slt._check() assert all(tup[0] == tup[1] for tup in zip(slt, [1, 2, 2, 3, 3, 5]))
class HeavyHitterList: def __init__(self, threshold): self.data = SortedKeyList(key=itemgetter(1)) self.threshold = threshold def append(self, x): if len(self.data) < self.threshold: self.data.add(x) else: if x[1] > self.data[0][1]: self.data.remove(self.data[0]) self.data.add(x) def get_data(self): return self.data def __str__(self): return self.data.__str__()
def test_remove_valueerror3(): slt = SortedKeyList([1, 2, 2, 2, 3, 3, 5], key=negate) with pytest.raises(ValueError): slt.remove(4)
class HandlerManager: handlers: SortedKeyList def __init__(self, bcc: Broadcast): self.handlers = SortedKeyList(key=lambda x: -x["priority"]) # 为了按优先级降序排序 @bcc.receiver(FriendMessage) async def on_receive_friend_message(app: GraiaMiraiApplication, friend: Friend, message: MessageChain): await self.__on_receive(app, friend, message) @bcc.receiver(GroupMessage) async def on_receive_group_message(app: GraiaMiraiApplication, group: Group, message: MessageChain): await self.__on_receive(app, group, message) def register(self, handler: AbstractMessageHandler, **kwargs): self.handlers.add({ "handler": handler, "priority": kwargs.get("priority", 8), "allow_friend": kwargs.get("allow_friend", None), "allow_group": kwargs.get("allow_group", None) }) def unregister(self, handler: AbstractMessageHandler): for h in self.handlers: if h["handler"] == handler: self.handlers.remove(h) break async def __on_receive(self, app: GraiaMiraiApplication, subject: T.Union[Group, Friend], message: MessageChain) -> T.NoReturn: src = message.get(Source) if len(src) == 0: src = None else: src = src[0] channel = asyncio.Queue(1) async def consumer(channel: asyncio.Queue): while True: try: msg = await channel.get() # logger.info("我摸到了!") await reply(app, subject, msg, src) # logger.info("我发完了!") channel.task_done() except asyncio.CancelledError as exc: # logger.info("我溜了!") break except Exception as exc: logger.exception(exc) channel.task_done() consumer_task = asyncio.create_task(consumer(channel)) try: for h in self.handlers: handler: AbstractMessageHandler = h["handler"] allow_group: T.Optional[T.Sequence[int]] = h["allow_group"] allow_friend: T.Optional[T.Sequence[int]] = h["allow_friend"] # 检查发送者是否有权限 if isinstance(subject, Group): if (allow_group is not None) and (subject.id not in allow_group): continue elif isinstance(subject, Friend): if (allow_friend is not None) and (subject.id not in allow_friend): continue # 若handler拦截了这条消息 try: if await handler.handle(app, subject, message, channel): break except Exception as exc: logger.exception(exc) pass finally: await channel.join() consumer_task.cancel()
class ParamsTaskQueue(disp.TaskQueue, ps.SweepListener): """This task queue acts like a greedy task queue, except it only works for tasks which amount to calling model_manager.train for a particular set of models. The process for selecting tasks is as follows: - If we have a hparameter sweep queued and we have enough cores to do one, do that. - Perform as many trials as possible. Note we can only perform trials on parameters we already know the hyperparameters for :ivar int total_cores: the number of physical cores that are available :ivar int sweep_cores: the number of cores to use for sweeping (not greater than the number of total cores) :ivar str module: the module which we are getting the model/dataset/etc :ivar HyperparameterSettings hparams: strategy for hyperparameters :ivar str folder: the folder containing the points folder :ivar list[SweepListener] listeners: the sweep listeners. contains self. :ivar set[tuple] in_progress: parameters which are currently in progress. :ivar list[ParamsTask] sweeps: the sweeps that need to be performed in an arbitrary order :ivar dict[tuple, int] params_to_sweeps_ind: a lookup that goes from the parameters of tasks to the index in sweeps if a sweep is still necessary. :ivar SortedList[ParamsTask] trials: the trials that need to be performed, in ascending order of the number of trials. :ivar dict[tuple, int] params_to_id: dictionary which converts a given set of parameters to the corresponding unique identifier for that set of parameters. :ivar dict[tuple, int] params_to_ntrials: dictionary which converts a given set of parameters to the corresponding number of trials that have been dispatched for those parameters :ivar int next_id: the next id that should be given out to a set of parameters and then incremented. :ivar dict[tuple, ParamsTask] params_to_task: a lookup that goes from params lists to param tasks, where this only contains tasks which have not yet been dispatched, and does not include tasks which are in sweeps :ivar int _len: number of actual tasks currently in queue :ivar bool expecting_more_trials: True to prevent saying we are out of trials, False otherwise """ def __init__(self, total_cores: int, sweep_cores: int, module: str, hparams: hyperparams.HyperparameterSettings, folder: str, listeners: typing.List[ps.SweepListener]): self.total_cores = total_cores self.sweep_cores = sweep_cores self.module = module self.hparams = hparams self.folder = folder self.listeners = list(listeners) self.listeners.append(self) self.in_progress = set() self.sweeps = list() self.params_to_sweeps_ind = dict() self.trials = SortedKeyList(key=lambda tsk: tsk.trials) self.params_to_id = dict() self.params_to_ntrials = dict() self.next_id = 0 self.params_to_task = dict() self._len = 0 self.expecting_more_trials = False def add_task_by_params(self, params: tuple, trials: int) -> None: """Adds a task to this queue based on the parameters which should be swept and the number of trials to perform. Regardless of the value for trials, this will ensure that the hyperparameters for the given model parameters have been found. """ sweep_id = self.params_to_id.get(params) if sweep_id is None: sweep_id = self.next_id self.next_id += 1 self.sweeps.append(ParamsTask(params, 0)) self.params_to_sweeps_ind[params] = len(self.sweeps) - 1 self.params_to_id[params] = sweep_id self.params_to_ntrials[params] = 0 self._len += 1 if trials <= 0: return tsk = self.params_to_task.get(params) if tsk is not None: self.trials.remove(tsk) tsk.trials += trials self.trials.add(tsk) return tsk = ParamsTask(params, trials) self.params_to_task[params] = tsk self.trials.add(tsk) self._len += 1 def set_total_cores(self, total_cores): self.total_cores = total_cores def on_hparams_found(self, values, lr_min, lr_max, batch_size): logger.debug('Found hyperparameters for %s: lr=(%s, %s), bs=%s', values, lr_min, lr_max, batch_size) self.in_progress.remove(values) def on_trials_completed(self, values, perfs_train, losses_train, perfs_val, losses_val): logger.info( 'Completed some trials for %s - mean train/val perf = %s / %s', values, perfs_train.mean(), perfs_val.mean()) logger.debug('%s - perf: %s, loss: %s, val - perf: %s, loss: %s', values, perfs_train, losses_train, perfs_val, losses_val) self.in_progress.remove(values) self.params_to_ntrials[values] += len(losses_train) def _get_next_task(self, cores): # Pseudocode: # if we have enough cores to sweep and a sweep available then # pop from the end of sweeps (so only one index changes) # remove from params_to_sweeps_ind # add to in_progress # return # # pop the item with the most number of trials from trials, ignoring # ones which are already in progress or haven't been sweep yet # # if we cannot finish this then # build the disp.Task which does the right # of trials # update the remaining number of trials for this trial # add to in_progress # return the built disp.Task # build the disp.Task which finishes the queued trials for this set # remove from params_to_task # add to in_progress # return built disp.Task if cores <= 0: return None if cores >= self.sweep_cores and self.sweeps: swp: ParamsTask = self.sweeps.pop() del self.params_to_sweeps_ind[swp.params] self._len -= 1 self.in_progress.add(swp.params) return swp.as_task(self.module, self.hparams, self.folder, self.params_to_id[swp.params], 0, self.sweep_cores, self.listeners, self.params_to_ntrials[swp.params]) if not self.trials: return None pop_ind = len(self.trials) - 1 while True: trl = self.trials[pop_ind] if (trl.params not in self.in_progress and trl.params not in self.params_to_sweeps_ind): break pop_ind -= 1 if pop_ind < 0: return None trl = self.trials.pop(pop_ind) if trl.trials > cores: trl.trials -= cores self.trials.add(trl) self.in_progress.add(trl.params) return trl.as_task(self.module, self.hparams, self.folder, self.params_to_id[trl.params], cores, self.sweep_cores, self.listeners, self.params_to_ntrials[trl.params]) del self.params_to_task[trl.params] self._len -= 1 self.in_progress.add(trl.params) return trl.as_task(self.module, self.hparams, self.folder, self.params_to_id[trl.params], trl.trials, self.sweep_cores, self.listeners, self.params_to_ntrials[trl.params]) def get_next_task(self, cores): res = self._get_next_task(cores) if res is not None: logger.debug('starting task %s', str(res)) return res def have_more_tasks(self): return (self.expecting_more_trials or self._len > 0) def __len__(self) -> int: return self._len
def test_remove_valueerror2(): slt = SortedKeyList(range(100), key=modulo) slt._reset(10) with pytest.raises(ValueError): slt.remove(100)
class Market: def __init__(self): # Implimentation note: Orders are kept in order, such that the most # competative orders are the first in the list, and older orders have # priority over newer ones. self.sell_orders = SortedKeyList( key=lambda order: (order.offer_price, order.order_number)) self.buy_orders = SortedKeyList( key=lambda order: (-order.offer_price, order.order_number)) self.order_number = 1 self._todays_volume = 0 self._daily_volumes = deque(maxlen=30) self._todays_high = None self._daily_highs = deque(maxlen=30) self._todays_low = None self._daily_lows = deque(maxlen=30) self._last_price = 0 self._daily_closing_price = deque(maxlen=30) def has_buy_orders(self): return len(self.buy_orders) > 0 def has_sell_orders(self): return len(self.sell_orders) > 0 def get_30d_avg_volume(self): if len(self._daily_volumes) == 0: return 0 return mean(self._daily_volumes) def get_30d_avg_price(self): if len(self._daily_closing_price) == 0: return 0 return mean(self._daily_closing_price) def get_30d_sigma_price(self): if len(self._daily_closing_price) == 0: return 0 return mean(self._daily_closing_price) def last_session_open(self): if len(self._daily_closing_price) < 2: return 0 return self._daily_closing_price[-2] def last_session_close(self): if len(self._daily_closing_price) < 1: return 0 return self._daily_closing_price[-1] def last_session_volume(self): if len(self._daily_volumes) < 1: return 0 return self._daily_volumes[-1] def last_session_high(self): if len(self._daily_highs) < 1: return 0 return self._daily_highs[-1] def last_session_low(self): if len(self._daily_lows) < 1: return 0 return self._daily_lows[-1] def best_buy_orders(self): """ Get all buy orders that share the most competative (highest) offer price. Orders are returned according to the same sorting as the larger order list. """ best_buy_order_price = self.buy_orders[0].offer_price return self.buy_orders.irange_key( (-best_buy_order_price, 0), (-best_buy_order_price, self.order_number)) def best_sell_orders(self): """ Get all sell orders that share the most competative (lowest) offer price Orders are returned according to the same sorting as the larger order list. """ best_sell_order_price = self.sell_orders[0].offer_price return self.sell_orders.irange_key( (best_sell_order_price, 0), (best_sell_order_price, self.order_number)) def lowest_sell_offer(self): """ Get the most competitive sell price """ return self.sell_orders[0].offer_price def highest_buy_offer(self): """ Get the most competitive buy price """ return self.buy_orders[0].offer_price def place_buy_order(self, offer_price: float, quantity: int, callback: OrderCallback) -> BuyOrder: """ Places an order, which is returned (unfilled) to the caller. Upon fulfilment, the person holding the order is called back with `.fill_order(order, amount_filled)`. This market is not doing escarow, and it is assumed that the person has kept the needed money in hand to be removed now in exchange for goods. """ assert quantity > 0 assert offer_price > 0 order = BuyOrder(self.order_number, offer_price, quantity, callback) self.order_number += 1 self.buy_orders.add(order) return order def place_sell_order(self, offer_price: float, quantity: int, callback: OrderCallback) -> SellOrder: """ Places an order, which is returned (unfilled) to the caller. Upon fulfilment, the person holding the order is called back with `.fill_order(order, amount_filled)`. This market is not doing escarow, and it is assumed that the person has kept the needed goods in hand to be removed now in exchange for money. """ assert quantity > 0 assert offer_price > 0 order = SellOrder(self.order_number, offer_price, quantity, callback) self.order_number += 1 self.sell_orders.add(order) return order def cancel_buy_order(self, order: BuyOrder): """ Cancels a buy order. If the order is not in the market, raises a ValueError """ self.buy_orders.remove(order) def cancel_sell_order(self, order: SellOrder): """ Cancels a sell order. If the order is not in the market, raises a ValueError """ self.sell_orders.remove(order) def _resolve_orders(self, orders, num_resolved): """ Resolve a number of orders as much as possible. In the trivial case, all orders are totally filled. In more complicated resolutions, orders are resolved oldest first, leaving some orders unfilled or partially filled. returns a list of orders that have been wholy or partially filled """ remaining = num_resolved modified_orders = [] for order in orders: to_fill = min(order.quantity_unfilled(), remaining) if to_fill > 0: remaining -= to_fill order._fill(to_fill) modified_orders.append((order, to_fill)) if remaining == 0: break return modified_orders def execute_orders(self): while (len(self.buy_orders) > 0 and len(self.sell_orders) > 0 and self.highest_buy_offer() >= self.lowest_sell_offer()): best_buy_offers = list(self.best_buy_orders()) best_sell_offers = list(self.best_sell_orders()) buy_quantity = count_quantity(best_buy_offers) sell_quantity = count_quantity(best_sell_offers) quantity_resolved = min(buy_quantity, sell_quantity) assert quantity_resolved > 0 self._todays_volume += quantity_resolved strike_price = best_buy_offers[0].offer_price self._last_price = strike_price if self._todays_high == None: self._todays_high = strike_price else: self._todays_high = max(self._todays_high, strike_price) if self._todays_low == None: self._todays_low = strike_price else: self._todays_low = min(self._todays_low, strike_price) executed_buy_orders = self._resolve_orders(best_buy_offers, quantity_resolved) executed_sell_orders = self._resolve_orders( best_sell_offers, quantity_resolved) assert len(executed_buy_orders) > 0 assert len(executed_sell_orders) > 0 for order, num_filled in executed_buy_orders: if order.is_filled(): self.buy_orders.remove(order) for order, num_filled in executed_sell_orders: if order.is_filled(): self.sell_orders.remove(order) # Finally, inform the actors that the orders are executed for order, num_filled in executed_buy_orders: order.callback(order, num_filled) for order, num_filled in executed_sell_orders: order.callback(order, num_filled) def tick(self): self._daily_closing_price.append(self._last_price) self._daily_volumes.append(self._todays_volume) self._todays_volume = 0 self._daily_highs.append(self._todays_high) self._todays_high = None self._daily_lows.append(self._todays_low) self._todays_low = None
class SortedIntvls: """ """ def __init__(self): # we sort by increasing start offset then increasing annotation id for this self._by_start = SortedKeyList(key=lambda x: (x[0], x[2])) # for this we sort by end offset only self._by_end = SortedKeyList(key=lambda x: x[1]) def add(self, start, end, data): """ Adds an interval. """ self._by_start.add((start, end, data)) self._by_end.add((start, end, data)) def update(self, tupleiterable): """ Updates from an iterable of intervals. """ self._by_start.update(tupleiterable) self._by_end.update(tupleiterable) def remove(self, start, end, data): """ Removes an interval, exception if the interval does not exist. """ self._by_start.remove((start, end, data)) self._by_end.remove((start, end, data)) def discard(self, start, end, data): """ Removes and interval, do nothing if the interval does not exist. """ self._by_start.discard((start, end, data)) self._by_end.discard((start, end, data)) def __len__(self): """ Returns the number of intervals. """ return len(self._by_start) def starting_at(self, offset): """ Returns an iterable of (start, end, data) tuples where start==offset """ return self._by_start.irange_key(min_key=(offset, 0), max_key=(offset, sys.maxsize)) def ending_at(self, offset): """ Returns an iterable of (start, end, data) tuples where end==offset """ return self._by_end.irange_key(min_key=offset, max_key=offset) def at(self, start, end): """ Returns an iterable of tuples where start==start and end==end """ for intvl in self._by_start.irange_key(min_key=(start, 0), max_key=(start, sys.maxsize)): if intvl[1] == end: yield intvl def within(self, start, end): """ Returns intervals which are fully contained within start...end """ # get all the intervals that start within the range, then keep those which also end within the range for intvl in self._by_start.irange_key(min_key=(start, 0), max_key=(end, sys.maxsize)): if intvl[1] <= end: yield intvl def starting_from(self, offset): """ Returns intervals that start at or after offset. """ return self._by_start.irange_key(min_key=(offset, 0)) def starting_before(self, offset): """ Returns intervals that start before offset. """ return self._by_start.irange_key(max_key=(offset - 1, sys.maxsize)) def ending_to(self, offset): """ Returns intervals that end before or at the given end offset. """ return self._by_end.irange_key(max_key=offset) def ending_after(self, offset): """ Returns intervals the end after the given offset. """ return self._by_end.irange_key(min_key=offset + 1) def covering(self, start, end): """ Returns intervals that contain the given range. """ # All intervals that start at or before the start and end at or after the end offset # we do this by first getting the intervals the start before or atthe start # then filtering by end for intvl in self._by_start.irange_key(max_key=(start, sys.maxsize)): if intvl[1] >= end: yield intvl def overlapping(self, start, end): """ Returns intervals that overlap with the given range. """ # All intervals where the start or end offset lies within the given range. # This excludes the ones where the end offset is before the start or # where the start offset is after the end of the range. # Here we do this by looking at all intervals where the start offset is before the # end of the range. This still includes those which also end before the start of the range # so we check in addition that the end is larger than the start of the range. for intvl in self._by_start.irange_key(max_key=(end - 1, sys.maxsize)): if intvl[1] > start + 1: yield intvl def firsts(self): """ Yields all intervals which start at the smallest known offset. """ laststart = None # logger.info("DEBUG: set laststart to None") for intvl in self._by_start.irange_key(): # logger.info("DEBUG: checking interval {}".format(intvl)) if laststart is None: laststart = intvl[0] # logger.info("DEBUG: setting laststart to {} and yielding {}".format(intvl[0], intvl)) yield intvl elif intvl[0] == laststart: # logger.info("DEBUG: yielding {}".format(intvl)) yield intvl else: # logger.info("DEBUG: returning since we got {}".format(intvl)) return def lasts(self): """ Yields all intervals which start at the last known start offset. """ laststart = None for intvl in reversed(self._by_start): if laststart is None: laststart = intvl[0] yield intvl elif intvl[0] == laststart: yield intvl else: return def min_start(self): """ Returns the smallest known start offset. """ return self._by_start[0][0] def max_end(self): """ Returns the biggest known end offset. """ return self._by_end[-1][1] def irange(self, minoff=None, maxoff=None, reverse=False, inclusive=(True, True)): """ Yields an iterator of intervals with a start offset between minoff and maxoff, inclusive. Args: minoff: minimum offset, default None indicates any maxoff: maximum offset, default None indicates any reverse: if `True` yield in reverse order inclusive: if the minoff and maxoff values should be inclusive, default is (True,True) Returns: """ return self._by_start.irange_key(min_key=minoff, max_key=maxoff, reverse=reverse, inclusive=inclusive) def __repr__(self): return "SortedIntvls({},{})".format(self._by_start, self._by_end)
def test_remove_valueerror5(): slt = SortedKeyList([1, 1, 1, 2, 2, 2], key=modulo) with pytest.raises(ValueError): slt.remove(12)
def test_remove_valueerror1(): slt = SortedKeyList(key=modulo) with pytest.raises(ValueError): slt.remove(0)
class GuidedExchangeOperator: def __init__(self, environment_energies, feature_key): self.n_envs = int(len(environment_energies) / 2) self.env_energy_differences = [ environment_energies[i] - environment_energies[i + self.n_envs] for i in range(self.n_envs) ] self.feature_key = feature_key self.symbol1_exchange_energies = dict() self.symbol2_exchange_energies = dict() self.symbol1_indices = SortedKeyList( key=lambda x: self.symbol1_exchange_energies[x]) self.symbol2_indices = SortedKeyList( key=lambda x: self.symbol2_exchange_energies[x]) self.n_symbol1_atoms = 0 self.n_symbol2_atoms = 0 def env_from_feature(self, x): return x % self.n_envs def guided_exchange(self, particle): symbol1_index = self.symbol1_indices[0] symbol2_index = self.symbol2_indices[0] particle.swap_symbols([(symbol1_index, symbol2_index)]) return symbol1_index, symbol2_index def basin_hop_step(self, particle): expected_energy_gain = -1 index = 0 while expected_energy_gain <= 0 and index < min( self.n_symbol1_atoms, self.n_symbol2_atoms): index += 1 symbol1_index = self.symbol1_indices[index % self.n_symbol1_atoms] symbol1_energy = self.symbol1_exchange_energies[symbol1_index] symbol2_index = self.symbol2_indices[index % self.n_symbol2_atoms] symbol2_energy = self.symbol2_exchange_energies[symbol2_index] expected_energy_gain = symbol1_energy + symbol2_energy if expected_energy_gain > 0: particle.swap_symbols([(symbol1_index, symbol2_index)]) return symbol1_index, symbol2_index symbol1_index = self.symbol1_indices[index % self.n_symbol1_atoms] symbol2_index = self.symbol2_indices[index % self.n_symbol2_atoms] particle.swap_symbols([(symbol1_index, symbol2_index)]) return symbol1_index, symbol2_index def bind_particle(self, particle): symbols = sorted(particle.atoms.get_all_symbols()) symbol1 = symbols[0] symbol2 = symbols[1] symbol1_indices = particle.get_indices_by_symbol(symbol1) symbol2_indices = particle.get_indices_by_symbol(symbol2) self.n_symbol1_atoms = len(symbol1_indices) self.n_symbol2_atoms = len(symbol2_indices) atom_features = particle.get_atom_features(self.feature_key) for index in symbol1_indices: feature = atom_features[index] self.symbol1_exchange_energies[ index] = -self.env_energy_differences[self.env_from_feature( feature)] self.symbol1_indices.add(index) for index in symbol2_indices: feature = atom_features[index] self.symbol2_exchange_energies[ index] = self.env_energy_differences[self.env_from_feature( feature)] self.symbol2_indices.add(index) def update(self, particle, indices, exchange_indices): symbols = sorted(particle.atoms.get_all_symbols()) symbol1 = symbols[0] atom_features = particle.get_atom_features(self.feature_key) for index in indices: if index in exchange_indices: if particle.get_symbol(index) == symbol1: self.symbol2_indices.remove(index) else: self.symbol1_indices.remove(index) else: if particle.get_symbol(index) == symbol1: self.symbol1_indices.remove(index) else: self.symbol2_indices.remove(index) for index in indices: feature = atom_features[index] new_exchange_energy = self.env_energy_differences[ self.env_from_feature(feature)] if index in exchange_indices: if particle.get_symbol(index) == symbol1: self.symbol1_exchange_energies[ index] = -new_exchange_energy del self.symbol2_exchange_energies[index] else: self.symbol2_exchange_energies[index] = new_exchange_energy del self.symbol1_exchange_energies[index] else: if particle.get_symbol(index) == symbol1: self.symbol1_exchange_energies[ index] = -new_exchange_energy else: self.symbol2_exchange_energies[index] = new_exchange_energy for index in indices: if particle.get_symbol(index) == symbol1: self.symbol1_indices.add(index) else: self.symbol2_indices.add(index)
class CacheLRU: def __init__(self, max_cache_size, nodes_url, nodes_endpoints): self.memmory = dict() self.max_cache_size = max_cache_size self.usage_list = dllist() self.sorted_expiration_time_list = SortedKeyList( key=lambda item: item.expiration_datetime) self.nodes_url = nodes_url self.nodes_endpoints = nodes_endpoints def add_resource_to_cache( self, resource_id, resource_value, mimetype, notify_nodes, expiration_datetime, ): if (not self.contains_cached_resource(resource_id)): while (self._is_memmory_full()): self._free_space() else: self.remove_usage_list(resource_id) self.sorted_expiration_time_list.remove(self.memmory[resource_id]) resource = Resource(resource_id, resource_value, mimetype, expiration_datetime) self.memmory[resource_id] = resource self.sorted_expiration_time_list.add(resource) self.usage_list.appendleft(resource_id) if (notify_nodes): self._notify_nodes_add(resource_id, resource_value, expiration_datetime) def find_resource(self, resource_id): self._delete_resources_by_expiration_time() if (self.contains_cached_resource(resource_id)): self.update_usage_list(resource_id, True) resource = self.memmory[resource_id] return resource.value, resource.mimetype else: return None, None def update_usage_list(self, used_resource_id, notify_nodes): self.remove_usage_list(used_resource_id) self.usage_list.appendleft(used_resource_id) if notify_nodes: self._notify_nodes_used(used_resource_id) def remove_usage_list(self, resource_id): for index in range(len(self.usage_list)): if resource_id == self.usage_list[index]: self.usage_list.remove(self.usage_list.nodeat(index)) break def delete_resource_from_cache(self, resource): if self.contains_cached_resource(resource.identifier): self.remove_usage_list(resource.identifier) self.sorted_expiration_time_list.remove(resource) del self.memmory[resource.identifier] def contains_cached_resource(self, resource_id): return resource_id in self.memmory def _is_memmory_full(self): return self.max_cache_size == len(self.memmory) def _free_space(self, itens=1): resource_id = self.usage_list.first.value self.delete_resource_from_cache(self.memmory[resource_id]) def _delete_resources_by_expiration_time(self): curr_date = datetime.datetime.now() for resource in self.sorted_expiration_time_list: if resource.expiration_datetime <= curr_date: self.delete_resource_from_cache(resource) else: break def _notify_nodes_add(self, resource_id, resource_value, expiration_datetime): for node_url in self.nodes_url: url = node_url + self.nodes_endpoints["insert"] payload = { 'notify_nodes': False, 'expiration_datetime': expiration_datetime } requests.post(f"{url}/{resource_id}", data=resource_value, params=payload) def _notify_nodes_used(self, resource_id): for node_url in self.nodes_url: url = node_url + self.nodes_endpoints["used"] payload = {'notify_nodes': False} requests.get(f"{url}/{resource_id}", params=payload)
class Stream: _key: str = attr.ib() _data: List[dict] = attr.ib() meta: Dict[str, Any] = attr.ib(factory=dict) streams: List['Stream'] = attr.ib(factory=list) # for joined streams twitch: str = attr.ib(init=False) type: StreamType = attr.ib(init=False) games: List[Tuple['Game', SegmentReference]] = attr.ib(init=False) segments: List[Segment] = attr.ib(init=False) timecodes: Timecodes = attr.ib(init=False) @staticmethod def _segment_key(s) -> int: if hasattr(s, 'fallbacks') and 'offset' in s.fallbacks: offset = s.fallbacks['offset'] else: offset = s.offset() return int(offset) def __attrs_post_init__(self): self.twitch = self._key if ',' in self.twitch: self.type = StreamType.JOINED elif self.twitch.startswith('00'): self.type = StreamType.NO_CHAT else: self.type = StreamType.DEFAULT self.games = [] self.segments = SortedKeyList(key=self._segment_key) self.timecodes = Timecodes(timecodes.get(self.twitch) or {}) for segment in self._data: Segment(stream=self, **segment) # Workaround for SortedKeyList.__init__ def __new__(cls, *args, **kwargs): return object.__new__(cls) @property @cached('duration-twitch-{0[0].twitch}') def _duration(self) -> int: line = last_line(self.subtitles_path) if line is not None: return int(Timecode(line.split(' ')[2].split('.')[0])) @property def duration(self) -> Timecode: if self.type is StreamType.JOINED: return Timecode(sum(int(s.duration) for s in self.streams)) elif self.type is StreamType.NO_CHAT: return Timecode(max(int(s.abs_end) for s in self)) else: return Timecode(self._duration) @property def abs_start(self) -> Timecode: return Timecode(0) @property def abs_end(self) -> Timecode: return self.duration @property @cached('date-{0[0].twitch}') def _unix_time(self) -> str: args = ['--pretty=oneline', '--reverse', '-S', self.twitch] rev = repo.git.log(args).split(' ')[0] return repo.commit(rev).authored_date @property def date(self) -> datetime: if self.type is StreamType.JOINED: return self.streams[0].date elif self.type is StreamType.NO_CHAT: return datetime.strptime(self.twitch[2:8], '%y%m%d') else: return datetime.fromtimestamp(self._unix_time) @property def subtitles_prefix(self) -> str: """Returns public URL prefix of subtitles for this segment.""" year = str(self.date.year) key = f'$PREFIX/chats/{year}' if key not in config['repos']['mounts']: raise Exception(f'Repository for year {year} is not configured') prefix = config['repos']['mounts'][key]['prefix'] return prefix @property def subtitles(self) -> str: """Returns full public URL of subtitles for this stream.""" if self.type is StreamType.NO_CHAT: return None return f'{self.subtitles_prefix}/v{self.twitch}.ass' @property def subtitles_path(self) -> str: """Returns relative path of subtitles in current environment.""" return _(f'chats/{self.date.year}/v{self.twitch}.ass') @property def subtitles_style(self) -> SubtitlesStyle: style = SubtitlesStyle(tcd_config['ssa_style_format'], tcd_config['ssa_style_default']) if self.meta.get('chromakey'): style['Alignment'] = '5' else: style['Alignment'] = '1' return style @cached_property def blacklist(self) -> BlacklistTimeline: bl = BlacklistTimeline() for segment in self: for subref in segment.all_subrefs: bl.add(subref.blacklist, subref.abs_start, subref.abs_end) return bl @property @cached('messages-{0[0].twitch}') def _messages(self) -> int: lines = count_lines(self.subtitles_path) return (lines - 10) if lines else None @property def messages(self) -> int: if self.type is StreamType.JOINED: return sum([s.messages for s in self.streams]) else: return self._messages or 0 def __getitem__(self, index: int) -> Segment: return self.segments[index] def __contains__(self, segment: Segment) -> bool: return segment in self.segments def __len__(self) -> int: return len(self.segments) def index(self, segment: Segment) -> int: return self.segments.index(segment) def add(self, segment: Segment): self.segments.add(segment) def remove(self, index: int): self.segments.remove(index) @join() def to_json(self) -> str: if len(self) > 1: yield '[\n' first = True for segment in self: if not first: yield ',\n' else: first = False yield indent(segment.to_json(), 2) yield '\n]' else: yield self[0].to_json() def __str__(self) -> str: return self.to_json()