async def _check_runtime(self, recv, evt: anyio.abc.Event = None): """This gets streamed a message when processing begins, and `None` when it ends. Repeat. """ if evt is not None: await evt.set() while True: msg = await recv.get() if msg is False: return assert msg is not None try: async with anyio.fail_after(0.2): msg = await recv.get() if msg is False: return assert msg is None except TimeoutError: log.error("Processing delayed: %s", msg) t = await anyio.current_time() # don't hard-fail that fast when debugging async with anyio.fail_after(1 if 'pdb' not in sys.modules else 99): msg = await recv.get() if msg is False: return assert msg is None pass # processing delayed, you have a problem log.error("Processing recovered after %.2f sec", (await anyio.current_time()) - t)
async def test_coro(): async with distkv_server(0) as s: async with create_broker( broker_config, plugin_namespace="hbmqtt.test.plugins") as broker: async with open_mqttclient() as client: await client.connect('mqtt://127.0.0.1/') self.assertIsNotNone(client.session) async with open_mqttclient() as client_pub: await client_pub.connect('mqtt://127.0.0.1/') await client_pub.publish('test_topic', data, QOS_0, retain=True) await anyio.sleep(1) async with open_mqttclient() as client: await client.connect('mqtt://127.0.0.1/') self.assertIsNotNone(client.session) ret = await client.subscribe([ ('test_topic', QOS_0), ]) self.assertEqual(ret[0], QOS_0) async with anyio.fail_after(0.5): message = await client.deliver_message() self.assertIsNotNone(message) self.assertIsNotNone(message.publish_packet) self.assertEqual(message.data, data) async with create_broker( broker_config, plugin_namespace="hbmqtt.test.plugins") as broker: async with open_mqttclient() as client: await client.connect('mqtt://127.0.0.1/') self.assertIsNotNone(client.session) ret = await client.subscribe([ ('test_topic', QOS_0), ]) self.assertEqual(ret[0], QOS_0) async with anyio.fail_after(0.5): message = await client.deliver_message() self.assertIsNotNone(message) self.assertIsNotNone(message.publish_packet) self.assertEqual(message.data, data) seen = 0 for h, p, *_ in s.ports: if h[0] != ":": break async with open_client(host=h, port=p) as cl: async for m in cl.get_tree(min_depth=1): del m['tock'] del m['seq'] assert m == { 'path': ('test', 'retain', 'test_topic'), 'value': b'data 1234' } seen += 1 assert seen == 1
async def _reader_loop(self, done): async with anyio.open_cancel_scope(shield=True) as scope: self._reader_scope = scope try: await done.set() while True: try: if self._stream is None: raise exceptions.AmqpClosedConnection if self.server_heartbeat: timeout = self.server_heartbeat * 2 else: timeout = inf async with anyio.fail_after(timeout): try: channel, frame = await self.get_frame() except anyio.ClosedResourceError: # the stream is now *really* closed … return try: await self.dispatch_frame(channel, frame) except Exception as exc: # We want to raise this exception so that the # nursery ends the protocol, but we need keep # going for now (need to process the close-OK # message). Thus we start a new task that # raises the actual error, somewhat later. if self._nursery is None: raise async def owch(exc): await anyio.sleep(0.01) raise exc logger.error("Queue", repr(exc)) await self._nursery.spawn(owch, exc) except TimeoutError: await self.connection_closed.set() raise exceptions.HeartbeatTimeoutError(self) from None except exceptions.AmqpClosedConnection as exc: logger.debug("Remote closed connection") if self.state in (CLOSING, CLOSED): return raise finally: self._reader_scope = None async with anyio.fail_after(2, shield=True): await self.connection_closed.set()
async def do_pub(client, arguments): logger.info("%s Connecting to broker", client.client_id) await client.connect( uri=arguments["--url"], cleansession=arguments["--clean-session"], cafile=arguments["--ca-file"], capath=arguments["--ca-path"], cadata=arguments["--ca-data"], extra_headers=_get_extra_headers(arguments), ) try: qos = _get_qos(arguments) topic = arguments["-t"] retain = arguments["-r"] async with anyio.create_task_group() as tg: for message in _get_message(arguments): logger.info("%s Publishing to '%s'", client.client_id, topic) await tg.spawn(client.publish, topic, message, qos, retain) logger.info("%s Disconnected from broker", client.client_id) except KeyboardInterrupt: logger.info("%s Disconnected from broker", client.client_id) except ConnectException as ce: logger.fatal("connection to '%s' failed: %r", arguments["--url"], ce) finally: async with anyio.fail_after(2, shield=True): await client.disconnect()
async def __aexit__(self, *tb): self._q = None try: async with anyio.fail_after(2, shield=True): await self.client._unsubscribe(self) except ClientException: pass
async def open_mqttclient(client_id=None, config=None, codec=None): """ MQTT client implementation. MQTTClient instances provides API for connecting to a broker and send/receive messages using the MQTT protocol. :param client_id: MQTT client ID to use when connecting to the broker. If none, it will generated randomly by :func:`distmqtt.utils.gen_client_id` :param config: Client configuration :param codec: Codec to default to, the config or "no-op" if not given. :return: async context manager returning a class instance Example usage:: async with open_mqttclient(config=dict(uri="mqtt://my-broker.example")) as client: # await client.connect("mqtt://my-broker.example") # alternate use await C.subscribe([ ('$SYS/broker/uptime', QOS_1), ('$SYS/broker/load/#', QOS_2), ]) async for msg in client: packet = message.publish_packet print("%d: %s => %s" % (i, packet.variable_header.topic_name, str(packet.payload.data))) """ async with anyio.create_task_group() as tg: C = MQTTClient(tg, client_id, config=config, codec=codec) try: if isinstance(config, dict) and 'uri' in config: await C.connect(**config) yield C finally: async with anyio.fail_after(2, shield=True): await C.disconnect() await tg.cancel_scope.cancel()
async def event_login_encryption(cls, evt: ConfirmEncryptionEvent): if evt.verify != evt._conn.verify_token: raise Exception("Invalid verification token!") evt._conn.cipher.enable(evt.secret) digest = make_digest( evt._conn.server_id.encode(), evt.secret, ServerCore.pubkey ) url = "https://sessionserver.mojang.com/session/minecraft/hasJoined" params = { "username": evt._conn.name, "serverId": digest } if ServerCore.options["prevent-proxy-connections"]: params["ip"] = evt._conn.client.server_hostname async with fail_after(ServerCore.auth_timeout): # FIXME: Fails on curio resp = await asks.get(url, params=params) data = resp.json() info(data) evt._conn.uuid = UUID(data["id"]) evt._conn.packet_decoder.status = 3 return PlayerRegistry.add_player(evt._conn)
async def test_receive_signals(): async with open_signal_receiver(signal.SIGUSR1, signal.SIGUSR2) as sigiter: await run_sync_in_worker_thread(os.kill, os.getpid(), signal.SIGUSR1) await run_sync_in_worker_thread(os.kill, os.getpid(), signal.SIGUSR2) async with fail_after(1): assert await sigiter.__anext__() == signal.SIGUSR1 assert await sigiter.__anext__() == signal.SIGUSR2
async def test_run_job_nonscheduled_success(sync, fail): def sync_func(*args, **kwargs): nonlocal received_args, received_kwargs received_args = args received_kwargs = kwargs if fail: raise Exception('failing as requested') else: return 'success' async def async_func(*args, **kwargs): nonlocal received_args, received_kwargs received_args = args received_kwargs = kwargs if fail: raise Exception('failing as requested') else: return 'success' received_args = received_kwargs = None events = [] async with LocalExecutor() as worker: await worker.subscribe(events.append) job = Job('task_id', sync_func if sync else async_func, args=(1, 2), kwargs={'x': 'foo'}) await worker.submit_job(job) async with fail_after(1): while len(events) < 3: await sleep(0) assert received_args == (1, 2) assert received_kwargs == {'x': 'foo'} assert isinstance(events[0], JobAdded) assert events[0].job_id == job.id assert events[0].task_id == 'task_id' assert events[0].schedule_id is None assert events[0].scheduled_start_time is None assert isinstance(events[1], JobUpdated) assert events[1].job_id == job.id assert events[1].task_id == 'task_id' assert events[1].schedule_id is None assert events[1].scheduled_start_time is None assert events[2].job_id == job.id assert events[2].task_id == 'task_id' assert events[2].schedule_id is None assert events[2].scheduled_start_time is None if fail: assert isinstance(events[2], JobFailed) assert type(events[2].exception) is Exception assert isinstance(events[2].formatted_traceback, str) else: assert isinstance(events[2], JobSuccessful) assert events[2].return_value == 'success'
async def test_start__consumer_cancelled__reconnect_and_process_message( rabbitmq, make_consumer, get_channel, queue_name, publish, ): result_future = asyncio.Future() async def callback(message): result_future.set_result(message.body) consumer = make_consumer(Process(callback)) async with anyio.create_task_group() as tg: tg.start_soon(consumer.start) await asyncio.sleep(1) async with get_channel() as channel: await channel.queue_delete(queue_name) await asyncio.sleep(1) await publish(b'1') with anyio.fail_after(3): result = await result_future tg.cancel_scope.cancel() assert result == b'1'
async def test_run_deadline_missed(): def func(): pytest.fail('This function should never be run') scheduled_start_time = datetime(2020, 9, 14) events = [] async with LocalExecutor() as worker: await worker.subscribe(events.append) job = Job('task_id', func, args=(), kwargs={}, schedule_id='foo', scheduled_start_time=scheduled_start_time, start_deadline=datetime(2020, 9, 14, 1)) await worker.submit_job(job) async with fail_after(1): while len(events) < 2: await sleep(0) assert isinstance(events[0], JobAdded) assert events[0].job_id == job.id assert events[0].task_id == 'task_id' assert events[0].schedule_id == 'foo' assert events[0].scheduled_start_time == scheduled_start_time assert isinstance(events[1], JobDeadlineMissed) assert events[1].job_id == job.id assert events[1].task_id == 'task_id' assert events[1].schedule_id == 'foo' assert events[1].scheduled_start_time == scheduled_start_time
async def wait_running(*workers, delay=4): async with anyio.fail_after(delay): for w in workers: while not w.started: await anyio.sleep(0.005) for e in w.executors: await e.wait_for(lambda: e.rebalanced)
async def test_send_after_eof(self, socket_path): async def handle(stream): async with stream: await stream.send(b'Hello\n') async with await create_unix_listener( socket_path) as listener, create_task_group() as tg: tg.spawn(listener.serve, handle) await wait_all_tasks_blocked() with socket.socket(socket.AF_UNIX) as client: client.connect(str(socket_path)) client.shutdown(socket.SHUT_WR) client.setblocking(False) with fail_after(1): while True: try: message = client.recv(10) except BlockingIOError: await sleep(0) else: assert message == b'Hello\n' break tg.cancel_scope.cancel()
async def introspect(self, bus_name: str, path: str, timeout: float = 30.0) -> intr.Node: """Get introspection data for the node at the given path from the given bus name. Calls the standard ``org.freedesktop.DBus.Introspectable.Introspect`` on the bus for the path. :param bus_name: The name to introspect. :type bus_name: str :param path: The path to introspect. :type path: str :param timeout: The timeout to introspect. :type timeout: float :returns: The introspection data for the name at the path. :rtype: :class:`Node <asyncdbus.introspection.Node>` :raises: - :class:`InvalidObjectPathError <asyncdbus.InvalidObjectPathError>` \ - If the given object path is not valid. - :class:`InvalidBusNameError <asyncdbus.InvalidBusNameError>` - If \ the given bus name is not valid. - :class:`DBusError <asyncdbus.DBusError>` - If the service threw \ an error for the method call or returned an invalid result. - :class:`Exception` - If a connection error occurred. - :class:`TimeoutError` - Waited for future but time run out. """ future = ValueEvent() super().introspect(bus_name, path, future) with anyio.fail_after(timeout): return await future
async def test_send_fds(self, server_sock, socket_path, tmp_path): def serve(): fds = array.array('i') client, _ = server_sock.accept() msg, ancdata, *_ = client.recvmsg( 10, socket.CMSG_LEN(2 * fds.itemsize)) client.close() assert msg == b'test' for cmsg_level, cmsg_type, cmsg_data in ancdata: assert cmsg_level == socket.SOL_SOCKET assert cmsg_type == socket.SCM_RIGHTS fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) text = '' for fd in fds: with os.fdopen(fd) as file: text += file.read() assert text == 'Hello, World!' path1 = tmp_path / 'file1' path2 = tmp_path / 'file2' path1.write_text('Hello, ') path2.write_text('World!') with path1.open() as file1, path2.open() as file2, fail_after(2): async with await connect_unix(socket_path) as stream: thread = Thread(target=serve, daemon=True) thread.start() await stream.send_fds(b'test', [file1, file2]) thread.join()
async def test_send_after_eof(self, family): async def handle(stream): async with stream: await stream.send(b'Hello\n') multi = await create_tcp_listener(family=family, local_host='localhost') async with multi, create_task_group() as tg: tg.spawn(multi.serve, handle) await wait_all_tasks_blocked() with socket.socket(family) as client: client.connect(multi.extra(SocketAttribute.local_address)) client.shutdown(socket.SHUT_WR) client.setblocking(False) with fail_after(1): while True: try: message = client.recv(10) except BlockingIOError: await sleep(0) else: assert message == b'Hello\n' break tg.cancel_scope.cancel()
async def receive_log(): try: async with async_amqp.connect_amqp() as protocol: channel = await protocol.channel() exchange_name = 'direct_logs' await channel.exchange(exchange_name, 'direct') result = await channel.queue(queue_name='', durable=False, auto_delete=True) queue_name = result['queue'] severities = sys.argv[1:] if not severities: print("Usage: %s [info] [warning] [error]" % (sys.argv[0], )) sys.exit(1) for severity in severities: await channel.queue_bind( exchange_name='direct_logs', queue_name=queue_name, routing_key=severity, ) print(' [*] Waiting for logs. To exit press CTRL+C') async with anyio.fail_after(10): await channel.basic_consume(callback, queue_name=queue_name) except async_amqp.AmqpClosedConnection: print("closed connections") return
async def test_coro(): async with create_broker( test_config, plugin_namespace="distmqtt.test.plugins") as broker: broker.plugins_manager._tg = broker._tg self.assertTrue(broker.transitions.is_started()) async with open_mqttclient() as sub_client: await sub_client.connect('mqtt://127.0.0.1') ret = await sub_client.subscribe([('+/monitor/Clients', QOS_0)]) self.assertEqual(ret, [QOS_0]) await self._client_publish('/test/monitor/Clients', b'data', QOS_0) message = await sub_client.deliver_message() self.assertIsNotNone(message) await self._client_publish('$SYS/monitor/Clients', b'data', QOS_0) message = None with self.assertRaises(TimeoutError): async with anyio.fail_after(2): message = await sub_client.deliver_message() self.assertIsNone(message) self.assertTrue(broker.transitions.is_stopped())
async def test_coro(): async with distkv_server(1): async with create_broker( broker_config, plugin_namespace="distmqtt.test.plugins"): async with open_mqttclient( config=broker_config["broker"]) as client: self.assertIsNotNone(client.session) ret = await client.subscribe([("test/vis/foo", QOS_0)]) self.assertEqual(ret[0], QOS_0) async with open_mqttclient( config=broker_config["broker"]) as client_pub: await client_pub.publish("test/vis/foo", data, QOS_0, retain=False) async with anyio.fail_after(0.5): message = await client.deliver_message() self.assertIsNotNone(message) self.assertIsNotNone(message.publish_packet) self.assertEqual(message.data, data) pass # exit client pass # exit broker pass # exit server pass # exit test
async def test_read_pbmsg_safe_readexactly_fails(): host = "127.0.0.1" port = 5566 event = anyio.create_event() async with anyio.create_task_group() as tg, await anyio.create_tcp_server( port=port, interface=host) as server: async def handler_stream(stream): pb_msg = p2pd_pb.Response() try: await read_pbmsg_safe(stream, pb_msg) except anyio.exceptions.IncompleteRead: await event.set() async def server_serve(): async for client in server.accept_connections(): await tg.spawn(handler_stream, client) await tg.spawn(server_serve) stream = await anyio.connect_tcp(address=host, port=port) # close the stream. Therefore the handler should receive EOF, and then `readexactly` raises. await stream.close() async with anyio.fail_after(5): await event.wait()
async def tcpsocket(address=None, connect_timeout=None, tcp_keepalive=None, tcp_nodelay=None, **kwargs): if address is None: address = ("localhost", 6379) async with anyio.fail_after(connect_timeout): sock = await anyio.connect_tcp(address[0], address[1]) if tcp_nodelay is not None: if tcp_nodelay: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) else: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0) if tcp_keepalive: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if platform == "linux": sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, tcp_keepalive) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, tcp_keepalive // 3) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) elif platform == "darwin": sock.setsockopt(socket.IPPROTO_TCP, 0x10, tcp_keepalive // 3) elif platform == "windows": sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, tcp_keepalive * 1000, tcp_keepalive // 3 * 1000)) return sock
async def _delivery_loop(self): """Server: process incoming messages""" try: async with anyio.open_cancel_scope() as scope: self._delivery_task = scope broker = self._broker broker.logger.debug("%s handling message delivery", self.client_id) while True: app_message = await self.get_next_message() await self._plugins_manager.fire_event( EVENT_BROKER_MESSAGE_RECEIVED, client_id=self.client_id, message=app_message) await broker.broadcast_message( self, app_message.topic, app_message.data, qos=app_message.qos, retain=app_message.publish_packet.retain_flag) finally: async with anyio.fail_after(2, shield=True): broker.logger.debug("%s finished message delivery", self.client_id) self._delivery_task = None await self._delivery_stopped.set()
async def _sender_loop(self, evt): keepalive_timeout = self.session.keep_alive if keepalive_timeout <= 0: keepalive_timeout = None try: async with anyio.open_cancel_scope() as scope: self._sender_task = scope await evt.set() while True: packet = None async with anyio.move_on_after(keepalive_timeout): packet = await self._send_q.get() if packet is None: # closing break if packet is None: # timeout await self.handle_write_timeout() continue # self.logger.debug("%s > %r",'B' if 'Broker' in type(self).__name__ else 'C', packet) await packet.to_stream(self.stream) await self.plugins_manager.fire_event( EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session) except ConnectionResetError: await self.handle_connection_closed() except anyio.get_cancelled_exc_class(): raise except BaseException as e: self.logger.warning("Unhandled exception", exc_info=e) raise finally: async with anyio.fail_after(2, shield=True): await self._sender_stopped.set() self._sender_task = None
async def test_run_deadline_missed(self, store): async def listener(worker_event): worker_events.append(worker_event) await event.set() scheduled_start_time = datetime(2020, 9, 14) worker_events = [] event = create_event() job = Job('task_id', fail_func, args=(), kwargs={}, schedule_id='foo', scheduled_fire_time=scheduled_start_time, start_deadline=datetime(2020, 9, 14, 1)) async with AsyncWorker(store) as worker: worker.subscribe(listener) await store.add_job(job) async with fail_after(5): await event.wait() assert len(worker_events) == 1 assert isinstance(worker_events[0], JobDeadlineMissed) assert worker_events[0].job_id == job.id assert worker_events[0].task_id == 'task_id' assert worker_events[0].schedule_id == 'foo' assert worker_events[0].scheduled_fire_time == scheduled_start_time
async def create_broker(config=None, plugin_namespace=None): """MQTT 3.1.1 compliant broker implementation :param config: Example Yaml config :param plugin_namespace: Plugin namespace to use when loading plugin entry_points. Defaults to ``distmqtt.broker.plugins`` This is an async context manager:: async with create_broker() as broker: while True: anyio.sleep(99999) """ async with anyio.create_task_group() as tg: if "distkv" in (config or {}): from .distkv_broker import DistKVbroker B = DistKVbroker else: B = Broker b = B(tg, config, plugin_namespace) try: await b.start() yield b finally: async with anyio.fail_after(2, shield=True): await b.shutdown() await tg.cancel_scope.cancel()
async def test_receive_signals() -> None: with open_signal_receiver(signal.SIGUSR1, signal.SIGUSR2) as sigiter: await to_thread.run_sync(os.kill, os.getpid(), signal.SIGUSR1) await to_thread.run_sync(os.kill, os.getpid(), signal.SIGUSR2) with fail_after(1): assert await sigiter.__anext__() == signal.SIGUSR1 assert await sigiter.__anext__() == signal.SIGUSR2
async def callback(): try: async with anyio.fail_after(resource.__timeout__): await value_event.set( await self._exit_stack.enter_async_context( resource(self, ctx))) except AttributeError as e: raise NotAResourceError() from e
async def _task_teardown(self, *tb): async with anyio.fail_after(2, shield=True): await self.teardown() return await super()._task_teardown(*tb) # Any unprocessed events get relegated to the parent while True: try: async with anyio.fail_after(0.001): if self._q is None: break if self._q.empty(): break evt = self._q.get() except TimeoutError: break await self._handle_prev(evt)
async def watch_for_connection(watchdog_queue: asyncio.Queue): while True: try: async with fail_after(ChatRuntimeSettings().WATCHDOG_TIMEOUT): msg = await watchdog_queue.get() watchdog_logger.info(msg) except TimeoutError: raise WatchdogException("not see ping or human messages")
async def close(self, no_wait=False): """Close connection (and all channels)""" if self.state == CLOSED: return if self.state == CLOSING: if not no_wait: await self.wait_closed() return try: self.state = CLOSING got_close = self.connection_closed.is_set() await self.connection_closed.set() if not got_close: await self._close_channels() # If the closing handshake is in progress, let it complete. request = pamqp.specification.Connection.Close(reply_code=0, reply_text='', class_id=0, method_id=0) try: await self._write_frame(0, request) except anyio.ClosedResourceError: pass except Exception: logger.exception("Error while closing") else: if not no_wait and self.server_heartbeat: async with anyio.move_on_after(self.server_heartbeat / 2): await self.wait_closed() except BaseException as exc: async with anyio.fail_after(2, shield=True): await self._close_channels(exception=exc) raise finally: async with anyio.fail_after(2, shield=True): try: await self._cancel_all() await self._stream.aclose() finally: self._nursery = None self.state = CLOSED
async def _reader(self, val): async with anyio.open_cancel_scope() as scope: await val.set(scope) it = self._msg_proto.__aiter__() while True: try: async with anyio.fail_after(15): res, data = await it.__anext__() except ServerBusy as exc: logger.info("Server %s busy", self.host) except (StopAsyncIteration, TimeoutError, IncompleteRead, ConnectionResetError, ClosedResourceError): await self._reconnect() it = self._msg_proto.__aiter__() else: msg = self.requests.popleft() await msg.process_reply(res, data, self) if not msg.done(): self.requests.appendleft(msg)
async def _writer(self, val): async with anyio.open_cancel_scope() as scope: await val.set(scope) while True: try: async with anyio.fail_after(10): msg = await self._wqueue.get() except TimeoutError: msg = NOPMsg() self.requests.append(msg) try: await msg.write(self._msg_proto) # except trio.ClosedResourceError: # # will get restarted by .reconnect() # return except IncompleteRead: await self.stream.close() return # wil be restarted by the reader except BaseException: await self.stream.close() raise