async def _reader(self, *, result: ValueEvent = None): """Main loop for reading TODO: add a timeout for receiving message bodies. """ unpacker = msgpack.Unpacker(object_hook=self._decode_addr_key) cur_msg = None async with anyio.open_cancel_scope(shield=True) as s: if result is not None: await result.set(s) try: while self._socket is not None: if cur_msg is not None: logger.debug("%d:wait for body", self._conn_id) try: async with anyio.fail_after( 5 if cur_msg else math.inf): buf = await self._socket.receive( self._socket_recv_size) except TimeoutError: seq = cur_msg.head.get(b"Seq", None) hdl = self._handlers.get(seq, None) if hdl is not None: await hdl.set_error(SerfTimeout(cur_msg)) else: raise SerfTimeout(cur_msg) from None except anyio.ClosedResourceError: return # closed by us except OSError as err: if err.errno == errno.EBADF: return raise if len(buf) == 0: # Connection was closed. raise SerfClosedError("Connection closed by peer") unpacker.feed(buf) for msg in unpacker: if cur_msg is not None: logger.debug("%d::Body=%s", self._conn_id, msg) cur_msg.body = msg await self._handle_msg(cur_msg) cur_msg = None else: logger.debug("%d:Recv =%s", self._conn_id, msg) msg = SerfResult(msg) if await self._handle_msg(msg): cur_msg = msg finally: hdl, self._handlers = self._handlers, None async with anyio.open_cancel_scope(shield=True): for m in hdl.values(): await m.cancel()
async def test_nested_shield(): async def killer(scope): await wait_all_tasks_blocked() await scope.cancel() with pytest.raises(TimeoutError): async with create_task_group() as tg: async with open_cancel_scope() as scope: async with open_cancel_scope(shield=True): await tg.spawn(killer, scope) async with fail_after(0.2): await sleep(2)
async def server_task(evt, cb, address, port, ssl_context, listener, listener_name): async with anyio.open_cancel_scope() as scope: try: sock = await anyio.create_tcp_listener( local_port=port, local_host=address) await evt.set(scope) async def _maybe_wrap(listener, listener_name, conn): if ssl_context: try: conn = await anyio.streams.tls.TLSStream.wrap( conn, ssl_context=ssl_context, server_side=True) except Exception: self.logger.error( "Listener '%s': unknown type '%s'", listener_name, listener["type"], ) return await cb(conn) await sock.serve( partial(_maybe_wrap, listener, listener_name)) finally: await sock.aclose()
async def _run(self, client): state = self.state async with anyio.open_cancel_scope() as sc: self.scope = sc try: logger.debug("START %s", self.name) await self(client) except anyio.get_cancelled_exc_class(): state.exc = "Canceled" if self.scope is not None: state.n_fail += 1 state.fail_count += 1 state.fail_map.append(True) raise except Exception as exc: state.exc = traceback.format_exc().split('\n') state.n_fail += 1 state.fail_count += 1 state.fail_map.append(True) else: state.fail_count = 0 state.fail_map.append(False) finally: state.n_run += 1 if any(state.fail_map): del state.fail_map[:-20] else: state.fail_map = [] # zero out after 20 successes in sequence self.scope = None logger.debug("END %s", self.name)
async def broadcast_dollar_sys_topics_loop(self, interval, evt): async with anyio.open_cancel_scope() as scope: self.sys_handle = scope await evt.set() while True: await anyio.sleep(interval) await self.broadcast_dollar_sys_topics()
async def _run_reconnected(self, val: ValueEvent): try: async with anyio.open_cancel_scope() as scope: self._current_run = scope while True: try: await self._run_one(val) except anyio.get_cancelled_exc_class(): raise except ( BrokenPipeError, TimeoutError, EnvironmentError, anyio.IncompleteRead, ConnectionResetError, anyio.ClosedResourceError, StopAsyncIteration, ) as exc: if val is not None and not val.is_set(): await val.set_error(exc) return logger.error("Disconnected") val = None await anyio.sleep(self._backoff) if self._backoff < 10: self._backoff *= 1.5 else: pass finally: self._current_run = None
async def _listen(self, evt): w = self.worker number = w.call.number or w.call.dst.number async with anyio.open_cancel_scope() as sc: self._in_scope = sc self.worker.in_logger.debug("Wait for call: using %s", w.call.dst.name) async with w.client.on_start_of(w.call.dst.name) as d: await evt.set() url = getattr(w.call, 'url', None) if url is not None: await self.worker.client.taskgroup.spawn( w.url_open, number, url) args = getattr(w.call, 'exec', None) if args is not None: await self.worker.client.taskgroup.spawn( w.exec_open, number, args) async for ic_, evt_ in d: if self._in_channel is None: self._in_channel = ic_['channel'] await self._evt.set() else: self.worker.in_logger.error( "Duplicate incall on %s %s %s", w.call.dst.name, ic_, evt_)
async def test_escaping_cancelled_error_from_cancelled_task(): """Regression test for issue #88. No CancelledError should escape the outer scope.""" async with open_cancel_scope() as scope: async with move_on_after(0.1): await sleep(1) await scope.cancel()
async def handle_connection_close(self, evt): async def cancel_tasks(): await self._no_more_connections.set() if self.client_task: task, self.client_task = self.client_task, None await task.cancel() async with anyio.open_cancel_scope() as scope: self._disconnect_task = scope await evt.set() self.logger.debug("Wait for broker disconnection") # Wait for disconnection from broker (like connection lost) await self._handler.wait_disconnect() self.logger.warning("Disconnected from broker") # Block client API self._connected_state.clear() # stop and clean handler await self._handler.stop() await self._handler.detach() self.session.transitions.disconnect() if self.config.get("auto_reconnect", False): # Try reconnection self.logger.debug("Auto-reconnecting") try: await self.reconnect() except ConnectException: # Cancel client pending tasks await cancel_tasks() else: # Cancel client pending tasks await cancel_tasks()
async def deliver_message(self, codec=None): """ Deliver next received message. Deliver next message received from the broker. If no message is available, this methods waits until next message arrives or ``timeout`` occurs. This method is a *coroutine*. :return: instance of :class:`distmqtt.session.ApplicationMessage` containing received message information flow. :raises: :class:`TimeoutError` if timeout occurs before a message is delivered :param codec: Codec to decode the message with. This method returns ``None`` if it is cancelled by closing the connection. """ if codec is None: codec = self.codec elif isinstance(codec, str): codec = _codecs[codec]() async with anyio.open_cancel_scope() as scope: if self.session is None: return None if self.client_task is not None: raise RuntimeError("You can't listen in more than one task") self.client_task = scope try: msg = await self.session.get_next_message() msg.data = codec.decode(msg.publish_packet.data) return msg finally: self.client_task = None
async def _sender_loop(self, evt): keepalive_timeout = self.session.keep_alive if keepalive_timeout <= 0: keepalive_timeout = None try: async with anyio.open_cancel_scope() as scope: self._sender_task = scope await evt.set() while True: packet = None async with anyio.move_on_after(keepalive_timeout): packet = await self._send_q.get() if packet is None: # closing break if packet is None: # timeout await self.handle_write_timeout() continue # self.logger.debug("%s > %r",'B' if 'Broker' in type(self).__name__ else 'C', packet) await packet.to_stream(self.stream) await self.plugins_manager.fire_event( EVENT_MQTT_PACKET_SENT, packet=packet, session=self.session) except ConnectionResetError: await self.handle_connection_closed() except anyio.get_cancelled_exc_class(): raise except BaseException as e: self.logger.warning("Unhandled exception", exc_info=e) raise finally: async with anyio.fail_after(2, shield=True): await self._sender_stopped.set() self._sender_task = None
async def _delivery_loop(self): """Server: process incoming messages""" try: async with anyio.open_cancel_scope() as scope: self._delivery_task = scope broker = self._broker broker.logger.debug("%s handling message delivery", self.client_id) while True: app_message = await self.get_next_message() await self._plugins_manager.fire_event( EVENT_BROKER_MESSAGE_RECEIVED, client_id=self.client_id, message=app_message) await broker.broadcast_message( self, app_message.topic, app_message.data, qos=app_message.qos, retain=app_message.publish_packet.retain_flag) finally: async with anyio.fail_after(2, shield=True): broker.logger.debug("%s finished message delivery", self.client_id) self._delivery_task = None await self._delivery_stopped.set()
async def _run(self, evt: anyio.abc.Event = None): """Connect to the WebSocket and begin processing messages. This method will block until all messages have been received from the WebSocket, or until this client has been closed. :param apps: Application (or list of applications) to connect for :type apps: str or list of str This is a coroutine. Don't call it directly, it's autostarted by the context manager. """ ws = None apps = self._apps if isinstance(apps, list): self._app = apps[0] apps = ','.join(apps) else: self._app = apps.split(',', 1)[0] try: ws = await self.swagger.events.eventWebsocket(app=apps) self.websockets.add(ws) if evt is not None: await evt.set() await self.__run(ws) finally: if ws is not None: self.websockets.remove(ws) async with anyio.open_cancel_scope(shield=True): await ws.close() del self._app
async def open_mqttclient(client_id=None, config=None): """ MQTT client implementation. MQTTClient instances provides API for connecting to a broker and send/receive messages using the MQTT protocol. :param client_id: MQTT client ID to use when connecting to the broker. If none, it will generated randomly by :func:`hbmqtt.utils.gen_client_id` :param config: Client configuration :return: class instance This is an async context manager. Example usage:: async with open_mqttclient() as client: await client.connect("mqtt://my-broker.example") await C.subscribe([ ('$SYS/broker/uptime', QOS_1), ('$SYS/broker/load/#', QOS_2), ]) async for msg in client: packet = message.publish_packet print("%d: %s => %s" % (i, packet.variable_header.topic_name, str(packet.payload.data))) """ async with anyio.create_task_group() as tg: C = MQTTClient(tg, client_id, config) try: yield C finally: async with anyio.open_cancel_scope(shield=True): await C.disconnect() await tg.cancel_scope.cancel()
async def __aexit__(slf, *tb): if tb[1] is None and not self._event_queue.empty(): async with anyio.open_cancel_scope(shield=True): while not self._event_queue.empty(): evt = await self._event_queue.get() if evt is not None: logger.error("Unprocessed: %s",evt) self._event_queue = None
async def __aexit__(self, *tb): async with anyio.open_cancel_scope(shield=True): try: await self.channel.basic_cancel(self.consumer_tag) except AmqpClosedConnection: pass await self._q_w.aclose() await self._q_r.aclose()
async def server_task(evt, cb, address, port, ssl_context): async with anyio.open_cancel_scope() as scope: await evt.set(scope) async with await anyio.create_tcp_server( port, interface=address, ssl_context=ssl_context ) as server: async for conn in server.accept_connections(): await self._tg.spawn(cb, conn)
async def _handle_disconnect(self, disconnect, wait=True): self.logger.debug("Client disconnecting") self.clean_disconnect = False # depending on 'disconnect' (if set) async with anyio.open_cancel_scope(shield=True): if wait: async with anyio.move_on_after(self.session.keep_alive): await self._reader_stopped.wait() await self.stop()
async def test_nested_fail_after(): async def killer(scope): await wait_all_tasks_blocked() await scope.cancel() async with create_task_group() as tg: async with open_cancel_scope() as scope: async with open_cancel_scope(): await tg.spawn(killer, scope) async with fail_after(1): await sleep(2) pytest.fail('Execution should not reach this point') pytest.fail('Execution should not reach this point either') pytest.fail('Execution should also not reach this point') assert scope.cancel_called
async def process(self): while True: try: packet = await self.client.recv() except anyio.exceptions.ClosedResourceError: async with anyio.open_cancel_scope(shield=True): await self.cleanup() return await self.process_packet(packet.opcode, packet.payload)
async def _add_task(self, val, proc, *args): async with anyio.open_cancel_scope() as scope: await val.set(scope) try: await proc(*args) finally: try: self._tasks.remove(scope) except KeyError: pass
async def work(self): async with self.lock: async with Channel.new(self.client) as c: try: cs = ChannelState(c) async with cs.task: yield cs finally: async with anyio.open_cancel_scope(shield=True): await c.hang_up()
async def _spawn(self, val, proc, args, kw): """ Helper for starting a task. This accepts a :class:`ValueEvent`, to pass the task's cancel scope back to the caller. """ async with anyio.open_cancel_scope() as scope: await val.set(scope) await proc(*args, **kw)
async def process_packet(self, packet): if packet.signature != self.packet_encoder.calc_packet_signature(packet, self.session_key, self.local_signature): raise ValueError("Received packet with invalid signature") if packet.dest_port != self.local_port: raise ValueError("Received packet with invalid destination port") if self.remote_port is not None and packet.source_port != self.remote_port: raise ValueError("Received packet with invalid source port") if packet.flags & FLAG_ACK: key = (packet.type, packet.substream_id, packet.packet_id) if key in self.ack_events: handle = self.ack_events.pop(key) self.scheduler.remove(handle) if packet.type == TYPE_SYN: if packet.flags & FLAG_ACK: await self.process_syn_ack(packet) else: await self.process_syn(packet) else: if not self.syn_complete: print("expected syn") raise ValueError("Expected SYN packet") if packet.type == TYPE_CONNECT: if packet.flags & FLAG_ACK: await self.process_connect_ack(packet) else: await self.process_connect(packet) else: if packet.flags & FLAG_MULTI_ACK: self.handle_aggregate_ack(packet) else: if packet.substream_id > self.max_substream_id: raise ValueError("Received packet with invalid substream id: %i", packet.substream_id) if packet.session_id != self.remote_session_id: raise ValueError("Received packet with invalid session id") if not packet.flags & FLAG_ACK: if packet.flags & FLAG_NEED_ACK: await self.send_ack(packet) if packet.flags & FLAG_RELIABLE: await self.process_reliable(packet) else: if packet.type == TYPE_DATA: data = self.payload_encoder.decode(packet) await self.unreliable_packets.put(data) else: if packet.type == TYPE_DISCONNECT: if self.closing: async with anyio.open_cancel_scope(shield=True): await self.cleanup() await self.client.close() else: raise ValueError("Received unexpected DISCONNECT/ACK")
async def _run(proc, args, kw, result): """ Helper for starting a task. This accepts a :class:`ValueEvent`, to pass the task's cancel scope back to the caller. """ async with anyio.open_cancel_scope() as scope: if result is not None: await result.set(scope) await proc(*args, **kw)
async def __aexit__(self, *exc): self._running = False hdl = self._conn._handlers if self.send_stop: async with anyio.open_cancel_scope(shield=True): await self._conn.call("stop", params={b'Stop': self.seq}, expect_body=False) if hdl is not None: # TODO remember this for a while? del hdl[self.seq]
async def test_cancel_from_shielded_scope(): async with create_task_group() as tg: with open_cancel_scope(shield=True) as inner_scope: assert inner_scope.shield tg.cancel_scope.cancel() with pytest.raises(get_cancelled_exc_class()): await sleep(0.01) with pytest.raises(get_cancelled_exc_class()): await sleep(0.01)
async def ping_wait(self, evt): async with anyio.open_cancel_scope() as sc: self._scope = sc await evt.set() while True: try: async with anyio.fail_after(self._freq * 3): await self._wait.wait() except TimeoutError: logger.error("PING missing %r", self) raise self._wait = anyio.create_event()
async def _client_consumer(self, channel, client_future, done): async with anyio.open_cancel_scope() as scope: self._client_scope = scope async with channel.new_consumer(queue_name=client_queue_name) \ as data: await done.set() logger.debug('Client consuming messages') async for body, envelope, properties in data: logger.debug('Client received message') client_future.test_result = (body, envelope, properties) await client_future.set()
async def _total_timer_(self, evt: anyio.abc.Event=None): self._total_deadline = self.total_timeout + await anyio.current_time() async with anyio.open_cancel_scope() as sc: self._total_timer = sc if evt is not None: await evt.set() while True: delay = self._total_deadline - await anyio.current_time() if delay <= 0: await self._stop_playing() raise NumberTimeoutError(self.num) from None await anyio.sleep(delay)
async def _reader(self, scope): """Main loop for reading TODO: add a timeout for receiving message bodies. """ unpacker = msgpack.Unpacker(object_hook=self._decode_addr_key) cur_msg = None async with anyio.open_cancel_scope(shield=True) as s: await scope.set(s) try: while self._socket is not None: if cur_msg is not None: logger.debug("%d:wait for body", self._conn_id) try: buf = await self._socket.receive_some( self._socket_recv_size) except ClosedResourceError: return # closed by us if len(buf) == 0: # Connection was closed. raise SerfClosedError("Connection closed by peer") unpacker.feed(buf) for msg in unpacker: if cur_msg is not None: logger.debug("%d Body=%s", self._conn_id, msg) cur_msg.body = msg await self._handle_msg(cur_msg) cur_msg = None else: logger.debug("%d:Recv =%s", self._conn_id, msg) msg = SerfResult(msg) if await self._handle_msg(msg): cur_msg = msg finally: hdl, self._handlers = self._handlers, None async with anyio.open_cancel_scope(shield=True): for m in hdl.values(): await m.cancel()
async def _reader(self, val): async with anyio.open_cancel_scope() as scope: await val.set(scope) it = self._msg_proto.__aiter__() while True: try: async with anyio.fail_after(15): res, data = await it.__anext__() except ServerBusy as exc: logger.info("Server %s busy", self.host) except (StopAsyncIteration, TimeoutError, IncompleteRead, ConnectionResetError, ClosedResourceError): await self._reconnect() it = self._msg_proto.__aiter__() else: msg = self.requests.popleft() await msg.process_reply(res, data, self) if not msg.done(): self.requests.appendleft(msg)
async def _writer(self, val): async with anyio.open_cancel_scope() as scope: await val.set(scope) while True: try: async with anyio.fail_after(10): msg = await self._wqueue.get() except TimeoutError: msg = NOPMsg() self.requests.append(msg) try: await msg.write(self._msg_proto) # except trio.ClosedResourceError: # # will get restarted by .reconnect() # return except IncompleteRead: await self.stream.close() return # wil be restarted by the reader except BaseException: await self.stream.close() raise