Ejemplo n.º 1
0
class BroadcastChannel(Generic[T], trio.abc.AsyncResource):
    """\
    Bundles a set of trio channels so that messages are sent to all of them.
    When a receiver is closed, it is cleanly removed from the broadcast channel on the next send.
    Be careful about the buffer size chosen when adding a receiver; see `send()` for details.
    """

    def __init__(self) -> None:
        self._send_channels: Set[trio.abc.SendChannel] = set()
        self._stack = AsyncExitStack()

    async def send(self, value: T) -> None:
        """\
        Sends the value to all receiver channels.
        Closed receivers are removed the next time a value.'is sent using this method.
        This method will send to all receivers immediately,
        but it will block until the message got out to all receivers.

        Suppose you have receivers A and B with buffer size zero, and you send to them:

            await channel.send(1)
            await channel.send(2)

        If only B is actually reading, then `send(2)` will not be called, because `send(1)` can't finish,
        meaning the `2` is not delivered to B either.
        To prevent this, close any receivers that are done, and/or poll receive in a timely manner.
        """
        broken = set()

        async def send(channel):
            try:
                await channel.send(value)
            except trio.BrokenResourceError:
                await channel.aclose()
                broken.add(channel)

        async with trio.open_nursery() as nursery:
            for channel in self._send_channels:
                nursery.start_soon(send, channel)

        self._send_channels -= broken
        broken.clear()

    def add_receiver(self, max_buffer_size) -> trio.abc.ReceiveChannel:
        """\
        Adds a receiver to this broadcast channel with the given buffer capacity.
        The send end of the receiver is closed when the broadcast channel is closed,
        and if the receive end is closed, it is discarded from the broadcast channel.
        """
        send, receive = trio.open_memory_channel(max_buffer_size)
        self._stack.push_async_exit(send)
        self._send_channels.add(send)
        return receive

    async def aclose(self):
        """\
        Closes the broadcast channel, causing all receive channels to stop iteration.
        """
        await self._stack.aclose()
Ejemplo n.º 2
0
 async def __aenter__(self) -> StorageManager:
     async with AsyncExitStack() as stack:
         self._tmp = stack.enter_context(tempfile.TemporaryDirectory())
         assert self._tmp is not None
         self._path = pathlib.Path(self._tmp)
         await stack.enter_async_context(self._watch())
         self._raii = stack.pop_all()
     return self
Ejemplo n.º 3
0
    async def test_live_reconnect(self):

        port = free_port()

        session = self.browser_session
        async with AsyncExitStack() as stack:
            app, client_errors, server_errors = await stack.enter_async_context(
                tracker_web_server_fixture())
            app['trackers.events']['test_event'] = event = Event(
                app, 'test_event',
                yaml.load("""
                    title: Test Event
                    live: True
                    riders:
                        - name: Foo Bar
                          tracker: null
                    markers: []
                """), [])
            url = await stack.enter_async_context(
                web_server_fixture(self.loop, app, port))
            await on_new_event(event)
            await session.get(f'{url}/test_event')
            await wait_condition(ws_ready_is, session, True)

        await wait_condition(ws_ready_is, session, False)

        # Bring the server back up, reconnect
        async with AsyncExitStack() as stack:
            app, client_errors, server_errors = await stack.enter_async_context(
                tracker_web_server_fixture())
            app['trackers.events']['test_event'] = event = Event(
                app, 'test_event',
                yaml.load("""
                    title: Test Event
                    live: True
                    riders:
                        - name: Foo Bar
                          tracker: null
                    markers: []
                """), [])
            url = await stack.enter_async_context(
                web_server_fixture(self.loop, app, port))
            await on_new_event(event)
            await wait_condition(ws_ready_is, session, True, timeout=10)

        self.check_no_errors(client_errors, server_errors)
Ejemplo n.º 4
0
 async def __call__(self, reader: StreamReader,
                    writer: StreamWriter) -> None:
     conn = IMAPConnection(self.commands, self._config, reader, writer)
     state = ConnectionState(self._login, self._config)
     async with AsyncExitStack() as stack:
         connection_exit.set(stack)
         stack.enter_context(closing(conn))
         await conn.run(state)
Ejemplo n.º 5
0
async def ps(root: Root, ) -> None:
    """List all jobs"""
    async with AsyncExitStack() as stack:
        client = await stack.enter_async_context(neuro_sdk.get())
        storage: Storage = await stack.enter_async_context(ApiStorage(client))
        runner = await stack.enter_async_context(
            LiveRunner(root.config_dir, root.console, client, storage, root))
        await runner.ps()
Ejemplo n.º 6
0
async def app_setup(app, settings):
    stack = AsyncExitStack()

    await stack.enter_async_context(await app_setup_basic(app, settings))
    await stack.enter_async_context(await trackers.modules.config_modules(
        app, settings))

    return stack
Ejemplo n.º 7
0
async def test_automatic_reconnect(browser: Browser):
    page = await browser.new_page()

    # we need to wait longer here because the automatic reconnect is not instant
    page.set_default_timeout(10000)

    @idom.component
    def OldComponent():
        return idom.html.p({"id": "old-component"}, "old")

    async with AsyncExitStack() as exit_stack:
        server = await exit_stack.enter_async_context(BackendFixture(port=8000))
        display = await exit_stack.enter_async_context(
            DisplayFixture(server, driver=page)
        )

        await display.show(OldComponent)

        # ensure the element is displayed before stopping the server
        await page.wait_for_selector("#old-component")

    # the server is disconnected but the last view state is still shown
    await page.wait_for_selector("#old-component")

    set_state = idom.Ref(None)

    @idom.component
    def NewComponent():
        state, set_state.current = idom.hooks.use_state(0)
        return idom.html.p({"id": f"new-component-{state}"}, f"new-{state}")

    async with AsyncExitStack() as exit_stack:
        server = await exit_stack.enter_async_context(BackendFixture(port=8000))
        display = await exit_stack.enter_async_context(
            DisplayFixture(server, driver=page)
        )

        await display.show(NewComponent)

        # Note the lack of a page refresh before looking up this new component. The
        # client should attempt to reconnect and display the new view automatically.
        await page.wait_for_selector("#new-component-0")

        # check that we can resume normal operation
        set_state.current(1)
        await page.wait_for_selector("#new-component-1")
Ejemplo n.º 8
0
 async def async_fn():
     async with AsyncExitStack() as stack:
         await stack.enter_async_context(A())
         await stack.enter_async_context(amgr())
         stack.push_async_exit(A().aexit)
         stack.push_async_exit(aexit2)
         stack.push_async_callback(acallback, "hi")
         await async_yield(None)
Ejemplo n.º 9
0
async def startSim():
    global client, stack
    # change entries to modify test
    beds = [
        ("W1", "1"),
        ("W1", "2"),
        ("W1", "3"),
        ("W1", "4"),
    ]
    async with AsyncExitStack() as stack:
        # Connect to the MQTT broker
        client = Client(BROKER_ADDRESS)
        await stack.enter_async_context(client)

        tasks = set()
        stack.push_async_callback(cancel_tasks, tasks)

        outbound_topics = [
            "+/+/patientDetails",
            "+/+/HR",
            "+/+/spO2",
            "+/+/diaBP",
            "+/+/sysBP",
            "+/+/ppg",
            "+/+/ecg",
        ]
        for ot in outbound_topics:
            manager = client.filtered_messages(ot)
            messages = await stack.enter_async_context(manager)
            template = f'Outbound -- [topic="{{}}"] {{}}'
            task = asyncio.create_task(log_messages(messages, template))
            tasks.add(task)

        inbound_topics = ["+/+/sendDetails"]

        for it in inbound_topics:
            manager = client.filtered_messages(it)
            messages = await stack.enter_async_context(manager)
            template = f'Inbound -- [topic="{{}}"] {{}}'
            task = asyncio.create_task(log_messages(messages, template))
            tasks.add(task)

        # Messages that doesn't match a filter will get logged here
        messages = await stack.enter_async_context(
            client.unfiltered_messages())
        task = asyncio.create_task(
            log_messages(messages, f'Other -- [topic="{{}}"] {{}}'))
        tasks.add(task)

        await client.subscribe('#')  # subscribe to all messages

        for bed in beds:
            tasks.add(asyncio.create_task(onboardPatient(bed, client)))
            tasks.add(asyncio.create_task(startHRProducer(bed, client)))
            tasks.add(asyncio.create_task(startBPProducer(bed, client)))
            tasks.add(asyncio.create_task(startSpO2Producer(bed, client)))
            tasks.add(asyncio.create_task(startECGProducer(bed, client)))
        await asyncio.gather(*tasks)
Ejemplo n.º 10
0
 async def __aenter__(
         self) -> "MemoryOperationImplementationNetworkContext":
     self.__stack = AsyncExitStack()
     await self.__stack.__aenter__()
     self.operations = {
         opimp.op.name: await self.__stack.enter_async_context(opimp)
         for opimp in self.opimps.values()
     }
     return self
Ejemplo n.º 11
0
 async def mock_app_setup(app, settings):
     app['trackers.settings'] = settings
     app['trackers.data_repo'] = repo
     app['trackers.events'] = {}
     app['analyse_processing_lock'] = asyncio.Lock()
     app['start_event_trackers'] = {
         'mock': None,
     }
     return AsyncExitStack()
Ejemplo n.º 12
0
async def _main(s: Settings) -> None:

    async with AsyncExitStack() as estack:
        await b.bind(s, estack)
        app = create_app()
        server = create_server(app, s)

        logger.info("Starting...", port=s.server.port)
        await server.serve()
Ejemplo n.º 13
0
 async def _create_stubbed_client(self, service_name, *args, **kwargs):
     async with AsyncExitStack() as es:
         es: AsyncExitStack
         client = await es.enter_async_context(
             super(StubbedSession, self).create_client(
                 service_name, *args, **kwargs))
         stubber = Stubber(client)
         self._client_stubs[service_name] = stubber
         yield client
Ejemplo n.º 14
0
    async def listen(self):
        async with AsyncExitStack() as stack:
            # Track tasks
            tasks = set()
            stack.push_async_callback(cancel_tasks, tasks)

            # Connect to the MQTT broker
            client = Client(self.host,
                            self.port,
                            username=self.username,
                            password=self.password)
            await stack.enter_async_context(client)

            logging.info(f'MQTT client connected')
            # Add tasks for each data source handler
            for ds in self.data_sources:
                # Get handlers from data source
                ds_listeners = ds.listeners()
                # Iterate through data source listeners and convert to
                # 'prime' listeners for each topic
                for listener in ds_listeners:
                    topic = listener.topic
                    funcs = listener.handlers
                    if topic in self.topics:
                        # Add these handlers to existing top level topic handler
                        logging.debug(
                            f'Adding handlers for existing prime Listener: {topic}'
                        )
                        ext_topic = self.topics[topic]
                        ext_topic.handlers.extend(funcs)
                    else:
                        # Add this instance as a new top level handler
                        logging.debug(
                            f'Creating new prime Listener for topic: {topic}')
                        self.topics[topic] = MQTTListener(topic, funcs)

            # Add handlers for each topic as a filtered topic
            for topic, listener in self.topics.items():
                manager = client.filtered_messages(topic)
                messages = await stack.enter_async_context(manager)
                task = asyncio.create_task(self.parse_messages(messages))
                tasks.add(task)

            # Subscribe to all topics
            # Assume QoS 0 for now
            all_topics = [(t, 0) for t in self.topics.keys()]
            logging.info(f'Subscribing to MQTT {len(all_topics)} topic(s)')
            logging.debug(f'Topics: {all_topics}')
            try:
                await client.subscribe(all_topics)
            except ValueError as err:
                logging.error(f'MQTT Subscribe error: {err}')

            # Gather all tasks
            await asyncio.gather(*tasks)
            logging.info(f'Listening for MQTT updates')
Ejemplo n.º 15
0
 async def pipeline(self):
     async with AsyncExitStack() as stack:
         sinks = [await stack.enter_async_context(sink) for sink in self.sinks]
         async with ClientSession() as session:
             await self.endpoint.set_session(session)
             async for obj in self.endpoint.iter_results():
                 new_obj = self.transform(obj)
                 if new_obj:
                     for sink in sinks:
                         await sink.put(new_obj)
Ejemplo n.º 16
0
 def __init__(self,
              service_name,
              region_name="eu-west-1",
              max_pool_connections=10) -> None:
     self._client = None
     self._context_stack = AsyncExitStack()
     self._region_name = region_name
     self._service_name = service_name
     self._init_lock = asyncio.Lock()
     self._max_pool_connections = max_pool_connections
Ejemplo n.º 17
0
async def commie_async_loop():
    async with AsyncExitStack() as stack:
        servers = []
        async for server in start_commie_servers():
            await stack.enter_async_context(server)
            servers.append(server)

        await asyncio.gather(
            expire_loop(),
            *(server.serve_forever() for server in servers))
Ejemplo n.º 18
0
async def launch_cluster(test_run, size, version):
    async with AsyncExitStack() as exit_stack:
        cluster = RMQCluster(test_run, size, version, exit_stack)
        exit_stack.push_async_callback(cluster.close)

        print(f'Creating cluster {cluster.name}…')
        await cluster.start()
        print(f'Created cluster {cluster.name}.')

        yield cluster
Ejemplo n.º 19
0
 async def setUp(self):
     self.exit_stack = AsyncExitStack()
     await self.exit_stack.__aenter__()
     self.tserver = await self.exit_stack.enter_async_context(
         ServerRunner.patch(Server))
     self.cli = Server(port=0, insecure=True)
     await self.tserver.start(self.cli.run())
     # Set up client
     self.session = await self.exit_stack.enter_async_context(
         aiohttp.ClientSession())
Ejemplo n.º 20
0
    async def _wrapper(self, main_func: Callable[[_CT, _WI],
                                                 Coroutine[Any, Any, _WO]],
                       config: _CT) -> None:
        async with AsyncExitStack() as stack:
            input_wires = {}
            for field in fields(self._wires_in_type):
                wire: Optional[Wire]
                if config.HasField(field.name):
                    wire_config = getattr(config, field.name)
                    if isinstance(field.type, type) and issubclass(
                            field.type, Wire):
                        wire_type = field.type
                    else:
                        # Optional[wire_type]
                        wire_type = field.type.__args__[0]
                        assert isinstance(wire_type, type) and issubclass(
                            wire_type, Wire), type(wire_type)
                    wire = wire_type()
                    wire.configure(wire_config)
                    await stack.enter_async_context(wire)
                else:
                    if isinstance(field.type, type) and issubclass(
                            field.type, Wire):
                        raise RuntimeError(f"Missing configuration for"
                                           f" required wire: {field.name}")
                    wire = None
                input_wires[field.name] = wire
            wires_in = self._wires_in_type(**input_wires)  # type: ignore

            wires_out = await main_func(config, wires_in)

            if not isinstance(wires_out, self._wires_out_type):
                raise RuntimeError(
                    f"{main_func} returned invalid type: {type(wires_out)!r}; "
                    f"expected: {self._wires_out_type!r}")

            waiters = set()
            output_wires = []
            for field in fields(self._wires_out_type):
                if not config.HasField(field.name):
                    continue
                wire = getattr(wires_out, field.name)
                assert isinstance(wire, Wire), type(wire)
                wire_config = getattr(config, field.name)
                wire.configure(wire_config)
                await stack.enter_async_context(wire)
                waiters.add(wire.wait_closed())
                output_wires.append(wire)

            if output_wires:
                with graceful_exit(output_wires):
                    await asyncio.wait(
                        waiters,
                        return_when=asyncio.FIRST_COMPLETED,
                    )
Ejemplo n.º 21
0
    async def create(
        self,
        paths: Optional[List[str]] = None,
        id: Optional[Union[str, 'DaemonID']] = None,
        complete: bool = False,
        *args,
        **kwargs,
    ) -> Optional['DaemonID']:
        """Create a workspace

        :param paths: local file/directory paths to be uploaded to workspace, defaults to None
        :param id: workspace id (if already known), defaults to None
        :param complete: True if complete_path is used (used by JinadRuntime), defaults to False
        :param args: additional positional args
        :param kwargs: keyword args
        :return: workspace id
        """

        async with AsyncExitStack() as stack:
            console = Console()
            status = stack.enter_context(
                console.status('Workspace: ...', spinner='earth'))
            workspace_id = None
            if id:
                workspace_id = daemonize(id)
                try:
                    return (workspace_id if await self._get_helper(
                        id=workspace_id, status=status) else None)
                except (TypeError, ValueError):
                    self._logger.debug('workspace doesn\'t exist, creating..')

            status.update('Workspace: Getting files to upload...')
            data = (self._files_in(paths=paths,
                                   exitstack=stack,
                                   complete=complete) if paths else None)
            status.update('Workspace: Sending request...')
            response = await stack.enter_async_context(
                aiohttp.request(
                    method='POST',
                    url=self.store_api,
                    params={'id': workspace_id} if workspace_id else None,
                    data=data,
                ))
            response_json = await response.json()
            workspace_id = next(iter(response_json))

            if response.status == HTTPStatus.CREATED:
                status.update(f'Workspace: {workspace_id} added...')
                return (workspace_id if await self.wait(
                    id=workspace_id, status=status, logs=True) else None)
            else:
                self._logger.error(
                    f'{self._kind.title()} creation failed as: {error_msg_from(response_json)}'
                )
                return None
Ejemplo n.º 22
0
async def connect_chipcard(
    cls: Type[T],
    *,
    path: Optional[str] = None,
    sock: Optional[ByteStream] = None,
    reader_name: Optional[str] = None,
    reader_id: Optional[int] = None,
    atr: Optional[ByteString] = None,
    byteorder: str = "=",
    **kwargs: Dict[str, Any],
) -> AsyncGenerator[T, None]:
    """
    Open a chipcard that is compatible to cls.
    """

    if path and sock:
        raise TypeError("Cannot specify both path and sock")

    if sock:
        ctx = async_null_context(PcscClient(sock, byteorder))
    else:
        ctx = PcscClient.connect(path, byteorder)

    async with ctx as client, AsyncExitStack() as stack:
        usable_card = None
        wanted_state = ReaderState.NEGOTIABLE | ReaderState.POWERED | ReaderState.PRESENT
        for reader in await client.get_reader_state():
            if reader.reader_state & wanted_state != wanted_state:
                continue
            if reader_name is not None and reader_name != reader.name:
                continue
            if reader_id is not None and reader_id != reader.reader_id:
                continue
            if atr is not None and atr != reader.atr:
                continue
            card = cls(await stack.enter_async_context(reader.connect()))
            try:
                if not await card.is_usable(**kwargs):
                    card = None
                    await stack.aclose()
                    continue
            except Exception:
                logging.exception(f"Could not open card in {reader}")
                continue

            if usable_card:
                raise Exception(
                    f"Multiple cards match: {usable_card.reader} {card.reader}"
                )

            usable_card = card

        if not usable_card:
            raise Exception("No card matches the criteria")
        yield usable_card
Ejemplo n.º 23
0
    async def test_async_callback(self):
        expected = [
            ((), {}),
            ((1, ), {}),
            ((1, 2), {}),
            ((), dict(example=1)),
            ((1, ), dict(example=1)),
            ((1, 2), dict(example=1)),
        ]
        result = []

        async def _exit(*args, **kwds):
            """Test metadata propagation"""
            result.append((args, kwds))

        async with AsyncExitStack() as stack:
            for args, kwds in reversed(expected):
                if args and kwds:
                    f = stack.push_async_callback(_exit, *args, **kwds)
                elif args:
                    f = stack.push_async_callback(_exit, *args)
                elif kwds:
                    f = stack.push_async_callback(_exit, **kwds)
                else:
                    f = stack.push_async_callback(_exit)
                self.assertIs(f, _exit)
            for wrapper in stack._exit_callbacks:
                self.assertIs(wrapper[1].__wrapped__, _exit)
                self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
                self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)

        self.assertEqual(result, expected)

        result = []
        async with AsyncExitStack() as stack:
            with self.assertRaises(TypeError):
                stack.push_async_callback(arg=1)
            with self.assertRaises(TypeError):
                self.exit_stack.push_async_callback(arg=2)
            with self.assertWarns(DeprecationWarning):
                stack.push_async_callback(callback=_exit, arg=3)
        self.assertEqual(result, [((), {'arg': 3})])
Ejemplo n.º 24
0
async def retry_download(*args, semaphore=None, **kwargs):  # noqa: E999
    """Retry download() calls."""
    async with AsyncExitStack() as stack:
        if semaphore:
            await stack.enter_async_context(semaphore)
        await retry_async(
            download,
            retry_exceptions=(aiohttp.ClientError, asyncio.TimeoutError),
            args=args,
            kwargs=kwargs,
        )
Ejemplo n.º 25
0
async def main():
    _client = aiosocket()
    _server = aiosocket()

    async with AsyncExitStack() as stack:
        stack.enter_async_context(_server)
        stack.enter_async_context(_client)

        server_started = Future(callback=client(_client))
        await gather(server(_server, server_started))
        print('done')
Ejemplo n.º 26
0
async def _get_connection(
        conn: Union[asyncpg.Connection, asyncpg.pool.Pool],
        force_transaction: bool = False) -> asyncpg.Connection:
    async with AsyncExitStack() as stack:
        if isinstance(conn, asyncpg.pool.Pool):
            conn = await stack.enter_async_context(conn.acquire())

        if force_transaction and not conn.is_in_transaction():
            await stack.enter_async_context(conn.transaction())

        yield conn
Ejemplo n.º 27
0
 def __init__(self, max_workers: int, id_: uuid.UUID = None) -> None:
     if id_ is None:
         self.id = uuid.uuid4()
     else:
         self.id = id_
     self._max_workers = max_workers
     self._num_workers = 0
     self._send_channel, self._receive_channel = trio.open_memory_channel[
         WorkerProcessAPI](max_workers)
     self._exit_stack = AsyncExitStack()
     self._is_open = False
Ejemplo n.º 28
0
    def __init__(self, stream_name, endpoint_url=None, region_name=None):

        self.stream_name = stream_name

        self.endpoint_url = endpoint_url
        self.region_name = region_name

        self.exit_stack = AsyncExitStack()
        self.client = None
        self.shards = None

        self.stream_status = None
Ejemplo n.º 29
0
async def app_setup_basic(app, settings):
    stack = AsyncExitStack()

    app['start_event_trackers'] = {
        'static':
        trackers.general.static_start_event_tracker,
        'cropped':
        partial(trackers.general.wrapped_tracker_start_event,
                trackers.general.cropped_tracker_start),
        'filter_inaccurate':
        partial(trackers.general.wrapped_tracker_start_event,
                trackers.general.filter_inaccurate_tracker_start),
    }

    app['trackers.settings'] = settings
    app['trackers.data_repo'] = stack.enter_context(
        dulwich.repo.Repo(settings['data_path']))
    app['trackers.events'] = {}
    app['analyse_processing_lock'] = asyncio.Lock()

    return stack
async def p2pcs():
    # TODO: Change back to gather style
    async with AsyncExitStack() as stack:
        p2pd_tuples = [
            await stack.enter_async_context(
                FUNC_MAKE_P2PD_PAIR(
                    enable_control=ENABLE_CONTROL,
                    enable_connmgr=ENABLE_CONNMGR,
                    enable_dht=ENABLE_DHT,
                    enable_pubsub=ENABLE_PUBSUB,
                )) for _ in range(NUM_P2PDS)
        ]
        yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
Ejemplo n.º 31
0
 async def __aenter__(self) -> "BaseOrchestratorContext":
     self._stack = AsyncExitStack()
     self._stack = await aenter_stack(
         self,
         {
             "rctx": self.parent.rchecker,
             "ictx": self.parent.input_network,
             "octx": self.parent.operation_network,
             "lctx": self.parent.lock_network,
             "nctx": self.parent.opimp_network,
         },
     )
     return self
Ejemplo n.º 32
0
 def __init__(self) -> None:
     self._send_channels: Set[trio.abc.SendChannel] = set()
     self._stack = AsyncExitStack()