def __init__(self, plugins_manager: PluginManager, session: Session = None): self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__) if session: self._init_session(session) else: self.session = None self.stream = None self.plugins_manager = plugins_manager self._tg = plugins_manager._tg self._reader_task = None self._sender_task = None self._reader_stopped = anyio.create_event() self._sender_stopped = anyio.create_event() self._send_q = anyio.create_queue(10) self._puback_waiters = dict() self._pubrec_waiters = dict() self._pubrel_waiters = dict() self._pubcomp_waiters = dict() self._disconnecting = False self._disconnect_waiter = None self._write_lock = anyio.create_lock()
async def test_reply_to(self, amqp): server_future = anyio.create_event() async with anyio.create_task_group() as n: done_here = anyio.create_event() await n.spawn(self._server, amqp, server_future, exchange_name, server_routing_key, done_here) await done_here.wait() correlation_id = 'secret correlation id' client_routing_key = 'secret_client_key' client_future = anyio.create_event() done_here = anyio.create_event() await n.spawn(self._client, amqp, client_future, exchange_name, server_routing_key, correlation_id, client_routing_key, done_here) await done_here.wait() logger.debug('Waiting for server to receive message') await server_future.wait() server_body, server_envelope, server_properties = \ server_future.test_result assert server_body == b'client message' assert server_properties.correlation_id == correlation_id assert server_properties.reply_to == client_routing_key assert server_envelope.routing_key == server_routing_key logger.debug('Waiting for client to receive message') await client_future.wait() client_body, client_envelope, client_properties = \ client_future.test_result assert client_body == b'reply message' assert client_properties.correlation_id == correlation_id assert client_envelope.routing_key == client_routing_key await n.cancel_scope.cancel()
async def test_cancel_worker_thread(): """ Test that when a task running a worker thread is cancelled, the cancellation is not acted on until the thread finishes. """ def thread_worker(): nonlocal last_active run_async_from_thread(sleep_event.set) time.sleep(0.2) last_active = 'thread' run_async_from_thread(finish_event.set) async def task_worker(): nonlocal last_active try: await run_in_thread(thread_worker) finally: last_active = 'task' sleep_event = create_event() finish_event = create_event() last_active = None async with create_task_group() as tg: await tg.spawn(task_worker) await sleep_event.wait() await tg.cancel_scope.cancel() await finish_event.wait() assert last_active == 'task'
async def __call__(self): async with self.dual_call() as (icm, ocm): # Coordinate between the tasks; DTMF sending must wait # until the receiver is listening sync1 = anyio.create_event() sync3 = anyio.create_event() sync2 = anyio.create_event() in_dtmf = random_dtmf(len=self.call.dtmf.len) out_dtmf = random_dtmf(len=self.call.dtmf.len) async def run_in(): await self.connect_in(icm) await sync1.wait() await icm.channel.sendDTMF(dtmf=in_dtmf, between=0.5) await ExpectDTMF(icm, dtmf=out_dtmf, ready=sync2, may_repeat=self.call.dtmf.may_repeat) await sync3.set() async def run_out(): await self.connect_out(ocm) await ExpectDTMF(ocm, dtmf=in_dtmf, ready=sync1, may_repeat=self.call.dtmf.may_repeat) await sync2.wait() await ocm.channel.sendDTMF(dtmf=out_dtmf, between=0.5) await sync3.wait() await icm.taskgroup.spawn(run_in) await ocm.taskgroup.spawn(run_out) await sync3.wait()
def __init__(self, tg: anyio.abc.TaskGroup, client_id=None, config=None, codec=None): self.logger = logging.getLogger(__name__) self.config = copy.deepcopy(_defaults) if config is not None: self.config.update(config) if client_id is not None: self.client_id = client_id else: from distmqtt.utils import gen_client_id self.client_id = gen_client_id() self.logger.debug("Using generated client ID : %s", self.client_id) self.session = None self._tg = tg self._handler = None self._disconnect_task = None self._connected_state = anyio.create_event() self._no_more_connections = anyio.create_event() self.extra_headers = {} self.codec = get_codec(codec, config=self.config) self._subscriptions = None # Init plugins manager context = ClientContext(self.config) self.plugins_manager = PluginManager(tg, "distmqtt.client.plugins", context) self.client_task = None
async def is_running(self, partition: "Partition"): async with partition.lock: await self.new_joiner.set() self.new_joiner = anyio.create_event() try: yield finally: logger.debug("consumer-cleanup") # Remove the partition from the set managed by this Runner. This will # cause any pending fetches to abort when they complete. del self.partitions[partition] # Wait until any pending fetches for the partition complete. Otherwise, # there's a race condition between the current fetcher and the one that # will own the partition next. This fetcher will read, but not ack, and # any events will stay in the PEL. Since we read as the same Redis # group/consumer, depending on its startup time the new owner might miss # the events read by this fetcher. while partition in self.fetching: await self.fetch_completed.wait() # Reset the partition's pointer so that the next task to take ownership # starts by reading the pending entries list. This is necessary because # we may have read, but not processed, some events. partition.reset() # Notify any waiting tasks that we are about to exit and relinquish our # run lock. await self.new_leaver.set() self.new_leaver = anyio.create_event() logger.debug("consumer-exit")
async def test_consume_callaback_synced(self, amqp): self.consume_future = anyio.create_event() # declare async with amqp.new_channel() as channel: await channel.queue_declare("q", exclusive=True, no_wait=False) await channel.exchange_declare("e", "fanout") await channel.queue_bind("q", "e", routing_key='') # get a different channel async with amqp.new_channel() as channel: # publish await channel.publish( b"coucou", "e", routing_key='', ) sync_future = anyio.create_event() async def callback(channel, body, envelope, properties): assert sync_future.is_set() await channel.basic_consume(callback, queue_name="q") await sync_future.set()
async def __aenter__(self): self._loans += 1 if self._loans == 1: self._schedules_event = create_event() self._jobs_event = create_event() return await super().__aenter__()
async def test_query(self): async with anyio.create_task_group() as tg: async with serf_client(codec=UTF8Codec()) as serf1: async with serf_client(codec=UTF8Codec()) as serf2: ev1 = anyio.create_event() ev2 = anyio.create_event() await tg.spawn(self.answer_query, serf2, ev1) await ev1.wait() await tg.spawn(self.ask_query, serf1, ev2) await ev2.wait()
async def flush(self): """ Send our write-buffered data. """ if self._heap: self._done = anyio.create_event() await self._ending.set() await self._done.wait() self._ending = anyio.create_event() self._done = None
async def start(self): cfg = self.config['distkv'] evt = anyio.create_event() await self._tg.spawn(self.__session, cfg, evt) await evt.wait() if 'retain' in cfg: evt = anyio.create_event() await self._tg.spawn(self.__retain, cfg, evt) await evt.wait() await super().start()
def __init__(self, settings, client, group): self.fragment_size = settings["prudp.fragment_size"] self.max_substream_id = settings["prudp.max_substream_id"] self.supported_functions = settings["prudp.supported_functions"] self.minor_ver = settings["prudp.minor_version"] self.resend_timeout = settings["prudp.resend_timeout"] self.resend_limit = settings["prudp.resend_limit"] self.ping_timeout = settings["prudp.ping_timeout"] self.version = settings["prudp.version"] self.settings = settings self.client = client self.group = group self.payload_encoder = PayloadEncoder(settings) self.packet_encoder = MessageEncoder(settings) self.sequence_mgr = SequenceMgr(settings) self.scheduler = scheduler.Scheduler(group) self.ack_events = {} self.ping_event = None self.local_port = None self.remote_port = None self.local_session_id = random.randint(0, 0xFF) self.remote_session_id = None self.local_signature = self.packet_encoder.calc_connection_signature(client.remote_address()) self.remote_signature = None substreams = self.max_substream_id + 1 self.sliding_windows = [SlidingWindow() for i in range(substreams)] self.fragment_buffers = [b""] * substreams self.packets = [socketutils.PacketQueue() for i in range(substreams)] self.unreliable_packets = socketutils.PacketQueue() self.credentials = None self.server_key = None self.session_key = b"" self.user_pid = None self.user_cid = None self.connection_check = random.randint(0, 0xFFFFFFFF) self.serving = False self.syn_complete = False self.connect_ack = None self.closing = False self.handshake = anyio.create_event() self.closed = anyio.create_event()
async def start(self): cfg = self.config["distkv"] await super().start() evt = anyio.create_event() await self._tg.spawn(self.__session, cfg, evt) await evt.wait() evt = anyio.create_event() await self._tg.spawn(self.__retain_reader, cfg, evt) await evt.wait()
async def test_consume_multiple_queues(self, amqp): self.consume_future = anyio.create_event() async with amqp.new_channel() as channel: await channel.queue_declare("q1", exclusive=True, no_wait=False) await channel.queue_declare("q2", exclusive=True, no_wait=False) await channel.exchange_declare("e", "direct") await channel.queue_bind("q1", "e", routing_key="q1") await channel.queue_bind("q2", "e", routing_key="q2") # get a different channel async with amqp.new_channel() as channel: q1_future = anyio.create_event() async def q1_callback(channel, body, envelope, properties): self.q1_result = (body, envelope, properties) await q1_future.set() q2_future = anyio.create_event() async def q2_callback(channel, body, envelope, properties): self.q2_result = (body, envelope, properties) await q2_future.set() # start consumers result = await channel.basic_consume(q1_callback, queue_name="q1") ctag_q1 = result['consumer_tag'] result = await channel.basic_consume(q2_callback, queue_name="q2") ctag_q2 = result['consumer_tag'] # put message in q1 await channel.publish(b"coucou1", "e", "q1") # get it await q1_future.wait() body1, envelope1, properties1 = self.q1_result assert ctag_q1 == envelope1.consumer_tag assert envelope1.delivery_tag is not None assert b"coucou1" == body1 assert isinstance(properties1, Properties) # put message in q2 await channel.publish(b"coucou2", "e", "q2") # get it await q2_future.wait() body2, envelope2, properties2 = self.q2_result assert ctag_q2 == envelope2.consumer_tag assert b"coucou2" == body2 assert isinstance(properties2, Properties)
async def run(self, client, updated=None): """ Background task runner for this test, stores exceptions. :param updated: Callback that's fired when this test's status changes. The accumulated test status is in the ``state`` attribute. """ if updated is None: async def updated(): pass else: updated = partial(updated, self) state = self.state state.update({ "n_run": 0, # total "n_fail": 0, # total "running": False, "last_exc": None, "fail_map": [], # last 20 or whatever "fail_count": 0, "retry_after": self.test.retry, "repeat_after": self.test.repeat, "timeout": self.timeout, }) if self.test.skip: # on demand only while True: await updated() self._delay = anyio.create_event() await self._delay.wait() await updated() await self._run(client) else: while True: await updated() await self._run(client) await updated() self._delay = anyio.create_event() if state.fail_count > 0: dly = self.test.retry else: dly = self.test.repeat async with anyio.move_on_after(dly): await self._delay.wait()
def __init__(self): service_state = type(self).ServiceState transitions = [ [ "initialized", service_state.initializing, service_state.initialized ], [ "starting", [service_state.initialized, service_state.stopped], service_state.starting ], [ "started", [service_state.starting, ServiceRestartState.starting], service_state.started ], ["restarting", service_state.started, "restarting"], ["stopping", "restarting", ServiceRestartState.stopping], [ "stopped", ServiceRestartState.stopping, ServiceRestartState.stopped ], [ "starting", ServiceRestartState.stopped, ServiceRestartState.starting ], ["stopping", service_state.started, service_state.stopping], ["stopped", service_state.stopping, service_state.stopped], ["crashed", "*", service_state.crashed], ] super().__init__(states=service_state, transitions=transitions, initial=service_state.initializing, auto_transitions=False) self._restart_count: int = 0 self._cancel_scope: anyio.CancelScope = anyio.open_cancel_scope() self._exit_stack: AsyncExitStack = AsyncExitStack() self._started_event: anyio.Event = anyio.create_event() self._shutdown_event: anyio.Event = anyio.create_event() self._dependencies = [] self.on_enter_restarting(self._increase_restart_count) self.on_enter_started(self._notify_started) self.on_enter_stopping(self._reset_started_event) self.on_enter_starting(self._reset_shutdown_event) self.on_enter_stopped(self._notify_shutdown)
async def start(self): self._disconnect_waiter = anyio.create_event() if not self._is_attached(): raise ProtocolHandlerException( "Handler is not attached to a stream") evt = anyio.create_event() await self._tg.spawn(self._reader_loop, evt) await evt.wait() self.logger.debug("Handler tasks started") await self._retry_deliveries() self.logger.debug( "%s %s ready", "Broker" if "Broker" in type(self).__name__ else "Client", self.session.client_id if self.session else "?", )
async def command(self, msgtype, payload=b""): if isinstance(payload, (list, dict)): payload = json.dumps(payload).encode() elif isinstance(payload, str): payload = payload.encode("utf-8") elif isinstance(payload, bytes): pass else: raise ValueError(type(payload)) async with self._command_lock: try: self._pending = { "event": anyio.create_event(), "msgtype": msgtype, "response": None, } await self._sock.send( struct.pack(self._FORMAT, self._MAGIC, len(payload), msgtype)) await self._sock.send(payload) await self._pending["event"].wait() return self._pending["response"] finally: self._pending = None
async def test_run_deadline_missed(self, store): async def listener(worker_event): worker_events.append(worker_event) await event.set() scheduled_start_time = datetime(2020, 9, 14) worker_events = [] event = create_event() job = Job('task_id', fail_func, args=(), kwargs={}, schedule_id='foo', scheduled_fire_time=scheduled_start_time, start_deadline=datetime(2020, 9, 14, 1)) async with AsyncWorker(store) as worker: worker.subscribe(listener) await store.add_job(job) async with fail_after(5): await event.wait() assert len(worker_events) == 1 assert isinstance(worker_events[0], JobDeadlineMissed) assert worker_events[0].job_id == job.id assert worker_events[0].task_id == 'task_id' assert worker_events[0].schedule_id == 'foo' assert worker_events[0].scheduled_fire_time == scheduled_start_time
async def test_basic_nack_requeue(self, amqp): queue_name = 'queue_name' exchange_name = 'exchange_name' routing_key = '' await self.publish(amqp, queue_name, exchange_name, routing_key, b"payload") qfuture = anyio.create_event() called = False async with amqp.new_channel() as channel: async def qcallback(channel, body, envelope, _properties): nonlocal called if not called: called = True await channel.basic_client_nack(envelope.delivery_tag, requeue=True) else: await channel.basic_client_ack(envelope.delivery_tag) await qfuture.set() await channel.basic_consume(qcallback, queue_name=queue_name) await qfuture.wait()
def __init__(self, channel, rpc_name): self.channel = channel self.rpc_name = rpc_name self.event = anyio.create_event() self.result = None self.exc = None channel._add_future(self)
async def test_big_consume(self, amqp): self.consume_future = anyio.create_event() # declare async with amqp.new_channel() as channel: await channel.queue_declare("q", exclusive=True, no_wait=False) await channel.exchange_declare("e", "fanout") await channel.queue_bind("q", "e", routing_key='') # get a different channel async with amqp.new_channel() as channel: # publish await channel.publish( b"a" * 1000000, "e", routing_key='', ) # start consume await channel.basic_consume(self.callback, queue_name="q") # get one body, envelope, properties = await self.get_callback_result() assert envelope.consumer_tag is not None assert envelope.delivery_tag is not None assert b"a" * 1000000 == body assert isinstance(properties, Properties)
async def test_wrong_callback_argument(self): self.consume_future = anyio.create_event() def badcallback(): pass self.reset_vhost() proto = testcase.connect(virtualhost=self.vhost, ) with pytest.raises(TypeError): async with proto as amqp: async with amqp.new_channel() as chan: await chan.queue_declare("q", exclusive=True, no_wait=False) await chan.exchange_declare("e", "fanout") await chan.queue_bind("q", "e", routing_key='') # get a different channel async with amqp.new_channel() as channel: # publish await channel.publish( "coucou", "e", routing_key='', ) # assert there is a message to consume await self.check_messages(amqp, "q", 1) # start consume await channel.basic_consume(badcallback, queue_name="q") await anyio.sleep(1)
async def _run_ctx(self, evt: anyio.abc.Event=None): assert self._done is None self._done = anyio.create_event() async with self.task: if evt is not None: await evt.set() await self._done.wait()
async def run(self, evt: anyio.abc.Event=None): """ Process my events. Override+call this e.g. for overall timeouts:: async def run(self): async with anyio.fail_after(30): await super().run() This method creates a runner task that do the actual event processing. A new runner is started if processing an event takes longer than 0.1 seconds. Do not replace this method. Do not call it directly. """ log.debug("SetupRun %r < %r", self, getattr(self, '_prev', None)) if evt is not None: await evt.set() await self.on_start() if self._ready is not None: await self._ready.set() self._proc_lock = anyio.create_lock() while True: if self._n_proc == 0: await self.taskgroup.spawn(self._process, name="Worker " + self.ref_id) self._proc_check = anyio.create_event() await anyio.sleep(0.1) await self._proc_check.wait()
def __init__(self, watermark: int, direction="lt"): self.watermark = watermark self.direction = direction self._event = anyio.create_event() s, r = anyio.create_memory_object_stream(max_buffer_size=math.inf) self._send_stream, self._receive_stream = s, r
async def _next_to_send(self): """ Returns the next message on the heap """ while True: while self._heap: if self._heap_large is not None and len( self._heap) < self._heap_max / 2: await self._heap_large.set() self._heap_large = None if self._ending.is_set( ) or self._heap[0].time <= self._t - self._delay: return heapq.heappop(self._heap) self._t = time.time() if self._heap[0].time <= self._t - self._delay: return heapq.heappop(self._heap) async with anyio.move_on_after( max(self._delay + self._heap[0].time - self._t, 0)): await self._ending.wait() await self.flush_buf() if self._done is not None: await self._done.set() self._heap_item = anyio.create_event() await self._heap_item.wait()
async def get(self): item = await self._receive_stream.receive() if self.direction == "lt": if self.qsize() < self.watermark: await self._event.set() self._event = anyio.create_event() return item
async def test_read_pbmsg_safe_readexactly_fails(): host = "127.0.0.1" port = 5566 event = anyio.create_event() async with anyio.create_task_group() as tg, await anyio.create_tcp_server( port=port, interface=host) as server: async def handler_stream(stream): pb_msg = p2pd_pb.Response() try: await read_pbmsg_safe(stream, pb_msg) except anyio.exceptions.IncompleteRead: await event.set() async def server_serve(): async for client in server.accept_connections(): await tg.spawn(handler_stream, client) await tg.spawn(server_serve) stream = await anyio.connect_tcp(address=host, port=port) # close the stream. Therefore the handler should receive EOF, and then `readexactly` raises. await stream.close() async with anyio.fail_after(5): await event.wait()
async def distkv_server(n): msgs = [] async with anyio.create_task_group() as tg: async with create_broker(test_config, plugin_namespace="distmqtt.test.plugins"): s = Server("test", cfg=broker_config["distkv"], init="test") evt = anyio.create_event() await tg.spawn(partial(s.serve, ready_evt=evt)) await evt.wait() async with open_client(**broker_config["distkv"]) as cl: async def msglog(task_status=trio.TASK_STATUS_IGNORED): async with cl._stream( "msg_monitor", topic="*" ) as mon: # , topic=broker_config['distkv']['topic']) as mon: log.info("Monitor Start") task_status.started() async for m in mon: log.info("Monitor Msg %r", m.data) msgs.append(m.data) await cl.scope.spawn(msglog) yield s await cl.scope.cancel() await tg.cancel_scope.cancel() if len(msgs) != n: log.error("MsgCount %d %d", len(msgs), n)