class CompoundQueue(GeneratorQueue): stop_event = None ready = None loop = None queues = None def __init__(self, queues, loop): self.ready = Event(loop=loop) self.stop_event = Event(loop=loop) self.queues = queues self.loop = loop async def start(self): if self.stop_event.is_set(): raise QueueError("Socket already stopped.") await self.do_action("start") self.ready.set() @dies_on_stop_event async def get(self): raise NotImplementedError() @dies_on_stop_event async def put(self, data): await self.setup() await self.ready.wait() await self.do_action("put", (data,)) async def setup(self): """Setup the client.""" if not self.ready.is_set(): await self.start() async def stop(self): """Stop queue.""" self.ready.clear() self.stop_event.set() await self.do_action("stop") async def do_action(self, name, args=()): coroutines = [getattr(i, name) for i in self.queues] tasks = [i(*args) for i in coroutines] await wait(tasks, loop=self.loop)
class Stream: """ API for working with streams, used by clients and request handlers """ id = None __buffer__ = None def __init__( self, connection: Connection, h2_connection: H2Connection, transport: Transport, *, loop: AbstractEventLoop, stream_id: Optional[int] = None, wrapper: Optional[Wrapper] = None ) -> None: self._connection = connection self._h2_connection = h2_connection self._transport = transport self._wrapper = wrapper self._loop = loop if stream_id is not None: self.id = stream_id self.__buffer__ = Buffer(self.id, self._connection, self._h2_connection, loop=self._loop) self.__headers__ = Queue(loop=loop) \ # type: Queue[List[Tuple[str, str]]] self.__window_updated__ = Event(loop=loop) async def recv_headers(self): return await self.__headers__.get() def recv_headers_nowait(self): try: return self.__headers__.get_nowait() except QueueEmpty: return None async def recv_data(self, size): return await self.__buffer__.read(size) async def send_request(self, headers, end_stream=False, *, _processor): assert self.id is None, self.id while True: # this is the first thing we should check before even trying to # create new stream, because this wait() can be cancelled by timeout # and we wouldn't need to create new stream at all if not self._connection.write_ready.is_set(): await self._connection.write_ready.wait() if self._connection.outbound_streams_limit.reached(): await self._connection.outbound_streams_limit.wait() # while we were trying to create a new stream, write buffer # can became full, so we need to repeat checks from checking # if we can write() data continue # `get_next_available_stream_id()` should be as close to # `connection.send_headers()` as possible, without any async # interruptions in between, see the docs on the # `get_next_available_stream_id()` method stream_id = self._h2_connection.get_next_available_stream_id() try: self._h2_connection.send_headers(stream_id, headers, end_stream=end_stream) except TooManyStreamsError: continue else: self._connection.outbound_streams_limit.acquire() self.id = stream_id self.__buffer__ = Buffer(self.id, self._connection, self._h2_connection, loop=self._loop) release_stream = _processor.register(self) self._transport.write(self._h2_connection.data_to_send()) return release_stream async def send_headers(self, headers, end_stream=False): assert self.id is not None if not self._connection.write_ready.is_set(): await self._connection.write_ready.wait() self._h2_connection.send_headers(self.id, headers, end_stream=end_stream) self._transport.write(self._h2_connection.data_to_send()) async def send_data(self, data, end_stream=False): f = BytesIO(data) f_pos, f_last = 0, len(data) while True: if not self._connection.write_ready.is_set(): await self._connection.write_ready.wait() window = self._h2_connection.local_flow_control_window(self.id) if not window: self.__window_updated__.clear() await self.__window_updated__.wait() window = self._h2_connection.local_flow_control_window(self.id) max_frame_size = self._h2_connection.max_outbound_frame_size f_chunk = f.read(min(window, max_frame_size, f_last - f_pos)) f_pos = f.tell() if f_pos == f_last: self._h2_connection.send_data(self.id, f_chunk, end_stream=end_stream) self._transport.write(self._h2_connection.data_to_send()) break else: self._h2_connection.send_data(self.id, f_chunk) self._transport.write(self._h2_connection.data_to_send()) async def end(self): if not self._connection.write_ready.is_set(): await self._connection.write_ready.wait() self._h2_connection.end_stream(self.id) self._transport.write(self._h2_connection.data_to_send()) async def reset(self, error_code=ErrorCodes.NO_ERROR): if not self._connection.write_ready.is_set(): await self._connection.write_ready.wait() self._h2_connection.reset_stream(self.id, error_code=error_code) self._transport.write(self._h2_connection.data_to_send()) def reset_nowait(self, error_code=ErrorCodes.NO_ERROR): self._h2_connection.reset_stream(self.id, error_code=error_code) if self._connection.write_ready.is_set(): self._transport.write(self._h2_connection.data_to_send()) def __ended__(self): self.__buffer__.eof() def __terminated__(self, reason): if self._wrapper is not None: self._wrapper.cancel(StreamTerminatedError(reason)) @property def closable(self): if self._h2_connection.state_machine.state is ConnectionState.CLOSED: return False stream = self._h2_connection.streams.get(self.id) if stream is None: return False return not stream.closed
class AsyncConnectionPool: """Object manages asynchronous connections. :param int size: size (number of connection) of the pool. :param float queue_timeout: time out when client is waiting connection from pool :param loop: event loop, if not passed then default will be used :param config: MySql connection config see `doc. <http://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html>`_ :raise ValueError: if the `size` is inappropriate """ def __init__(self, size=1, queue_timeout=15.0, *, loop=None, **config): assert size > 0, 'DBPool.size must be greater than 0' if size < 1: raise ValueError('DBPool.size is less than 1, ' 'connections won"t be established') self._pool = set() self._busy_items = set() self._size = size self._pending_futures = deque() self._queue_timeout = queue_timeout self._loop = loop or asyncio.get_event_loop() self.config = config self._shutdown_event = Event(loop=self._loop) self._shutdown_event.set() @property def queue_timeout(self): """Number of seconds to wait a connection from the pool, before TimeoutError occurred :rtype: float """ return self._queue_timeout @queue_timeout.setter def queue_timeout(self, value): """Sets a timeout for :attr:`queue_timeout` :param float value: number of seconds """ if not isinstance(value, (float, int)): raise ValueError('Float or integer type expected') self._queue_timeout = value @property def size(self): """Size of pool :rtype: int """ return self._size def __len__(self): """Number of allocated pool's slots :rtype: int """ return len(self._pool) @property def free_count(self): """Number of free pool's slots :rtype: int """ return self.size - len(self._busy_items) @asyncio.coroutine def get(self): """Coroutine. Returns an opened connection from pool. If coroutine invoked when all connections have been issued, then caller will blocked until some connection will be released. Also, the class provides context manager for getting connection and automatically freeing it. Example: >>> with (yield from pool) as cnx: >>> ... :rtype: AsyncMySQLConnection :raise: concurrent.futures.TimeoutError() """ cnx = None yield from self._shutdown_event.wait() for free_client in self._pool - self._busy_items: cnx = free_client self._busy_items.add(cnx) break else: if len(self) < self.size: cnx = AsyncMySQLConnection(loop=self._loop) self._pool.add(cnx) self._busy_items.add(cnx) try: yield from cnx.connect(**self.config) except: self._pool.remove(cnx) self._busy_items.remove(cnx) raise if not cnx: queue_future = Future(loop=self._loop) self._pending_futures.append(queue_future) try: cnx = yield from asyncio.wait_for(queue_future, self.queue_timeout, loop=self._loop) self._busy_items.add(cnx) except TimeoutError: raise TimeoutError('Database pool is busy') finally: try: self._pending_futures.remove(queue_future) except ValueError: pass return cnx def release(self, connection): """Frees connection. After that the connection can be issued by :func:`get`. :param AsyncMySQLConnection connection: a connection received from :func:`get` """ if len(self._pending_futures): f = self._pending_futures.popleft() f.set_result(connection) else: self._busy_items.remove(connection) @asyncio.coroutine def shutdown(self): """Coroutine. Closes all connections and purge queue of a waiting for connection. """ self._shutdown_event.clear() try: for cnx in self._pool: yield from cnx.disconnect() for f in self._pending_futures: f.cancel() self._pending_futures.clear() self._pool = set() self._busy_items = set() finally: self._shutdown_event.set() def __enter__(self): raise RuntimeError( '"yield from" should be used as context manager expression') def __exit__(self, *args): # This must exist because __enter__ exists, even though that # always raises; that's how the with-statement works. pass @asyncio.coroutine def __iter__(self): cnx = yield from self.get() return ContextManager(self, cnx)
class Writer: """ """ terminator = "\n" delay = 1.0 def __init__(self, levelno=logging.DEBUG): self.queue = Queue() self.canWrite = Event() self.levelno = levelno def filterMessage(self, message) -> bool: if not isinstance(message, Message) or message.type != "log": return False # ignore invalid if message.levelno < self.levelno: return False # filter level return True async def start(self): loop = get_running_loop() while True: try: await self.canWrite.wait() if not self.check(): self.canWrite.clear() continue message = await self.queue.get() if not self.filterMessage(message): self.queue.task_done() continue # avoid acquiring the lock await loop.run_in_executor(None, self.acquire) while True: if self.filterMessage(message): await loop.run_in_executor( None, self.emit, message.msg, message.levelno ) self.queue.task_done() try: # handle any other records while we have the lock message = self.queue.get_nowait() except QueueEmpty: break await loop.run_in_executor(None, self.release) await sleep(self.delay) # rate limit except CancelledError: break # exit the writer except Exception: # catch all logging.warning(f"Caught exception in {self.__class__.__name__}. Stopping", exc_info=True) self.canWrite.clear() def check(self) -> bool: return True def acquire(self): pass @abstractmethod def emit(self, msg: str, levelno: int): raise NotImplementedError() def release(self): pass
class CamModel: def __init__(self, msg_pipe: Connection, img_pipe: Connection, cam_index: int = 0): set_event_loop(new_event_loop()) self._msg_pipe = msg_pipe self._img_pipe = img_pipe self._cam_index = cam_index self._cam_reader = StreamReader(cam_index) self._size_gtr = SizeGetter(self._cam_reader) self._stop_event = Event() self._write_q = SimpleQueue() self._cam_writer = StreamWriter() self._tasks = [] self._switcher = { defs.ModelEnum.STOP: self._stop_writing, defs.ModelEnum.START: self._start_writing, defs.ModelEnum.SET_USE_CAM: self._use_cam, defs.ModelEnum.SET_USE_FEED: self._use_feed, defs.ModelEnum.CLEANUP: self.cleanup, defs.ModelEnum.INITIALIZE: self.init_cam, defs.ModelEnum.GET_FPS: self._get_fps, defs.ModelEnum.SET_FPS: self._set_fps, defs.ModelEnum.GET_RES: self._get_res, defs.ModelEnum.SET_RES: self._set_res, defs.ModelEnum.COND_NAME: self._update_cond_name, defs.ModelEnum.BLOCK_NUM: self._update_block_num, defs.ModelEnum.KEYFLAG: self._update_keyflag, defs.ModelEnum.EXP_STATUS: self._update_exp_status, defs.ModelEnum.LANGUAGE: self.set_lang, defs.ModelEnum.OVERLAY: self._toggle_overlay, } self._running = True self._process_imgs = False self._writing = False self._show_feed = False self._frame_size = self._cam_reader.get_resolution() self._handle_frames = Event() self._loop = get_event_loop() self._strings = strings[LangEnum.ENG] self._fps = 30 self._cam_name = "CAM_" + str(self._cam_index) self._cond_name = str() self._exp_status = self._strings[StringsEnum.EXP_STATUS_STOP] self._exp_running = False self._block_num = 0 self._keyflag = str() self.set_lang() self._test_task = None self._num_img_workers = 2 self._sems1 = list() self._sems2 = list() self._sems3 = list() self._shm_ovl_arrs = list() self._shm_img_arrs = list() self._np_img_arrs = list() self._num_writes_arrs = list() self._use_overlay = True self._proc_thread = Thread(target=None, args=()) cur_res = self._cam_reader.get_resolution() self._cur_arr_shape = (int(cur_res[1]), int(cur_res[0]), 3) self._cur_arr_size = self._cur_arr_shape[0] * self._cur_arr_shape[ 1] * self._cur_arr_shape[2] self._executor = ThreadPoolExecutor() self._loop.run_until_complete(self._start_loop()) async def _handle_pipe(self) -> None: """ Handle msgs from model. :return None: """ try: while self._running: if self._msg_pipe.poll(): msg = self._msg_pipe.recv() if msg[0] in self._switcher.keys(): if msg[1] is not None: self._switcher[msg[0]](msg[1]) else: self._switcher[msg[0]]() await asyncsleep(.02) except BrokenPipeError as bpe: pass except OSError as ose: pass except Exception as e: raise e def cleanup(self, discard: bool) -> None: """ Cleanup this code and prep for app closure. :param discard: Quit without saving. :return None: """ create_task(self._cleanup(discard)) def init_cam(self) -> None: """ Begin initializing camera. :return None: """ self._test_task = create_task(self._run_tests()) def set_lang(self, lang: LangEnum = LangEnum.ENG) -> None: """ Set this camera's language. :param lang: The new language enum. :return None: """ self._strings = strings[lang] self._set_texts() async def _run_tests(self) -> None: """ Run each camera test in order. :return None: """ prog_tracker = create_task(self._monitor_init_progress()) sizes = await self._size_gtr.get_sizes() if len(sizes) < 1: self._msg_pipe.send((defs.ModelEnum.FAILURE, None)) prog_tracker.cancel() return self._msg_pipe.send((defs.ModelEnum.START, (self._fps, sizes))) prog_tracker.cancel() self._proc_thread = Thread(target=self._start_frame_processing, args=()) self._proc_thread.start() async def _monitor_init_progress(self) -> None: """ Periodically update controller on init progress. :return None: """ while True: if self._size_gtr.status >= 100: break self._msg_pipe.send( (defs.ModelEnum.STAT_UPD, self._size_gtr.status)) await asyncsleep(.5) async def _await_reader_err(self) -> None: """ Handle if reader fails. :return None: """ while self._running: await self._cam_reader.await_err() self._msg_pipe.send((defs.ModelEnum.FAILURE, None)) async def _cleanup(self, discard: bool) -> None: self._running = False if self._test_task is not None: if self._test_task.done(): await self._test_task else: self._test_task.cancel() self._size_gtr.stop() self._stop() self._cam_reader.cleanup() self._cam_writer.cleanup(discard) self._msg_pipe.send((defs.ModelEnum.CLEANUP, None)) def _refresh_np_arrs(self) -> None: self._np_img_arrs = list() for j in range(self._num_img_workers): self._np_img_arrs.append( frombuffer(self._shm_img_arrs[j].get_obj(), count=self._cur_arr_size, dtype=DTYPE).reshape(self._cur_arr_shape)) def _update_block_num(self, num: int) -> None: """ Update the block num shown on camera details. :param num: The new num to show. :return None: """ self._block_num = num def _update_cond_name(self, name: str) -> None: """ Update the condition name shown on camera details. :param name: The new name to show. :return None: """ self._cond_name = name def _update_keyflag(self, flag: str) -> None: """ Update the key flag shown on camera details. :param flag: The new key flag to show. :return None: """ self._keyflag = flag def _update_exp_status(self, status: bool) -> None: """ Update the experiment status shown on the camera details. :param status: The new status to show. :return None: """ self._exp_running = status self._set_texts() def _toggle_overlay(self, is_active: bool) -> None: """ toggle whether to use overlay on this camera. :param is_active: Whether to use overlay. :return None: """ self._use_overlay = is_active def _get_res(self) -> None: """ Send the current resolution of this camera. :return None: """ self._msg_pipe.send( (defs.ModelEnum.CUR_RES, self._cam_reader.get_resolution())) def _set_res(self, new_res: (float, float)) -> None: """ Change the resolution on this camera. :param new_res: The new resolution to use. :return None: """ if new_res == self._cam_reader.get_resolution(): return self._show_feed = False self._cam_reader.stop_reading() self._stop_frame_processing() self._cam_reader.set_resolution(new_res) self._times = deque() cur_res = self._cam_reader.get_resolution() self._cur_arr_shape = (int(cur_res[1]), int(cur_res[0]), 3) self._cur_arr_size = self._cur_arr_shape[0] * self._cur_arr_shape[ 1] * self._cur_arr_shape[2] self._proc_thread = Thread(target=self._start_frame_processing, args=()) self._cam_reader.start_reading() self._proc_thread.start() self._show_feed = True def _use_cam(self, is_active: bool) -> None: """ Toggle whether this cam is being used. :param is_active: Whether this cam is being used. :return None: """ if is_active: self._cam_reader.start_reading() self._handle_frames.set() else: self._cam_reader.stop_reading() self._handle_frames.clear() def _use_feed(self, is_active: bool) -> None: """ Toggle whether this cam feed is being passed to the view. :param is_active: Whether this cam feed is being passed to the view. :return None: """ self._show_feed = is_active def _start_writing(self, path: str) -> None: """ Create new writer and set boolean to start putting frames in write queue. :return None: """ # filename = path + "CAM_" + str(self._cam_index) + "_" + format_current_time(datetime.now(), save=True) + ".avi" filename = path + "CAM_" + str(self._cam_index) + ".avi" x, y = self._cam_reader.get_resolution() self._frame_size = (int(x), int(y)) self._write_q = SimpleQueue() self._cam_writer = StreamWriter() self._cam_writer.start(filename, int(self._fps), self._frame_size, self._write_q) self._writing = True def _stop_writing(self) -> None: """ Destroy writer and set boolean to stop putting frames in write queue. :return None: """ self._writing = False while not self._write_q.empty(): tsleep(.05) self._cam_writer.cleanup() self._msg_pipe.send((defs.ModelEnum.STOP, None)) async def _start_loop(self) -> None: """ Run all async tasks in this model and wait for stop signal. (This method is the main loop for this process) :return None: """ self._tasks.append(create_task(self._handle_pipe())) self._tasks.append(create_task(self._await_reader_err())) await self._stop_event.wait() def _start_frame_processing(self) -> None: """ Create image processing threads and wait for stop signal. :return None: """ self._process_imgs = True max_res = defs.common_resolutions[-1] max_img_arr_shape = (int(max_res[1]), int(max_res[0]), 3) max_img_arr_size = max_img_arr_shape[0] * max_img_arr_shape[ 1] * max_img_arr_shape[2] self._sems1 = list() self._sems2 = list() self._sems3 = list() self._shm_ovl_arrs = list() self._shm_img_arrs = list() self._np_img_arrs = list() self._num_writes_arrs = list() for i in range(self._num_img_workers): self._sems1.append(Semaphore(0)) self._sems2.append(Semaphore(0)) self._sems3.append(Semaphore(1)) self._shm_ovl_arrs.append(Array(c_char, BYTESTR_SIZE)) self._shm_img_arrs.append(Array('Q', max_img_arr_size)) self._num_writes_arrs.append(Value('i', 1)) worker_args = (self._shm_img_arrs[i], self._sems1[i], self._sems2[i], self._shm_ovl_arrs[i]) worker = Thread(target=self._img_processor, args=worker_args, daemon=True) worker.start() self._refresh_np_arrs() distributor = Thread(target=self._distribute_frames, args=(), daemon=True) distributor.start() handler = Thread(target=self._handle_processed_frames, args=(), daemon=True) handler.start() while self._process_imgs: tsleep(1) def _stop_frame_processing(self) -> None: """ Stop proc_thread and join it. :return None: """ self._process_imgs = False self._proc_thread.join() def _stop(self) -> None: """ Stop all async tasks. :return None: """ for task in self._tasks: task.cancel() self._process_imgs = False if self._proc_thread.is_alive(): self._proc_thread.join() self._stop_event.set() def _get_fps(self) -> None: """ Send the current fps of this camera. :return None: """ self._msg_pipe.send( (defs.ModelEnum.CUR_FPS, self._cam_reader.get_fps_setting())) def _set_fps(self, new_fps: float) -> None: """ Set new fps and reset fps tracking. :param new_fps: The new fps to use. :return None: """ self._times = deque() self._cam_reader.set_fps(new_fps) self._fps = int(new_fps) def _distribute_frames(self) -> None: """ Distribute frames in proper order to image_worker processes. :return None: """ i = 0 while self._process_imgs: ret, val = self._cam_reader.get_next_new_frame() if ret: (frame, timestamp, num_writes) = val self._hand_out_frame(frame, timestamp, i, num_writes) i = self._increment_counter(i) else: tsleep(.001) def _hand_out_frame(self, frame, timestamp: datetime, i: int, num_writes: int) -> None: """ Helper function for self._distribute_frames() :param frame: The frame to put an overlay on. :param timestamp: A datetime object to add to the overlay. :param i: Which arrays to access. :param num_writes: The number of times to write this frame to save file. :return None: """ overlay = shorten(self._cond_name, COND_NAME_WIDTH) + CM_SEP + \ format_current_time(timestamp, time=True, mil=True) + CM_SEP + self._exp_status + CM_SEP + \ str(self._block_num) + CM_SEP + str(self._keyflag) + CM_SEP + str(self._cam_reader.get_fps_actual())\ + "/" + str(self._fps) self._sems3[i].acquire() copyto(self._np_img_arrs[i], frame) self._shm_ovl_arrs[i].value = (overlay.encode()) self._num_writes_arrs[i].value = num_writes self._sems1[i].release() def _increment_counter(self, num: int) -> int: """ Helper function for self._distribute_frames() :param num: The integer to increment from. :return int: The incremented integer. """ return (num + 1) % self._num_img_workers def _img_processor(self, sh_img_arr: Array, sem1: Semaphore, sem2: Semaphore, ovl_arr: Array) -> None: """ Process images as needed. :param sh_img_arr: The array containing the frame to work with. :param sem1: The entrance lock. :param sem2: The exit lock. :param ovl_arr: The array containing the overlay work with. :return None: """ img_dim = (EDIT_HEIGHT, self._cur_arr_shape[1], self._cur_arr_shape[2]) img_size = int(EDIT_HEIGHT * img_dim[1] * img_dim[2]) img_arr = frombuffer(sh_img_arr.get_obj(), count=img_size, dtype=DTYPE).reshape(img_dim) while self._process_imgs: sem1.acquire() if self._use_overlay: img_pil = Image.fromarray(img_arr) draw = ImageDraw.Draw(img_pil) draw.text(OVL_POS, text=ovl_arr.value.decode(), font=OVL_FONT, fill=OVL_CLR) processed_img = asarray(img_pil) copyto(img_arr, processed_img) sem2.release() def _handle_processed_frames(self) -> None: """ Handle processed frames in proper order from ImgWorker processes. :return None: """ i = 0 while self._process_imgs: self._sems2[i].acquire() frame = self._np_img_arrs[i] if self._writing: for p in range(self._num_writes_arrs[i].value): self._write_q.put(copy(frame)) if self._show_feed: to_send = self.image_resize(frame, width=640) self._img_pipe.send(to_send) self._sems3[i].release() i = self._increment_counter(i) def _set_texts(self) -> None: """ Set the initial texts for this camera. :return None: """ if self._exp_running: self._exp_status = self._strings[StringsEnum.EXP_STATUS_RUN] else: self._exp_status = self._strings[StringsEnum.EXP_STATUS_STOP] # from https://stackoverflow.com/questions/44650888/resize-an-image-without-distortion-opencv @staticmethod def image_resize(image, width=None, height=None, inter=INTER_AREA): # initialize the dimensions of the image to be resized and # grab the image size dim = None (h, w) = image.shape[:2] # if both the width and height are None, then return the # original image if width is None and height is None: return image # check to see if the width is None if width is None: # calculate the ratio of the height and construct the # dimensions r = height / float(h) dim = (int(w * r), height) # otherwise, the height is None else: # calculate the ratio of the width and construct the # dimensions r = width / float(w) dim = (width, int(h * r)) # resize the image resized = resize(image, dim, interpolation=inter) # return the resized image return resized
class Pool: def __init__(self, n_slots): self._n_slots = n_slots self._wait_tx = OrderedDict() self._wait_tx_sem = Semaphore(0) self._wait_rx = OrderedDict() self._not_full = Event() self._not_full.set() @property def n_slots(self): return self._n_slots @property def is_full(self): return len(self._wait_tx) + len(self._wait_rx) >= self._n_slots def full(self): return len(self._wait_tx) + len(self._wait_rx) >= self._n_slots async def wait_not_full(self): await self._not_full.wait() @property def qsize(self): return self._n_slots def put_nowait(self, request): instance_id = request.instanceId analysis_id = request.analysisId key = (instance_id, analysis_id) existing = self._wait_tx.get(key) if existing is not None: ex_request, ex_stream = existing log.debug('%s %s', 'cancelling', req_str(ex_request)) ex_stream.cancel() else: existing = self._wait_rx.get(key) if existing is not None: ex_request, ex_stream = existing log.debug('%s %s', 'cancelling', req_str(ex_request)) ex_stream.cancel() if self.full(): raise QueueFull stream = Stream() log.debug('%s %s', 'queueing', req_str(request)) self._wait_tx[key] = (request, stream) if self._wait_tx_sem.locked(): self._wait_tx_sem.release() stream.add_complete_listener(self._stream_complete) if self.is_full: self._not_full.clear() return stream def add(self, request): return self.put_nowait(request) def cancel(self, key): existing = self._wait_tx.get(key) if existing is not None: ex_request, ex_stream = existing log.debug('%s %s', 'cancelling', req_str(ex_request)) ex_stream.cancel() else: existing = self._wait_rx.get(key) if existing is not None: ex_request, ex_stream = existing log.debug('%s %s', 'cancelling', req_str(ex_request)) ex_stream.cancel() else: raise KeyError def get(self, key): value = self._wait_tx.get(key) if value is not None: return value value = self._wait_rx.get(key) if value is not None: return value return None def _stream_complete(self): # iterate through a copied list so we can delete from the original for key, value in list(self._wait_rx.items()): request, stream = value if stream.is_complete: del self._wait_rx[key] log.debug('%s %s', 'removing', req_str(request)) for key, value in list(self._wait_tx.items()): request, stream = value if stream.is_complete: del self._wait_tx[key] log.debug('%s %s', 'removing', req_str(request)) if not self.is_full: self._not_full.set() async def stream(self): while True: await self._wait_tx_sem.acquire() while len(self._wait_tx) > 0: key, value = self._wait_tx.popitem() self._wait_rx[key] = value request, stream = value log.debug('%s %s', 'yielding', req_str(request)) yield value def __contains__(self, value): return value in self._wait_tx or value in self._wait_rx
class SessionBase(asyncio.Protocol): """Base class of networking sessions. There is no client / server distinction other than who initiated the connection. To initiate a connection to a remote server pass host, port and proxy to the constructor, and then call create_connection(). Each successful call should have a corresponding call to close(). Alternatively if used in a with statement, the connection is made on entry to the block, and closed on exit from the block. """ max_errors = 10 def __init__(self, *, framer=None, loop=None): self.framer = framer or self.default_framer() self.loop = loop or asyncio.get_event_loop() self.logger = logging.getLogger(self.__class__.__name__) self.transport = None # Set when a connection is made self._address = None self._proxy_address = None # For logger.debug messages self.verbosity = 0 # Cleared when the send socket is full self._can_send = Event() self._can_send.set() self._pm_task = None self._task_group = TaskGroup(self.loop) # Force-close a connection if a send doesn't succeed in this time self.max_send_delay = 60 # Statistics. The RPC object also keeps its own statistics. self.start_time = time.perf_counter() self.errors = 0 self.send_count = 0 self.send_size = 0 self.last_send = self.start_time self.recv_count = 0 self.recv_size = 0 self.last_recv = self.start_time self.last_packet_received = self.start_time async def _limited_wait(self, secs): try: await asyncio.wait_for(self._can_send.wait(), secs) except asyncio.TimeoutError: self.abort() raise asyncio.TimeoutError(f'task timed out after {secs}s') async def _send_message(self, message): if not self._can_send.is_set(): await self._limited_wait(self.max_send_delay) if not self.is_closing(): framed_message = self.framer.frame(message) self.send_size += len(framed_message) self.send_count += 1 self.last_send = time.perf_counter() if self.verbosity >= 4: self.logger.debug(f'Sending framed message {framed_message}') self.transport.write(framed_message) def _bump_errors(self): self.errors += 1 if self.errors >= self.max_errors: # Don't await self.close() because that is self-cancelling self._close() def _close(self): if self.transport: self.transport.close() # asyncio framework def data_received(self, framed_message): """Called by asyncio when a message comes in.""" self.last_packet_received = time.perf_counter() if self.verbosity >= 4: self.logger.debug(f'Received framed message {framed_message}') self.recv_size += len(framed_message) self.framer.received_bytes(framed_message) def pause_writing(self): """Transport calls when the send buffer is full.""" if not self.is_closing(): self._can_send.clear() self.transport.pause_reading() def resume_writing(self): """Transport calls when the send buffer has room.""" if not self._can_send.is_set(): self._can_send.set() self.transport.resume_reading() def connection_made(self, transport): """Called by asyncio when a connection is established. Derived classes overriding this method must call this first.""" self.transport = transport # This would throw if called on a closed SSL transport. Fixed # in asyncio in Python 3.6.1 and 3.5.4 peer_address = transport.get_extra_info('peername') # If the Socks proxy was used then _address is already set to # the remote address if self._address: self._proxy_address = peer_address else: self._address = peer_address self._pm_task = self.loop.create_task(self._receive_messages()) def connection_lost(self, exc): """Called by asyncio when the connection closes. Tear down things done in connection_made.""" self._address = None self.transport = None self._task_group.cancel() if self._pm_task: self._pm_task.cancel() # Release waiting tasks self._can_send.set() # External API def default_framer(self): """Return a default framer.""" raise NotImplementedError def peer_address(self): """Returns the peer's address (Python networking address), or None if no connection or an error. This is the result of socket.getpeername() when the connection was made. """ return self._address def peer_address_str(self): """Returns the peer's IP address and port as a human-readable string.""" if not self._address: return 'unknown' ip_addr_str, port = self._address[:2] if ':' in ip_addr_str: return f'[{ip_addr_str}]:{port}' else: return f'{ip_addr_str}:{port}' def is_closing(self): """Return True if the connection is closing.""" return not self.transport or self.transport.is_closing() def abort(self): """Forcefully close the connection.""" if self.transport: self.transport.abort() # TODO: replace with synchronous_close async def close(self, *, force_after=30): """Close the connection and return when closed.""" self._close() if self._pm_task: with suppress(CancelledError): await asyncio.wait([self._pm_task], timeout=force_after) self.abort() await self._pm_task def synchronous_close(self): self._close() if self._pm_task and not self._pm_task.done(): self._pm_task.cancel()
class Stream: """ API for working with streams, used by clients and request handlers """ def __init__(self, connection: Connection, h2_connection: H2Connection, transport: Transport, stream_id: int, *, loop: AbstractEventLoop) -> None: self._connection = connection self._h2_connection = h2_connection self._transport = transport self.id = stream_id self.__buffer__ = Buffer(loop=loop) self.__headers__: 'Queue[List[Tuple[str, str]]]' = Queue(loop=loop) self.__window_updated__ = Event(loop=loop) async def recv_headers(self): return await self.__headers__.get() async def recv_data(self, size=None): data = await self.__buffer__.read(size) self._h2_connection.acknowledge_received_data(len(data), self.id) return data async def send_headers(self, headers, end_stream=False): if not self._connection.write_ready.is_set(): await self._connection.write_ready.wait() self._h2_connection.send_headers(self.id, headers, end_stream=end_stream) self._transport.write(self._h2_connection.data_to_send()) async def send_data(self, data, end_stream=False): f = BytesIO(data) f_pos, f_last = 0, len(data) while True: if not self._connection.write_ready.is_set(): await self._connection.write_ready.wait() window = self._h2_connection.local_flow_control_window(self.id) if not window: self.__window_updated__.clear() await self.__window_updated__.wait() window = self._h2_connection.local_flow_control_window(self.id) f_chunk = f.read(min(window, f_last - f_pos)) f_pos = f.tell() if f_pos == f_last: self._h2_connection.send_data(self.id, f_chunk, end_stream=end_stream) self._transport.write(self._h2_connection.data_to_send()) break else: self._h2_connection.send_data(self.id, f_chunk) self._transport.write(self._h2_connection.data_to_send()) async def end(self): if not self._connection.write_ready.is_set(): await self._connection.write_ready.wait() self._h2_connection.end_stream(self.id) self._transport.write(self._h2_connection.data_to_send()) async def reset(self, error_code=ErrorCodes.NO_ERROR): if not self._connection.write_ready.is_set(): await self._connection.write_ready.wait() self._h2_connection.reset_stream(self.id, error_code=error_code) self._transport.write(self._h2_connection.data_to_send()) def __ended__(self): self.__buffer__.eof()
class Client: def __init__(self, intents, token=None, *, shard_count=None, shard_ids=None): """ The client used to interact with the discord API. Parameters ---------- intents: int The intents to use. token: Optional[str] Discord bot token to use. shard_count: Optional[int] How many shards the client should use. shard_ids: Optional[List[int]] A list of shard IDs to spawn. ``shard_count`` must be set for this to work. Raises ------ TypeError ``shard_ids`` was set without ``shard_count``. """ # Configurable stuff self.intents = int(intents) self.token = token self.shard_count = shard_count self.shard_ids = shard_ids # Things used by the lib, usually doesn't need to get changed but can if you want to. self.shards = [] self.loop = get_event_loop() self.logger = getLogger("speedcord") self.http = None self.opcode_dispatcher = OpcodeDispatcher(self.loop) self.event_dispatcher = EventDispatcher(self.loop) self.connected = Event() self.exit_event = Event(loop=self.loop) self.remaining_connections = None self.connection_lock = Lock(loop=self.loop) self.fatal_exception = None self.connect_ratelimiter = None self.current_shard_count = shard_count if shard_count else None # Default event handlers self.opcode_dispatcher.register(0, self.handle_dispatch) # Check types if shard_count is None and shard_ids is not None: raise TypeError("You have to set shard_count if you use shard_ids") def run(self): """ Starts the client. """ try: self.loop.run_until_complete(self.start()) except KeyboardInterrupt: self.loop.run_until_complete(self.close()) if self.fatal_exception is not None: raise self.fatal_exception from None async def get_gateway(self): """ Get details about the gateway Returns ------- Tuple[str, int, int, int] A tuple consisting of the wss url to connect to, how many shards to use, how many gateway connections left, how many milliseconds until the gateway connection limit resets. Raises ------ Unauthorized Authentication failed. """ route = Route("GET", "/gateway/bot") try: r = await self.http.request(route) except Unauthorized: await self.close() raise data = await r.json() shards = data["shards"] remaining_connections = data["session_start_limit"]["remaining"] connections_reset_after = data["session_start_limit"]["reset_after"] max_concurrency = data["session_start_limit"]["max_concurrency"] gateway_url = data["url"] if remaining_connections == 0: raise ConnectionsExceeded self.remaining_connections = remaining_connections self.logger.debug(f"{remaining_connections} gateway connections left!") return gateway_url, shards, remaining_connections, connections_reset_after, max_concurrency async def connect(self): """ Connects to discord and spawns shards. :meth:`start()` has to be called first! Raises ------ InvalidToken Provided token is invalid. """ if self.token is None: raise InvalidToken if self.http is None: self.http = HttpClient(self.token, loop=self.loop) await self.spawn_shards(self.shards, shard_ids=self.shard_ids) self.connected.set() self.logger.info("All shards connected!") async def start(self): """ Sets up the HTTP client, connects to Discord, and spawns shards. Raises ------ InvalidToken Provided token is invalid. """ if self.token is None: raise InvalidToken self.http = HttpClient(self.token, loop=self.loop) await self.connect() await self.exit_event.wait() await self.close() async def close(self): """ Closes the HTTP client and disconnects all shards. """ self.connected.clear() self.exit_event.set() await self.http.close() for shard in self.shards: await shard.close() async def fatal(self, exception): """ Raises a fatal exception to the bot. Please do not use this for non-fatal exceptions. """ self.fatal_exception = exception await self.close() async def spawn_shards(self, shard_list, *, activate_automatically=True, shard_ids=None): try: gateway_url, shard_count, connections_left, \ connections_reset_after, max_concurrency = await self.get_gateway() except Unauthorized as e: await self.fatal(e) return if self.connect_ratelimiter is None: self.connect_ratelimiter = TimesPer(max_concurrency, 5) if self.current_shard_count is None: self.current_shard_count = self.shard_count or shard_count if shard_count > self.current_shard_count: if self.shard_count: raise InvalidShardCount self.current_shard_count = shard_count if shard_ids is None: shard_ids = range(self.current_shard_count) async with self.connection_lock: for shard_id in shard_ids: connections_left -= 1 if connections_left <= 1: sleep_time = (connections_reset_after / 1000) - time() if sleep_time > 0: self.logger.warning("You have used up all your gateway IDENTIFYs. Sleeping until it resets.") await sleep(connections_reset_after - time()) try: gateway_url, shard_count, connections_left, \ connections_reset_after, max_concurrency = await self.get_gateway() except Unauthorized as e: await self.fatal(e) return await self.connect_ratelimiter.trigger() self.logger.info(f"Launching shard {shard_id}") shard = DefaultShard(shard_id, self, loop=self.loop) if not activate_automatically: shard.active = False else: shard.active = True self.logger.debug("Connecting shard") await shard.connect(gateway_url) self.logger.debug("Connected shard!") shard_list.append(shard) self.logger.debug("All shards connected") self.remaining_connections = connections_left def listen(self, event): """ Listen to an event or opcode. Parameters ---------- event: Union[int, str] An opcode or event name to listen to. Raises ------ TypeError Invalid event type was passed. """ def get_func(func): if isinstance(event, int): self.opcode_dispatcher.register(event, func) elif isinstance(event, str): self.event_dispatcher.register(event, func) else: raise TypeError("Invalid event type!") return get_func # Handle events async def handle_dispatch(self, data, shard): """ Dispatches a event to the event handler. Parameters ---------- data: Dict[str, Any] The data to dispatch. shard: DefaultShard Shard the event was received on. """ self.event_dispatcher.dispatch(data["t"], data["d"], shard)
class AsyncCoordinator: _loop: AbstractEventLoop _maxsize: int _getters: Dict[str, Future] _putters: Dict[str, Future] _unfinished_tasks: int _finished: Event _dict: Dict[str, Any] def __init__(self, maxsize=0, *, loop=None): if loop is None: self._loop = get_event_loop() else: self._loop = loop self._maxsize = maxsize self._getters: Dict[str, Future] = dict() self._putters: Dict[str, Future] = dict() self._unfinished_tasks = 0 self._finished = Event(loop=self._loop) self._finished.set() self._dict = dict() def get_dict_copy(self) -> Dict[str, Any]: return self._dict.copy() def _get(self, key: str) -> Any: value = self._dict[key] del self._dict[key] return value def _put(self, key: str, value: Any) -> None: self._dict[key] = value def _wake_up_getter(self, key: str) -> None: if key in self._getters: waiter = self._getters[key] if not waiter.done(): waiter.set_result(None) def _wake_up_putter(self, key: str) -> None: if key in self._putters: waiter = self._putters[key] if not waiter.done(): waiter.set_result(None) def __repr__(self) -> str: return f'<{type(self).__name__} at {id(self):#x} {self._format()}>' def __str__(self) -> str: return f'<{type(self).__name__} {self._format()}>' def _format(self) -> str: result = f'maxsize={self._maxsize!r}' if getattr(self, '_dict', None): result += f' _dict={dict(self._dict)!r}' if self._getters: result += f' _getters[{len(self._getters)}]' if self._putters: result += f' _putters[{len(self._putters)}]' if self._unfinished_tasks: result += f' tasks={self._unfinished_tasks}' return result def qsize(self) -> int: return len(self._dict) @property def maxsize(self) -> int: return self._maxsize def empty(self) -> bool: return not self._dict def full(self) -> bool: if self._maxsize <= 0: return False else: return self.qsize() >= self._maxsize async def put(self, key: str, value: Any) -> None: while self.full(): putter = self._loop.create_future() self._putters[key] = putter try: await putter except Exception as err: print('Put err = ', err) # Just in case putter is not done yet. putter.cancel() try: # Clean self._putters from canceled putters. del self._putters[key] except ValueError: # The putter could be removed from self._putters by a # previous get_nowait call. pass if not self.full() and not putter.cancelled(): # We were woken up by get_nowait(), but can't take # the call. Wake up the next in line. self._wake_up_putter(key) raise return self.put_nowait(key, value) def put_nowait(self, key: str, value: Any) -> None: print(f'put no wait {key}, {value}') if self.full(): raise DictIsFull self._put(key, value) self._unfinished_tasks += 1 self._finished.clear() self._wake_up_getter(key) async def get(self, key: str) -> Any: print('dict = ', self._dict) print('loop = ', self._loop) while self._dict[key]: print('self._dict[key] = ', self._dict[key]) getter = self._loop.create_future() self._getters[key] = getter try: print('Before await') await getter print('After await') except Exception as err: print('Err = ', err) # Just in case getter is not done yet. getter.cancel() try: # Clean self._getters from canceled getters. del self._getters[key] except ValueError: # The getter could be removed from self._getters by a # previous put_nowait call. pass if not self.empty() and not getter.cancelled(): # We were woken up by put_nowait(), but can't take # the call. Wake up the next in line. self._wake_up_getter(key) raise return self.get_nowait(key) def get_nowait(self, key: str) -> Any: if self.empty(): raise DictIsEmpty item = self._get(key) self._wake_up_putter(key) return item def task_done(self) -> None: if self._unfinished_tasks <= 0: raise ValueError('task_done() called too many times') self._unfinished_tasks -= 1 if self._unfinished_tasks == 0: self._finished.set() async def join(self) -> None: if self._unfinished_tasks > 0: await self._finished.wait()
class Game: def __init__( self, bot: 'Union[Aria, TestBot]', alpha: Union[discord.Member, TestMember], beta: Union[discord.Member, TestMember], channel: Union[discord.TextChannel, TestChannel], send_callable: Callable = _print, ) -> None: self.bot = bot self.alpha = alpha self.beta = beta self.channel = channel self.finish = False self.alpha_spell: Optional[Spell] = None self.beta_spell: Optional[Spell] = None self.alpha_loop: Optional[Task] = None self.beta_loop: Optional[Task] = None self.alpha_db_user: Optional[User] = None self.beta_db_user: Optional[User] = None self.alpha_hp = 100 self.beta_hp = 100 self.alpha_mp = 100 self.beta_mp = 100 self.ready_to_raise = False self.send_callable = send_callable self.battle_finish_flag = Event() self.game_finish_flag = Event() async def send(self, *args, **kwargs) -> None: # type: ignore if iscoroutinefunction(self.send_callable): await self.send_callable(*args, **kwargs) else: self.send_callable(*args, **kwargs) async def wait_for(self, *args, **kwargs) -> Message: # type: ignore content = input() return Message(content, datetime.datetime.now()) def alpha_check(self, message: discord.Message) -> bool: return message.channel.id == self.channel.id and message.author.id == self.alpha.id def beta_check(self, message: discord.Message) -> bool: return message.channel.id == self.channel.id and message.author.id == self.beta.id async def recv_command(self, check: Callable, user: str) -> Optional[Spell]: spell = Spell() while not self.bot.is_closed() and not self.finish: try: message = await self.wait_for('message', check=check, timeout=60) except TimeoutError: return None if message.content in ['execute', 'discharge']: if not self.use_mp(user, 5): await self.send('システム: MPが枯渇しました。') return None break if not spell.can_aria(message.created_at): return None mp, msg = spell.receive_command(message.content, message.created_at) if mp is not None: await self.send('システム: ' + msg) if not self.use_mp(user, mp): await self.send('システム: MPが枯渇しました。') return None continue return None return spell async def win(self, winner: Union[discord.Member, TestMember], loser: Union[discord.Member, TestMember]) -> None: await self.send(f'{winner.mention} の勝利!') winner_db_user = await self.bot.db.get_user(winner.id) loser_db_user = await self.bot.db.get_user(loser.id) hp_or_mp = random.choice([0, 1]) # 0=hp 1=mp def get_num(_user: User) -> int: return _user.hp if not hp_or_mp else _user.mp diff = (winner_db_user.hp + winner_db_user.mp) / (loser_db_user.hp + loser_db_user.mp) # hp if diff <= 0.5: # めっちゃ勝った get_ = int(get_num(loser_db_user) * 0.15 * (random.random() + 1)) lost_ = int(get_num(loser_db_user) * 0.15) elif diff <= 0.6: # 結構勝った get_ = int(get_num(loser_db_user) * 0.12 * (random.random() + 1)) lost_ = int(get_num(loser_db_user) * 0.12) elif diff <= 0.7: # まあまあ勝った get_ = int(get_num(loser_db_user) * 0.1 * (random.random() + 1)) lost_ = int(get_num(loser_db_user) * 0.1) elif diff <= 0.8: # ちょい勝った get_ = int(get_num(loser_db_user) * 0.07 * (random.random() + 1)) lost_ = int(get_num(loser_db_user) * 0.07) elif diff <= 0.9: # ほんとちょびっと勝った get_ = int(get_num(loser_db_user) * 0.06 * (random.random() + 1)) lost_ = int(get_num(loser_db_user) * 0.06) elif diff <= 1.1: # 同じくらい get_ = int(get_num(loser_db_user) * 0.05 * (random.random() + 1)) lost_ = int(get_num(loser_db_user) * 0.05) elif diff <= 1.2: # ちょっと弱い get_ = int(get_num(loser_db_user) * 0.03) lost_ = int(get_num(loser_db_user) * 0.02) elif diff <= 1.4: # 結構弱い get_ = int(get_num(loser_db_user) * 0.02) lost_ = int(get_num(loser_db_user) * 0.02) elif diff <= 1.8: get_ = int(get_num(loser_db_user) * 0.01) lost_ = int(get_num(loser_db_user) * 0.01) else: get_ = 0 lost_ = 0 if not hp_or_mp: await self.bot.db.update_user(winner.id, winner_db_user.hp + get_, winner_db_user.mp) await self.bot.db.update_user(loser.id, loser_db_user.hp - lost_, loser_db_user.mp) await self.send( f'{winner.mention}, HP: {winner_db_user.hp} -> {winner_db_user.hp + get_}' ) await self.send( f'{loser.mention}, HP: {loser_db_user.hp} -> {loser_db_user.hp - lost_}' ) else: await self.bot.db.update_user(winner.id, winner_db_user.hp, winner_db_user.mp + get_) await self.bot.db.update_user(loser.id, loser_db_user.hp, loser_db_user.mp - lost_) await self.send( f'{winner.mention}, MP: {winner_db_user.mp} -> {winner_db_user.mp + get_}' ) await self.send( f'{loser.mention}, MP: {loser_db_user.mp} -> {loser_db_user.mp - lost_}' ) async def raise_spell(self, wait_time: int = 5) -> None: if self.finish: return self.ready_to_raise = True await sleep(wait_time) alpha_to_beta_damage = _calc_damage(self.alpha_spell, self.beta_spell) beta_to_alpha_damage = _calc_damage(self.beta_spell, self.alpha_spell) await self.send( f'{self.alpha.mention} から {self.beta.mention} に {alpha_to_beta_damage} ダメージ!' ) await self.send( f'{self.beta.mention} から {self.alpha.mention} に {beta_to_alpha_damage} ダメージ!' ) self.alpha_hp -= beta_to_alpha_damage self.beta_hp -= alpha_to_beta_damage if self.alpha_hp <= 0 and self.beta_hp <= 0: await self.send('相打ち!両者HPが0になったため、相打ちとなります。') self.finish = True self.game_finish_flag.set() elif self.alpha_hp <= 0: await self.win(self.beta, self.alpha) self.finish = True self.game_finish_flag.set() elif self.beta_hp <= 0: await self.win(self.alpha, self.beta) self.finish = True self.game_finish_flag.set() else: await self.send( f'{self.alpha.mention}\n HP: {self.alpha_hp}\n MP: {self.alpha_mp}', allowed_mentions=discord.AllowedMentions(users=False)) await self.send( f'{self.beta.mention}\n HP: {self.beta_hp}\n MP: {self.beta_mp}', allowed_mentions=discord.AllowedMentions(users=False)) self.battle_finish_flag.set() self.battle_finish_flag.clear() self.alpha_spell = None self.beta_spell = None self.ready_to_raise = False def use_mp(self, user: str, mp: int = 1) -> bool: if user == 'alpha': self.alpha_mp -= mp if self.alpha_mp < 0: self.alpha_mp = 0 return False return True else: self.beta_mp -= mp if self.beta_mp < 0: self.beta_mp = 0 return False return True async def force_end_game(self) -> None: await self.send('入力がなかったためゲームを終了します。') self.game_finish_flag.set() self.battle_finish_flag.set() self.finish = True async def loop(self, check: Callable, user: str) -> None: while not self.bot.is_closed() and not self.finish: try: message = await self.wait_for('message', check=check, timeout=1000) except TimeoutError: return await self.force_end_game() if message.content != 'aria command': continue await self.send('システム: 魔法の発動開始を確認。物質生成フェーズへ移行します。') spell = await self.recv_command(check, user) if spell is None: await self.send('システム: 魔法の発動に失敗しました。') continue await self.send('システム: 魔法の発動を開始します。') if user == 'alpha': self.alpha_spell = spell else: self.beta_spell = spell if self.ready_to_raise: await self.battle_finish_flag.wait() continue await self.raise_spell(5 - spell.burst) async def auto_heal_loop(self) -> None: while not self.finish: if self.alpha_db_user is None: continue if self.beta_db_user is None: continue self.alpha_mp += (self.alpha_db_user.mp // 50) self.beta_mp += (self.beta_db_user.mp // 50) if self.alpha_db_user.mp < self.alpha_mp: self.alpha_mp = self.alpha_db_user.mp if self.beta_db_user.mp < self.beta_mp: self.beta_mp = self.beta_db_user.mp await sleep(10) async def start(self) -> None: alpha_db_user = await self.bot.db.get_user(self.alpha.id) beta_db_user = await self.bot.db.get_user(self.beta.id) self.alpha_hp = alpha_db_user.hp self.alpha_mp = alpha_db_user.mp self.beta_hp = beta_db_user.hp self.beta_mp = beta_db_user.mp self.alpha_db_user = alpha_db_user self.beta_db_user = beta_db_user await self.send('ゲームスタート!') tasks = [ self.bot.loop.create_task(self.loop(self.alpha_check, 'alpha')), self.bot.loop.create_task(self.loop(self.beta_check, 'beta')), self.bot.loop.create_task(self.auto_heal_loop()) ] await self.game_finish_flag.wait() for task in tasks: if not task.done(): task.cancel()
class TwitterCog(commands.Cog): """ Follow twitter feeds """ def __init__(self, bot): self.bot = bot self.__cog_name__ = "Twitter" self.to_update_stream = False self.api = None self.key_set = 1 self.feeds_list = [] self.stream = None self.tweet_queue = Queue() self.tweet_event = Event() self.initialize.start() self.tweet_handler.start() @tasks.loop(count=1) async def initialize(self): try: await self.update_feeds() self.restart_stream() except Exception: traceback.print_exc() @tasks.loop(seconds=0) async def tweet_handler(self): try: await self.tweet_event.wait() while self.tweet_queue.qsize() > 0: tweet = self.tweet_queue.get() for channel_id in tweet["channel_ids"]: try: channel = await self.bot.fetch_channel(channel_id) await channel.send(tweet["url"]) except Exception: print(f'Cannot post in channel {channel_id}') self.tweet_event.clear() except Exception: traceback.print_exc() async def moderator_role_check(ctx): if ctx.author.guild_permissions.administrator is True: return True try: moderator_role_id = await guilds.db_get_moderator_role_id(ctx.guild.id) moderator_role = await commands.RoleConverter().convert(ctx, moderator_role_id) except Exception as e: print(e) return False return ctx.author.top_role >= moderator_role @commands.group(case_insensitive=True) @commands.guild_only() @commands.check(moderator_role_check) async def twitter(self, ctx): if ctx.invoked_subcommand is None: await ctx.invoke(self.feeds) @twitter.command( description="Follow a user to a discord channel", usage="Args:\n" + "- Twitter username\n" + "- Text channel" ) async def add(self, ctx, *, arg=""): args = shlex.split(arg) if len(args) != 2: await ctx.send('Invalid arguments. Please do `twitter add <username> <text channel>`') return try: channel = await commands.TextChannelConverter().convert(ctx, args[1]) except Exception: await ctx.send('Invalid text channel.') return try: user = self.api.get_user(screen_name=args[0]) except Exception as e: print(e) await ctx.send(e) return result = await feeds_twitter.db_add_feed(user.id, channel.id) print(result) if result is True: await ctx.send( f'{ctx.guild.me.mention} will now update new tweets of `{user.screen_name}` to {channel.mention}.' ) self.to_update_stream = True await sleep(20) await self.update_feeds() if self.to_update_stream: self.restart_stream() elif result is False: await self.update_feeds() await ctx.send( f'{ctx.guild.me.mention} will now update new tweets of `{user.screen_name}` to {channel.mention}.' ) else: await ctx.send('Already existing.') @twitter.command( description="Stop selected feeds. Get the index needed using `feeds`.\n" + "For example `twitter delete 1 2 3 4` will delete feeds 1, 2, 3, and 4.", usage="Args:\n" + "- Feed index/es" ) async def delete(self, ctx, *args): if len(args) == 0: await ctx.send('Please include the feed numbers to delete `twitter delete <feed number> <feed number> ...`\nYou can use `twitter feeds` to list the feeds.') return self.feeds_list[:] = await feeds_twitter.db_get_feeds() guild_channels = ctx.guild.channels guild_channels_ids = [str(guild_channels[x].id) for x in range(0, len(guild_channels))] response = "" deleted_feed_counter = 0 feed_counter = 0 for feed in self.feeds_list: try: user = self.api.get_user(id=feed["_id"]) except Exception: traceback.print_exc() for feed_channel_id in feed['channelIds']: if feed_channel_id in guild_channels_ids: feed_counter += 1 if str(feed_counter) in args: await feeds_twitter.db_delete_feed(user.id, feed_channel_id) response = response + f'`{feed_counter}` Stopped following `{user.screen_name}` in <#{feed_channel_id}>\n' deleted_feed_counter += 1 if deleted_feed_counter == 0: response = "Deleted nothing. Make sure your feed numbers are correct.\n" +\ "Use `twitter feeds` to see the feed numbers.\n" +\ "Then use `twitter delete 1 5` to delete feeds 1, and 5, for example" else: response = 'Stopped following:\n' + response await ctx.send(response) await self.update_feeds() @twitter.command( description="List all the twitter feeds in this server", usage="Args: None" ) async def feeds(self, ctx): self.feeds_list[:] = await feeds_twitter.db_get_feeds() guild_channels = ctx.guild.channels guild_channels_ids = [str(guild_channels[x].id) for x in range(0, len(guild_channels))] response = f'Followed accounts for `{ctx.guild.name}`:\n' feed_counter = 0 for feed in self.feeds_list: for feed_channel_id in feed['channelIds']: if feed_channel_id in guild_channels_ids: feed_counter += 1 user = self.api.get_user(feed["_id"]) response = response + f'`{feed_counter}` Following `{user.screen_name}` in <#{feed_channel_id}>\n' await ctx.send(response) async def update_feeds(self): self.feeds_list[:] = await feeds_twitter.db_get_feeds() def restart_stream(self): print("Updated stream") self.to_update_stream = False if self.key_set == 2: self.key_set = 0 else: self.key_set += 1 auth = tweepy.OAuthHandler( TWT_CONSUMER_KEY.split(" ")[self.key_set], TWT_CONSUMER_SECRET.split(" ")[self.key_set] ) auth.set_access_token( TWT_ACCESS_TOKEN.split(" ")[self.key_set], TWT_ACCESS_TOKEN_SECRET.split(" ")[self.key_set] ) self.api = tweepy.API(auth) if self.stream is not None: self.stream.disconnect() self.stream = MyStream( twitter_cog=self, auth=self.api.auth, listener=MyStreamListener(self.feeds_list, self.tweet_queue, self.tweet_event) ) self.stream.filter(follow=[self.feeds_list[x]["_id"] for x in range(0, len(self.feeds_list))], is_async=True)
class MapAsyncIterator: """Map an AsyncIterable over a callback function. Given an AsyncIterable and a callback function, return an AsyncIterator which produces values mapped via calling the callback function. When the resulting AsyncIterator is closed, the underlying AsyncIterable will also be closed. """ def __init__( self, iterable: AsyncIterable, callback: Callable, reject_callback: Optional[Callable] = None, ) -> None: self.iterator = iterable.__aiter__() self.callback = callback self.reject_callback = reject_callback self._close_event = Event() def __aiter__(self) -> "MapAsyncIterator": """Get the iterator object.""" return self async def __anext__(self) -> Any: """Get the next value of the iterator.""" if self.is_closed: if not isasyncgen(self.iterator): raise StopAsyncIteration value = await self.iterator.__anext__() result = self.callback(value) else: aclose = ensure_future(self._close_event.wait()) anext = ensure_future(self.iterator.__anext__()) try: pending: Set[Task] = (await wait([aclose, anext], return_when=FIRST_COMPLETED))[1] except CancelledError: # cancel underlying tasks and close aclose.cancel() anext.cancel() await self.aclose() raise # re-raise the cancellation for task in pending: task.cancel() if aclose.done(): raise StopAsyncIteration error = anext.exception() if error: if not self.reject_callback or isinstance( error, (StopAsyncIteration, GeneratorExit)): raise error result = self.reject_callback(error) else: value = anext.result() result = self.callback(value) return await result if isawaitable(result) else result async def athrow( self, type_: Union[BaseException, Type[BaseException]], value: Optional[BaseException] = None, traceback: Optional[TracebackType] = None, ) -> None: """Throw an exception into the asynchronous iterator.""" if not self.is_closed: athrow = getattr(self.iterator, "athrow", None) if athrow: await athrow(type_, value, traceback) else: await self.aclose() if value is None: if traceback is None: raise type_ value = (type_ if isinstance(value, BaseException) else cast(Type[BaseException], type_)()) if traceback is not None: value = value.with_traceback(traceback) raise value async def aclose(self) -> None: """Close the iterator.""" if not self.is_closed: aclose = getattr(self.iterator, "aclose", None) if aclose: try: await aclose() except RuntimeError: pass self.is_closed = True @property def is_closed(self) -> bool: """Check whether the iterator is closed.""" return self._close_event.is_set() @is_closed.setter def is_closed(self, value: bool) -> None: """Mark the iterator as closed.""" if value: self._close_event.set() else: self._close_event.clear()
class PubSubInstance: __slots__ = '_pubsub', '_encoder', '_decoder', '_closed', '_messages', '_event' def __init__(self, pubsub, encoder, decoder): self._pubsub = pubsub self._encoder = encoder or utf8_encode self._decoder = decoder self._closed = False self._messages = deque() self._event = Event() async def __aenter__(self): if self._closed: raise RedisError('Pub/sub instance closed') return self async def __aexit__(self, exc_type, exc_value, traceback): await self.aclose() async def aclose(self): if not self._closed: self._closed = True try: await self._pubsub.unregister(self) except Exception: pass self._messages = None self._decoder = None self._encoder = None self._pubsub = None self._event = None async def add(self, channels=None, patterns=None): await self._cmd(self._pubsub.register, channels, patterns) # TODO (question) should we removed the self._messages that are not related to this channels and patterns (left overs)? async def remove(self, channels=None, patterns=None): await self._cmd(self._pubsub.unregister, channels, patterns) async def message(self, timeout=None): if self._closed: raise RedisError('Pub/sub instance closed') msg = self._get_message() if msg is not None: return msg self._event.clear() # We check connection here to notify the end user if there is an connection error... await self._pubsub.check_connection(self) if timeout is None: await self._event.wait() else: try: await wait_for(self._event.wait(), timeout) except AsyncIOTimeoutError: pass return self._get_message() async def ping(self, message=None): if self._closed: raise RedisError('Pub/sub instance closed') await self._pubsub.ping(message) async def _cmd(self, cmd, channels, patterns): if self._closed: raise RedisError('Pub/sub instance closed') if channels: if isinstance(channels, (str, bytes)): channels = [channels] channels = [self._encoder(x) for x in channels] if patterns: if isinstance(patterns, (str, bytes)): patterns = [patterns] patterns = [self._encoder(x) for x in patterns] await cmd(self, channels, patterns) def _add_message(self, msg): if self._messages is not None: self._messages.append(msg) self._event.set() def _get_message(self): try: msg = self._messages.popleft() except IndexError: return None if self._decoder: msg = self._decoder(msg) if isinstance(msg, Exception): raise msg return msg