async def run_server(task_status=trio.TASK_STATUS_IGNORED): if sock: listeners = [trio.SocketListener(from_stdlib_socket(sock.sock))] else: if ssl: listeners = await trio.open_ssl_over_tcp_listeners(port, ssl, host=host) else: listeners = await trio.open_tcp_listeners(port, host=host) # An outer nursery to run connection tasks in. server = None async with trio.open_nursery() as connection_nursery: # We can cancel serve_listeners() via this scope, without cancelling # the connection tasks. with trio.open_cancel_scope() as cancel_scope: server = Server(listeners=listeners, nursery=nursery, cancel_scope=cancel_scope) task_status.started(server) await trio.serve_listeners(functools.partial( handle_connection, protocol_factory), listeners, handler_nursery=connection_nursery) if server is not None: server._closed.set()
def _cancel_when(self, flag): with trio.open_cancel_scope() as cancel_scope: try: self._cancel_scopes[flag].add(cancel_scope) yield finally: self._cancel_scopes[flag].remove(cancel_scope)
async def connect(self, broker): self.broker = weakref.ref(broker) self.is_disconnected = trio.Event() timeout = self.cfg.timeout.connect if not timeout: timeout = math.inf with trio.open_cancel_scope() as scope: scope.timeout = timeout async with trio_amqp.connect_amqp(**self.cfg.server) as amqp: scope.timeout = math.inf self.amqp = amqp nursery = self.nursery try: for _ in range(self.rpc_workers): await nursery.start(self._work_rpc) for _ in range(self.alert_workers): await nursery.start(self._work_alert) await self.setup_channels() yield self except BaseException as exc: logger.debug("Problem in connection", exc_info=exc) raise finally: self.is_disconnected.set() nursery.cancel_scope.cancel() self.amqp = None
async def __aexit__(self, *tb): with trio.open_cancel_scope(shield=True): try: await self.channel.basic_cancel(self.consumer_tag) except AmqpClosedConnection: pass del self._q
async def _recv_task(self, *, task_status=trio.TASK_STATUS_IGNORED) -> None: handlers = { Command.PING: self._recv_ping, Command.PONG: self._recv_pong, Command.DATA: self._recv_data, Command.DATA_ACK: self._recv_data_ack, Command.ACK: self._recv_ack, Command.ERROR: self._recv_error, Command.ERROR_UNDEF: self._recv_error_undef, } with trio.open_cancel_scope() as self._cancel_recv: task_status.started() while True: if not self._recv_data_task and not self._pings and not len(self._message_ids): await self._recv_parker.park() try: cmd = await self.stream.recv_byte() try: handler = handlers[cmd] except KeyError as ex: raise exception.UnknownCommand(cmd) from ex try: await handler() except EOFError as ex: raise exception.IncompleteMessage() from ex except trio.Cancelled as ex: raise Exception('Cancelled while receiving!') from ex
async def _trio_claim_user( queue, qt_on_done, qt_on_error, config, addr, device_id, token, use_pkcs11, password=None, pkcs11_token=None, pkcs11_key=None, ): portal = trio.BlockingTrioPortal() queue.put(portal) with trio.open_cancel_scope() as cancel_scope: queue.put(cancel_scope) try: async with backend_anonymous_cmds_factory(addr) as cmds: device = await core_claim_user(cmds, device_id, token) if use_pkcs11: devices_manager.save_device_with_pkcs11( config.config_dir, device, pkcs11_token, pkcs11_key) else: devices_manager.save_device_with_password( config.config_dir, device, password) qt_on_done.emit() except BackendCmdsBadResponse as e: qt_on_error.emit(e.status)
async def run(self, msg): if self.call_conv == CC_DICT: a = () k = msg.data if not isinstance(k, Mapping): assert k is None, k k = {} elif self.call_conv == CC_DATA: a = (msg.data,) k = {} else: a = (msg,) k = {} if self.call_conv == CC_TASK: await msg.conn.nursery.start(self._run, self.fn, msg) else: try: res = await coro_wrapper(self.fn, *a, **k) if res is not None: await msg.reply(res) except Exception as exc: await msg.error(exc, _exit=self.debug) finally: with trio.open_cancel_scope(shield=True, deadline=trio.current_time() + 1): with suppress(AmqpClosedConnection): await msg.aclose()
def _cancel_when_something_happens(self): with trio.open_cancel_scope() as cancel_scope: try: self._cancel_scopes.add(cancel_scope) yield finally: self._cancel_scopes.remove(cancel_scope)
async def __trio_thread_main(self): # The non-context-manager equivalent of open_loop() async with trio.open_nursery() as nursery: asyncio.set_event_loop(self) await self._main_loop_init(nursery) self.__blocking_result_queue.put(True) while not self._closed: # This *blocks* req = self.__blocking_job_queue.get() if req is None: self.stop() break async_fn, args = req result = await trio.hazmat.Result.acapture(async_fn, *args) if type(result) == trio.hazmat.Error and type( result.error) == trio.Cancelled: res = RuntimeError("Main loop cancelled") res.__cause__ = result.error.__cause__ result = trio.hazmat.Error(res) self.__blocking_result_queue.put(result) with trio.open_cancel_scope(shield=True): await self._main_loop_exit() self.__blocking_result_queue.put(None) nursery.cancel_scope.cancel()
async def __aexit__(self, *tb): if not self.channel.is_open: return with trio.open_cancel_scope(shield=True): try: await self.channel.close() except exceptions.AmqpClosedConnection: pass
def cancel_on_graceful_shutdown(self): with trio.open_cancel_scope() as cancel_scope: self._cancel_scopes.add(cancel_scope) if self._shutting_down: cancel_scope.cancel() try: yield finally: self._cancel_scopes.remove(cancel_scope)
async def monitor_messages(backend_online, fs, event_bus, *, task_status=trio.TASK_STATUS_IGNORED): msg_arrived = trio.Event() backend_online_event = trio.Event() process_message_cancel_scope = None def _on_msg_arrived(event, index=None): msg_arrived.set() event_bus.connect("backend.message.received", _on_msg_arrived, weak=True) event_bus.connect("backend.message.polling_needed", _on_msg_arrived, weak=True) def _on_backend_online(event): backend_online_event.set() def _on_backend_offline(event): backend_online_event.clear() if process_message_cancel_scope: process_message_cancel_scope.cancel() event_bus.connect("backend.online", _on_backend_online, weak=True) event_bus.connect("backend.offline", _on_backend_offline, weak=True) if backend_online: _on_backend_online(None) task_status.started() while True: try: with trio.open_cancel_scope() as process_message_cancel_scope: event_bus.send( "message_monitor.reconnection_message_processing.started") try: await fs.process_last_messages() finally: event_bus.send( "message_monitor.reconnection_message_processing.done") while True: await msg_arrived.wait() msg_arrived.clear() try: await fs.process_last_messages() except SharingError: logger.exception("Invalid message from backend") except BackendNotAvailable: pass process_message_cancel_scope = None msg_arrived.clear() await backend_online_event.wait()
async def _recv_callback(self): """ Callable passed to the ASGI awaitable that consumes from the event queue. """ with trio.open_cancel_scope() as scope: self._cancels.add(scope) try: return await self._event_queue.get() finally: self._cancels.remove(scope)
async def _consumer(in_queue, out_queue, *, task_status=trio.TASK_STATUS_IGNORED): with trio.open_cancel_scope() as cancel_scope: task_status.started(cancel_scope) while True: x, y = await in_queue.get() await trio.sleep(0) result = x + y await out_queue.put('%s + %s = %s' % (x, y, result))
async def _run_idle(self, task_status=trio.TASK_STATUS_IGNORED): """ Run the "idle proc" under a separate scope so that it can be cancelled when the connection comes back. """ try: with trio.open_cancel_scope() as s: self._idle = s await self.idle_proc() finally: self._idle = None
async def _dispatch(self, mode, channel, body, envelope, properties): try: routing_key = properties.headers['routing-key'] except (KeyError, AttributeError): routing_key = envelope.routing_key if routing_key != envelope.routing_key: logger.debug( "read %s %s on %s for %s: %s", mode, envelope.delivery_tag, envelope.routing_key, routing_key, body ) else: logger.debug("read %s %s for %s: %s", mode, envelope.delivery_tag, routing_key, body) try: codec = get_codec(properties.content_type) msg = codec.decode(body) msg = BaseMsg.load( msg, envelope, properties, channel=channel, type='server', reply_channel=self._ch_reply.channel, reply_exchange=self._ch_reply.exchange, conn=self ) n = mode + '.' + msg.routing_key try: rpc = self.rpcs[n] except KeyError: while True: i = n.rfind('.') if i < 1: raise n = n[:i] rpc = self.rpcs.get(n + '.#', None) if rpc is not None: break await rpc.run(msg) except KeyError: logger.info( "Unknown message %s %s on %s for %s: %s", mode, envelope.delivery_tag, envelope.routing_key, routing_key, body ) await channel.basic_reject(envelope.delivery_tag) except BaseException: with trio.open_cancel_scope(shield=True): with suppress(AmqpClosedConnection): await channel.basic_reject(envelope.delivery_tag) raise
async def _handle_invite_and_create_user(queue, qt_on_done, qt_on_error, core, username, token, is_admin): try: with trio.open_cancel_scope() as cancel_scope: queue.put(cancel_scope) await invite_and_create_user(core.device, core.backend_cmds, username, token, is_admin) qt_on_done.emit() except BackendCmdsBadResponse as exc: qt_on_error.emit(exc.status) except: qt_on_error.emit(None)
async def __aexit__(self, *tb): self.nursery.cancel_scope.cancel() with trio.open_cancel_scope(shield=True): try: if self.conn is not None: await self.conn.aclose() except BaseException as exc: logger.debug("Conn ended", exc_info=exc) raise finally: self.conn = None self._running = False
async def run_asyncio_loop(nursery, *, task_status=trio.TASK_STATUS_IGNORED): with trio.open_cancel_scope() as cancel_scope: try: async with trio_asyncio.open_loop(): # Starting a coroutine from here make it inherit the access # to the asyncio loop context manager await nursery.start(work_in_trio_no_matter_what) task_status.started(cancel_scope) await trio.sleep_forever() finally: asyncio_loop_closed.set()
async def _run(self, fn, msg, task_status=trio.TASK_STATUS_IGNORED): task_status.started() try: res = await fn(msg) except Exception as exc: await msg.error(exc, _exit=self.debug) else: if res is not None: await msg.reply(res) finally: with trio.open_cancel_scope(shield=True, deadline=trio.current_time() + 1): with suppress(AmqpClosedConnection): await msg.aclose()
async def work_in_trio_no_matter_what(*, task_status=trio.TASK_STATUS_IGNORED ): await trio_asyncio.aio_as_trio(work_in_asyncio)() try: # KeyboardInterrupt won't cancel this coroutine thanks to the shield with trio.open_cancel_scope(shield=True): task_status.started() await asyncio_loop_closed.wait() finally: # Hence this call will be exceuted after run_asyncio_loop is cancelled with pytest.raises(RuntimeError): await trio_asyncio.aio_as_trio(work_in_asyncio)()
async def _writer_loop(self, fd, handle, task_status=trio.TASK_STATUS_IGNORED): with trio.open_cancel_scope() as scope: handle._scope = scope task_status.started() try: while not handle._cancelled: # pragma: no branch await _wait_writable(fd) handle._call_sync() await self.synchronize() except Exception as exc: _h_raise(handle, exc) return finally: handle._scope = None
async def _send_task(self, *, task_status=trio.TASK_STATUS_IGNORED) -> None: try: with trio.open_cancel_scope() as self._cancel_send: task_status.started() while not self._send_eof or self._send_buf: await self._stream.wait_send_all_might_not_block() if not self._send_buf: await self._send_parker.park() sb = self._send_buf self._send_buf = bytearray() await self._stream.send_all(sb) finally: self._send_task_finished.set()
async def _run(): try: portal = trio.BlockingTrioPortal() self.core_queue.put(portal) with trio.open_cancel_scope() as cancel_scope: self.core_queue.put(cancel_scope) async with logged_core_factory( self.core_config, self.current_device ) as core: self.core_queue.put(core) await trio.sleep_forever() # If we have an exception, we never put the core object in the queue. Since the # main thread except something to be there, we put the exception. except Exception as exc: self.core_queue.put(exc)
async def _client_consumer(self, channel, client_future, task_status=trio.TASK_STATUS_IGNORED): with trio.open_cancel_scope() as scope: self._client_scope = scope async with channel.new_consumer(queue_name=client_queue_name) \ as data: task_status.started() logger.debug('Client consuming messages') async for body, envelope, properties in data: logger.debug('Client received message') client_future.test_result = (body, envelope, properties) client_future.set()
async def aclose(self, no_wait=False): """Close connection (and all channels)""" if self.state == CLOSED: return if self.state == CLOSING: if not no_wait: await self.wait_closed() return try: self.state = CLOSING got_close = self.connection_closed.is_set() self.connection_closed.set() if not got_close: self._close_channels() # If the closing handshake is in progress, let it complete. frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, 0) frame.declare_method( amqp_constants.CLASS_CONNECTION, amqp_constants.CONNECTION_CLOSE ) encoder = amqp_frame.AmqpEncoder() # we request a clean connection close encoder.write_short(0) encoder.write_shortstr('') encoder.write_short(0) encoder.write_short(0) try: await self._write_frame(frame, encoder) except trio.ClosedStreamError: pass except Exception: logger.exception("Error while closing") else: if not no_wait and self.server_heartbeat: with trio.move_on_after(self.server_heartbeat / 2): await self.wait_closed() except BaseException as exc: self._close_channels(exception=exc) raise finally: with trio.open_cancel_scope(shield=True): self._cancel_all() await self._stream.aclose() self._nursery = None self.state = CLOSED
async def _call_async(self, task_status=trio.TASK_STATUS_IGNORED): assert not self._is_sync if self._cancelled: return task_status.started() try: with trio.open_cancel_scope() as scope: self._scope = scope if self._is_sync is None: await self._context.run(self._callback, self) else: await self._context.run(self._callback, *self._args) except Exception as exc: self._raise(exc) finally: self._scope = None
async def task_wrapper(): with trio.open_cancel_scope() as scope: task._cancel_scope = scope try: if inspect.iscoroutine(async_func): result = await async_func else: result = await async_func() except Exception as e: task_set_exception(task, e) return if scope.cancelled_caught: task_set_cancelled(task) else: task_set_result(task, result)
async def _reader_loop(self, task_status=trio.TASK_STATUS_IGNORED): with trio.open_cancel_scope(shield=True) as scope: self._reader_scope = scope try: task_status.started() while True: try: if self._stream is None: raise exceptions.AmqpClosedConnection if self.server_heartbeat: timeout = self.server_heartbeat * 2 else: timeout = inf with trio.fail_after(timeout): try: frame = await self.get_frame() except (trio.BrokenResourceError, trio.ClosedResourceError): # the stream is now *really* closed … return try: await self.dispatch_frame(frame) except Exception as exc: # We want to raise this exception so that the # nursery ends the protocol, but we need keep # going for now (need to process the close-OK # message). Thus we start a new task that # raises the actual error, somewhat later. async def owch(exc): await trio.sleep(0) raise exc self._nursery.start_soon(owch, exc) except trio.TooSlowError: self.connection_closed.set() raise exceptions.HeartbeatTimeoutError(self) from None except exceptions.AmqpClosedConnection as exc: logger.debug("Remote closed connection") raise finally: self._reader_scope = None self.connection_closed.set()
async def __run_trio(self, h): """Helper for copying the result of a Trio task to an asyncio future""" f, proc, *args = h._args if f.cancelled(): # pragma: no cover return try: with trio.open_cancel_scope() as scope: h._scope = scope res = await proc(*args) if scope.cancelled_caught: f.cancel() return except BaseException as exc: if not f.cancelled(): # pragma: no branch f.set_exception(exc) else: if not f.cancelled(): # pragma: no branch f.set_result(res)
async def _task(*, task_status=trio.TASK_STATUS_IGNORED): task = trio.hazmat.current_task() task.name = f"{func.__module__}.{func.__qualname__}" stopped = trio.Event() try: with trio.open_cancel_scope() as cancel_scope: async def stop(): cancel_scope.cancel() await stopped.wait() task_status.started(stop) await func(*args, **kwargs) finally: stopped.set()
async def _event_pump(self, *, task_status=trio.TASK_STATUS_IGNORED): with trio.open_cancel_scope() as cancel_scope: async with backend_cmds_factory( self.device.organization_addr, self.device.device_id, self.device.signing_key, max_pool=1, ) as cmds: # Copy `self._subscribed_beacons` to avoid concurrent modifications await cmds.events_subscribe( message_received=True, beacon_updated=self._subscribed_beacons.copy()) # Given the backend won't notify us for messages that arrived while # we were offline, we must actively check this ourself. self.event_bus.send("backend.message.polling_needed") task_status.started(cancel_scope) await self._event_pump_do(cmds)
async def _join_fuse_thread(mountpoint, fuse_operations, fuse_thread_stopped, stop=False): if fuse_thread_stopped.is_set(): return # Ask for dummy file just to force a fuse operation that will # process the `fuse_exit` from a valid context # Note given python fs api is blocking, we must run it inside a thread # to avoid blocking the trio loop and ending up in a deadlock def _wakeup_fuse(): try: (mountpoint / "__shutdown_fuse__").exists() except OSError: pass with trio.open_cancel_scope(shield=True): if stop: logger.info("Stopping fuse thread...") fuse_operations.schedule_exit() await trio.run_sync_in_worker_thread(_wakeup_fuse) await trio.run_sync_in_worker_thread(fuse_thread_stopped.wait) logger.info("Fuse thread stopped")
async def _server_consumer(self, channel, server_future, task_status=trio.TASK_STATUS_IGNORED): with trio.open_cancel_scope() as scope: self._server_scope = scope async with channel.new_consumer(queue_name=server_queue_name) \ as data: logger.debug('Server consuming messages') task_status.started() async for body, envelope, properties in data: logger.debug('Server received message') publish_properties = { 'correlation_id': properties.correlation_id } logger.debug('Replying to %r', properties.reply_to) await channel.publish(b'reply message', exchange_name, properties.reply_to, publish_properties) server_future.test_result = (body, envelope, properties) server_future.set() logger.debug('Server replied')
async def _writer_loop(self, task_status=trio.TASK_STATUS_IGNORED): with trio.open_cancel_scope(shield=True) as scope: self._writer_scope = scope task_status.started() while self.state != CLOSED: if self.server_heartbeat: timeout = self.server_heartbeat / 2 else: timeout = inf with trio.move_on_after(timeout) as timeout_scope: frame, encoder = await self._send_queue.get() if timeout_scope.cancelled_caught: await self.send_heartbeat() continue f = frame.get_frame(encoder) try: await self._stream.send_all(f) except (trio.BrokenStreamError,trio.ClosedStreamError): # raise exceptions.AmqpClosedConnection(self) from None # the reader will raise the error also return
async def _keep_connected(self, task_status=trio.TASK_STATUS_IGNORED): """Task which keeps a connection going""" class TODOexception(Exception): pass self.restarting = None while not self._stop.is_set(): try: self._reg_endpoints = set() async with Connection(self.cfg, self.uuid).connect(self) as conn: self.restarting = False self.conn = conn self._connected.set() if self._idle is not None: self._idle.cancel() self._idle = None await self._do_regs() task_status.started() await conn.is_disconnected.wait() except TODOexception: self._connected.clear() logger.exception("Error. TODO Reconnecting after a while.") finally: c, self.conn = self.conn, None if c is not None: with trio.open_cancel_scope(shield=True, deadline=trio.current_time() + 1): await c.aclose() self.restarting = True if self._stop.is_set(): break if self.idle_proc is not None: await self.nursery.start(self._run_idle) with trio.move_on_after(10): await self._stop.wait()
async def _on_dead_rpc(self, channel, body, envelope, properties): """ This handler is responsible for receiving dead-lettered messages. It builds an error reply and sends it to the client, ensuring that the error is discovered instantly, instead of waiting for a timeout. """ try: codec = get_codec(properties.content_type) msg = codec.decode(body) msg = BaseMsg.load( msg, envelope, properties, conn=self, type="server", reply_channel=self._ch_reply.channel, reply_exchange=self._ch_reply.exchange ) reply = msg.make_response(self) reply_to = getattr(msg, 'reply_to', None) exc = envelope.exchange_name if exc.startswith("dead"): exc = properties.headers['x-death'][0]['exchange'] exc = DeadLettered(exc, envelope.routing_key) if reply_to is None: # usually, this is no big deal: call debug(), not exception(). logger.debug("Undeliverable one-way message", exc_info=exc) return reply.set_error(exc, envelope.routing_key) reply, props = reply.dump(self, codec=self.codec) logger.debug("DeadLetter %s to %s", envelope.routing_key, self._ch_reply.exchange) await self._ch_reply.channel.publish( reply, self._ch_reply.exchange, reply_to, properties=props ) finally: with trio.open_cancel_scope(shield=True, deadline=trio.current_time() + 1): await channel.basic_client_ack(envelope.delivery_tag)
async def server(tree={}, msgs=(), options={}, events=None, polling=False, **kw): async with OWFS(**kw) as ow: async with trio.open_nursery() as n: s = None try: server = await n.start( partial(trio.serve_tcp, host="127.0.0.1"), partial(some_server, tree, msgs, options), 0 ) if events is not None: await n.start(events, ow) addr = server[0].socket.getsockname() s = await ow.add_server(*addr, polling=polling) ow.test_server = s await yield_(ow) finally: ow.test_server = None with trio.open_cancel_scope(shield=True): if s is not None: await s.drop() await ow.push_event(None) await trio.sleep(0.1) n.cancel_scope.cancel()
async def register(self, ep): assert ep.tag not in self.rpcs ch = getattr(self, '_ch_' + ep.type) cfg = self.cfg dn = n = ep.name if ep.name.endswith('.#'): n = n[:-2] dn = n + '._all_' if len(n) > 1 and '#' in n: raise RuntimeError("I won't find that") if ep.tag in self.rpcs: raise RuntimeError("multiple registration of " + ep.tag) self.rpcs[ep.tag] = ep chan = None try: ep._c_channel = chan = await self.amqp.channel() d = {} ttl = ep.ttl or cfg.ttl[ep.type] if ttl: d["x-dead-letter-exchange"] = cfg.queues['dead'] d["x-message-ttl"] = int(1000 * ttl) if ep.durable: if isinstance(ep.durable, str): dn = ep.durable else: dn = self.cfg.queues['msg'] + ep.tag chan = await self.amqp.channel() q = await chan.queue_declare( dn, auto_delete=False, passive=False, exclusive=False, durable=True, arguments=d ) elif ep.type == "rpc": chan = await self.amqp.channel() q = await chan.queue_declare( cfg.queues[ep.type] + ep.name, auto_delete=True, durable=False, passive=False, arguments=d ) else: chan = self._ch_alert.channel q = self._ch_alert.queue logger.debug("Chan %s: bind %s %s %s", ch.channel, ep.exchange, ep.name, q['queue']) await chan.queue_bind(q['queue'], ep.exchange, routing_key=ep.name) await chan.basic_qos(prefetch_count=1, prefetch_size=0, connection_global=False) logger.debug("Chan %s: read %s", ch, q['queue']) await chan.basic_consume( queue_name=q['queue'], callback=self._on_rpc_in if ep.type == "rpc" else self._on_alert_in ) ep._c_channel = chan ep._c_queue = q except BaseException: # pragma: no cover del self.rpcs[ep.tag] if chan is not None: del ep._c_channel with trio.open_cancel_scope(shield=True, deadline=trio.current_time() + 1): await chan.close() raise
async def check_cancel(proc, seen): with trio.open_cancel_scope() as scope: with pytest.raises(asyncio.CancelledError): await self.call_t_a(proc, seen, loop=loop) assert not scope.cancel_called seen.flag |= 4