Exemple #1
0
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        **kwargs,
    ):
        """
        :param inputs: the callable
        :param on_done: the callback for on_done
        :param on_error: the callback for on_error
        :param on_always: the callback for on_always
        :param kwargs: kwargs for _get_task_name and _get_requests
        :yields: generator over results
        """
        with ImportExtensions(required=True):
            import aiohttp

        self.inputs = inputs
        request_iterator = self._get_requests(**kwargs)

        async with AsyncExitStack() as stack:
            try:
                cm1 = ProgressBar(total_length=self._inputs_length,
                                  disable=not (self.show_progress))
                p_bar = stack.enter_context(cm1)

                proto = 'wss' if self.args.tls else 'ws'
                url = f'{proto}://{self.args.host}:{self.args.port}/'
                iolet = await stack.enter_async_context(
                    WebsocketClientlet(url=url, logger=self.logger))

                request_buffer: Dict[str, asyncio.Future] = dict(
                )  # maps request_ids to futures (tasks)

                def _result_handler(result):
                    return result

                async def _receive():
                    def _response_handler(response):
                        if response.header.request_id in request_buffer:
                            future = request_buffer.pop(
                                response.header.request_id)
                            future.set_result(response)
                        else:
                            self.logger.warning(
                                f'discarding unexpected response with request id {response.header.request_id}'
                            )

                    """Await messages from WebsocketGateway and process them in the request buffer"""
                    try:
                        async for response in iolet.recv_message():
                            _response_handler(response)
                    finally:
                        if request_buffer:
                            self.logger.warning(
                                f'{self.__class__.__name__} closed, cancelling all outstanding requests'
                            )
                            for future in request_buffer.values():
                                future.cancel()
                            request_buffer.clear()

                def _handle_end_of_iter():
                    """Send End of iteration signal to the Gateway"""
                    asyncio.create_task(iolet.send_eoi())

                def _request_handler(request: 'Request') -> 'asyncio.Future':
                    """
                    For each request in the iterator, we send the `Message` using `iolet.send_message()`.
                    For websocket requests from client, for each request in the iterator, we send the request in `bytes`
                    using using `iolet.send_message()`.
                    Then add {<request-id>: <an-empty-future>} to the request buffer.
                    This empty future is used to track the `result` of this request during `receive`.
                    :param request: current request in the iterator
                    :return: asyncio Future for sending message
                    """
                    future = get_or_reuse_loop().create_future()
                    request_buffer[request.header.request_id] = future
                    asyncio.create_task(iolet.send_message(request))
                    return future

                streamer = RequestStreamer(
                    args=self.args,
                    request_handler=_request_handler,
                    result_handler=_result_handler,
                    end_of_iter_handler=_handle_end_of_iter,
                )

                receive_task = asyncio.create_task(_receive())

                if receive_task.done():
                    raise RuntimeError(
                        'receive task not running, can not send messages')
                try:
                    async for response in streamer.stream(request_iterator):
                        callback_exec(
                            response=response,
                            on_error=on_error,
                            on_done=on_done,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                        if self.show_progress:
                            p_bar.update()
                        yield response
                finally:
                    if iolet.close_code == status.WS_1011_INTERNAL_ERROR:
                        raise ConnectionError(iolet.close_message)
                    await receive_task

            except (aiohttp.ClientError, ConnectionError) as e:
                self.logger.error(
                    f'Error while streaming response from websocket server {e!r}'
                )

                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e
Exemple #2
0
async def notify_many(*conditions: CustomCondition):

    async with AsyncExitStack() as stack:
        await asyncio.gather(
            *[stack.enter_async_context(x) for x in conditions])
        await asyncio.gather(*[x.notify_all() for x in conditions])
Exemple #3
0
    def __init__(
        self,
        *,
        budget: Union[float, Decimal],
        strategy: BaseMarketStrategy,
        event_consumer: Callable[[events.Event], None],
        subnet_tag: Optional[str] = None,
        payment_driver: Optional[str] = None,
        payment_network: Optional[str] = None,
        stream_output: bool = False,
        app_key: Optional[str] = None,
    ):
        """Initialize the engine.

        :param budget: maximum budget for payments
        :param strategy: market strategy used to select providers from the market
            (e.g. LeastExpensiveLinearPayuMS or DummyMS)
        :param event_consumer: callable that will be directly executed on every Event this Engine creates.
            NOTE: it is expected to be fast or async - if not, it will block the _Engine.
        :param subnet_tag: use only providers in the subnet with the subnet_tag name.
            Uses `YAGNA_SUBNET` environment variable, defaults to `None`
        :param payment_driver: name of the payment driver to use. Uses `YAGNA_PAYMENT_DRIVER`
            environment variable, defaults to `erc20`. Only payment platforms with
            the specified driver will be used
        :param payment_network: name of the payment network to use. Uses `YAGNA_PAYMENT_NETWORK`
        environment variable, defaults to `rinkeby`. Only payment platforms with the specified
            network will be used
        :param stream_output: stream computation output from providers
        :param app_key: optional Yagna application key. If not provided, the default is to
                        get the value from `YAGNA_APPKEY` environment variable
        """
        self._api_config = rest.Configuration(app_key)
        self._budget_amount = Decimal(budget)
        self._budget_allocations: List[rest.payment.Allocation] = []

        self._strategy = strategy
        self._event_consumer = event_consumer

        self._subnet: Optional[str] = subnet_tag or DEFAULT_SUBNET
        self._payment_driver: str = payment_driver.lower(
        ) if payment_driver else DEFAULT_DRIVER
        self._payment_network: str = payment_network.lower(
        ) if payment_network else DEFAULT_NETWORK
        self._stream_output = stream_output

        # a set of `Job` instances used to track jobs - computations or services - started
        # it can be used to wait until all jobs are finished
        self._jobs: Set[Job] = set()

        # initialize the payment structures
        self._invoice_manager = InvoiceManager()

        self._agreements_accepting_debit_notes: Dict[
            JobId, Set[AgreementId]] = defaultdict(set)
        self._num_debit_notes: Dict[ActivityId, int] = defaultdict(int)
        self._num_payable_debit_notes: Dict[ActivityId, int] = defaultdict(int)
        self._activity_created_at: Dict[ActivityId, datetime] = dict()
        self._payment_closing: bool = False

        self._process_invoices_job: Optional[asyncio.Task] = None

        # a set of async generators created by executors that use this engine
        self._generators: Set[AsyncGenerator] = set()
        self._services: Set[asyncio.Task] = set()
        self._stack = AsyncExitStack()

        self._started = False

        #   All agreements ever used within this Engine will be stored here
        self._all_agreements: Dict[AgreementId, Agreement] = {}
Exemple #4
0
 def __init__(self) -> None:
     self._active = False
     self._stack = AsyncExitStack()
Exemple #5
0
async def connect_to_mqtt_server(config):
    async with AsyncExitStack() as stack:
        # Keep track of the asyncio tasks that we create, so that
        # we can cancel them on exit
        tasks = set()
        stack.push_async_callback(cancel_tasks, tasks)

        if config.has_section('MQTT'):
            # Connect to the MQTT broker
            # client = Client("10.0.1.20", username="******", password="******")
            client = Client(config.get('MQTT', 'broker_address'),
                            username=config.get('MQTT', 'usr'),
                            password=config.get('MQTT', 'pswd'))

            await stack.enter_async_context(client)

            # Create chair state
            chair_state = ChairState()

            # Select topic filters
            # You can create any number of topic filters
            topic_filters = (
                "sensors/#",
                # TODO add more filters
            )

            # Log all messages
            # for topic_filter in topic_filters:
            #     # Log all messages that matches the filter
            #     manager = client.filtered_messages(topic_filter)
            #     messages = await stack.enter_async_context(manager)
            #     template = f'[topic_filter="{topic_filter}"] {{}}'
            #     task = asyncio.create_task(log_messages(messages, template))
            #     tasks.add(task)

            # Messages that doesn't match a filter will get logged here
            messages = await stack.enter_async_context(
                client.unfiltered_messages())
            task = asyncio.create_task(
                log_messages(messages, "[unfiltered] {}"))
            tasks.add(task)

            # Subscribe to pressure sensors
            await client.subscribe("sensors/pressure")
            manager = client.filtered_messages('sensors/pressure')
            messages = await stack.enter_async_context(manager)
            task = asyncio.create_task(
                handle_sensors_pressure(client, messages, chair_state))
            tasks.add(task)

            # Subscribe to angle sensors
            await client.subscribe("sensors/angle")
            manager = client.filtered_messages('sensors/angle')
            messages = await stack.enter_async_context(manager)
            task = asyncio.create_task(
                handle_sensors_angle(client, messages, chair_state))
            tasks.add(task)

            # Subscribe to travel sensors
            await client.subscribe("sensors/travel")
            manager = client.filtered_messages('sensors/travel')
            messages = await stack.enter_async_context(manager)
            task = asyncio.create_task(
                handle_sensors_travel(client, messages, chair_state))
            tasks.add(task)

            # Subscribe to alarm sensors
            await client.subscribe("sensors/alarm/state")
            manager = client.filtered_messages('sensors/alarm/state')
            messages = await stack.enter_async_context(manager)
            task = asyncio.create_task(
                handle_sensors_alarm(client, messages, chair_state))
            tasks.add(task)

            # Start periodic publish of chair state
            task = asyncio.create_task(publish_chair_state(
                client, chair_state))
            tasks.add(task)

        # Wait for everything to complete (or fail due to, e.g., network errors)
        await asyncio.gather(*tasks)
Exemple #6
0
async def connect_to_mqtt_server(config):
    async with AsyncExitStack() as stack:
        # Keep track of the asyncio tasks that we create, so that
        # we can cancel them on exit
        tasks = set()
        stack.push_async_callback(cancel_tasks, tasks)

        if config.has_section('MQTT'):
            # Connect to the MQTT broker
            # client = Client("10.0.1.20", username="******", password="******")
            client = Client(config.get('MQTT', 'broker_address'),
                            username=config.get('MQTT', 'usr'),
                            password=config.get('MQTT', 'pswd'))

            await stack.enter_async_context(client)

            # Create angle fsm
            fsm = NotificationFSM(config)

            # Messages that doesn't match a filter will get logged here
            messages = await stack.enter_async_context(
                client.unfiltered_messages())
            task = asyncio.create_task(
                log_messages(messages, "[unfiltered] {}"))
            tasks.add(task)

            # Subscribe to chairState
            await client.subscribe("sensors/chairState")
            manager = client.filtered_messages('sensors/chairState')
            messages = await stack.enter_async_context(manager)
            task = asyncio.create_task(
                handle_sensors_chair_state(client, messages, fsm, config))
            tasks.add(task)

            # Subscribe to angle fsm
            await client.subscribe("fsm/angle")
            manager = client.filtered_messages('fsm/angle')
            messages = await stack.enter_async_context(manager)
            task = asyncio.create_task(
                handle_angle_fsm_state(client, messages, fsm, config))
            tasks.add(task)

            # Subscribe to travel fsm
            await client.subscribe("fsm/travel")
            manager = client.filtered_messages('fsm/travel')
            messages = await stack.enter_async_context(manager)
            task = asyncio.create_task(
                handle_travel_fsm_state(client, messages, fsm, config))
            tasks.add(task)

            # Subscribe to seating fsm
            await client.subscribe("fsm/seating")
            manager = client.filtered_messages('fsm/seating')
            messages = await stack.enter_async_context(manager)
            task = asyncio.create_task(
                handle_seating_fsm_state(client, messages, fsm, config))
            tasks.add(task)

            # Notification changes
            await client.subscribe("config/notifications_settings")
            manager = client.filtered_messages('config/notifications_settings')
            messages = await stack.enter_async_context(manager)
            task = asyncio.create_task(
                handle_config_notifications_settings(client, messages, fsm))
            tasks.add(task)

            # Start periodic publish of chair state
            task = asyncio.create_task(publish_notification_fsm(client, fsm))
            tasks.add(task)

        # Wait for everything to complete (or fail due to, e.g., network errors)
        await asyncio.gather(*tasks)
Exemple #7
0
 async def __aenter__(self):
     self._stack = AsyncExitStack()
     await self._stack.__aenter__()
     for item in self:
         await self._stack.enter_async_context(item)
     return self
Exemple #8
0
    async def create(
        self,
        paths: Optional[List[Path]] = None,
        id: Optional[Union[str, 'DaemonID']] = None,
        complete: bool = False,
        *args,
        **kwargs,
    ) -> Optional['DaemonID']:
        """Create a workspace

        :param paths: local file/directory paths to be uploaded to workspace, defaults to None
        :param id: workspace id (if already known), defaults to None
        :param complete: True if complete_path is used (used by JinadRuntime), defaults to False
        :param args: additional positional args
        :param kwargs: keyword args
        :return: workspace id
        """

        async with AsyncExitStack() as stack:
            console = Console()
            status = stack.enter_context(
                console.status('Workspace: ...', spinner='earth'))
            workspace_id = None
            if id:
                """When creating `Pods` with `shards > 1`, `JinadRuntime` knows the workspace_id already.
                For shards > 1:
                - shard 0 throws TypeError & we create a workspace
                - shard N (all other shards) wait for workspace creation & don't emit logs

                For shards = 0:
                - Throws a TypeError & we create a workspace
                """
                workspace_id = daemonize(id)
                try:
                    return (workspace_id if await self._get_helper(
                        id=workspace_id, status=status) else None)
                except (TypeError, ValueError):
                    self._logger.debug('workspace doesn\'t exist, creating..')

            status.update('Workspace: Getting files to upload...')
            data = stack.enter_context(
                FormData(paths=paths, logger=self._logger, complete=complete))
            status.update('Workspace: Sending request...')
            response = await stack.enter_async_context(
                aiohttp.request(
                    method='POST',
                    url=self.store_api,
                    params={'id': workspace_id} if workspace_id else None,
                    data=data,
                ))
            response_json = await response.json()
            workspace_id = next(iter(response_json))

            if response.status == HTTPStatus.CREATED:
                status.update(f'Workspace: {workspace_id} added...')
                return (workspace_id if await self.wait(
                    id=workspace_id, status=status, logs=True) else None)
            else:
                self._logger.error(
                    f'{self._kind.title()} creation failed as: {error_msg_from(response_json)}'
                )
                return None
Exemple #9
0
    async def _ws(self, init, keepalive=True, timeout=30):
        pinger = None
        foo = {'UNSUBSCRIBE': []}
        foo.update(init)

        while True:
            try:
                '''
                Make sure that before we launch the WebSocket we have a valid
                session id
                '''
                while True:
                    if await self.is_logged_in():
                        break
                    logger.warning("Session died, trying a manual login.")
                    await asyncio.sleep(5)
                    await self.login()

                async with AsyncExitStack() as stack:

                    ws = await stack.enter_async_context(
                        self.session.ws_connect(f"{self.url}/ws/stats",
                                                headers=self.headers,
                                                origin=self.url,
                                                ssl=self.ssl))
                    pinger = await stack.enter_async_context(
                        TaskEvery(ws_ping, ws, interval=30, sync_once=False))

                    foo.update({'SESSION_ID': self.session_id})
                    await ws.send_str(as_statd_string(foo))
                    data = ''
                    while True:
                        msg = await asyncio.wait_for(ws.receive(), timeout)

                        if msg.type != WSMsgType.TEXT:
                            logging.debug(
                                f"got non text websocket data {msg.data!r} this probbaly means the socket was closed so let's start a fresh one"
                            )
                            break

                        data += msg.data

                        temp1, temp2 = data.split('\n', 1)
                        data_len = int(temp1)
                        '''
                        If the payload is larger than the data len, process until it's less
                        '''
                        while len(temp2) >= data_len:
                            try:
                                payload = json.loads(temp2[:data_len])
                                self.sysdata.update(payload)
                                yield payload
                            except Exception as e:
                                logger.error(f"{e!r}")
                            '''
                            Strip off the processed data and leave the next for another round
                            '''
                            data = temp2[data_len:]
                            if len(data) < 4: break
                            temp1, temp2 = data.split('\n', 1)
                            data_len = int(temp1)
            except asyncio.CancelledError as err:
                return
            except Exception as err:
                logger.debug(f"websocket loop raised {err!r}, ignoring")
                if not keepalive:
                    return
Exemple #10
0
async def handle_event(bot: "Bot", event: "Event") -> None:
    """
    :说明:

       处理一个事件。调用该函数以实现分发事件。

    :参数:

      * ``bot: Bot``: Bot 对象
      * ``event: Event``: Event 对象

    :示例:

    .. code-block:: python

        import asyncio
        asyncio.create_task(handle_event(bot, event))
    """
    show_log = True
    log_msg = f"<m>{escape_tag(bot.type.upper())} {escape_tag(bot.self_id)}</m> | "
    try:
        log_msg += event.get_log_string()
    except NoLogException:
        show_log = False
    if show_log:
        logger.opt(colors=True).success(log_msg)

    state: Dict[Any, Any] = {}
    dependency_cache: T_DependencyCache = {}

    async with AsyncExitStack() as stack:
        coros = list(
            map(
                lambda x: _run_coro_with_catch(
                    x(
                        bot=bot,
                        event=event,
                        state=state,
                        stack=stack,
                        dependency_cache=dependency_cache,
                    )),
                _event_preprocessors,
            ))
        if coros:
            try:
                if show_log:
                    logger.debug("Running PreProcessors...")
                await asyncio.gather(*coros)
            except IgnoredException as e:
                logger.opt(colors=True).info(
                    f"Event {escape_tag(event.get_event_name())} is <b>ignored</b>"
                )
                return
            except Exception as e:
                logger.opt(colors=True, exception=e).error(
                    "<r><bg #f8bbd0>Error when running EventPreProcessors. "
                    "Event ignored!</bg #f8bbd0></r>")
                return

        # Trie Match
        try:
            TrieRule.get_value(bot, event, state)
        except Exception as e:
            logger.opt(
                colors=True,
                exception=e).warning("Error while parsing command for event")

        break_flag = False
        for priority in sorted(matchers.keys()):
            if break_flag:
                break

            if show_log:
                logger.debug(
                    f"Checking for matchers in priority {priority}...")

            pending_tasks = [
                _check_matcher(priority, matcher, bot, event, state.copy(),
                               stack, dependency_cache)
                for matcher in matchers[priority]
            ]

            results = await asyncio.gather(*pending_tasks,
                                           return_exceptions=True)

            for result in results:
                if not isinstance(result, Exception):
                    continue
                if isinstance(result, StopPropagation):
                    break_flag = True
                    logger.debug("Stop event propagation")
                else:
                    logger.opt(colors=True, exception=result).error(
                        "<r><bg #f8bbd0>Error when checking Matcher.</bg #f8bbd0></r>"
                    )

        coros = list(
            map(
                lambda x: _run_coro_with_catch(
                    x(
                        bot=bot,
                        event=event,
                        state=state,
                        stack=stack,
                        dependency_cache=dependency_cache,
                    )),
                _event_postprocessors,
            ))
        if coros:
            try:
                if show_log:
                    logger.debug("Running PostProcessors...")
                await asyncio.gather(*coros)
            except Exception as e:
                logger.opt(colors=True, exception=e).error(
                    "<r><bg #f8bbd0>Error when running EventPostProcessors</bg #f8bbd0></r>"
                )
Exemple #11
0
async def test_alexandria_network_stream_locations(
    tester,
    alice,
):
    content_key = b"test-key"
    async with AsyncExitStack() as stack:
        networks = await stack.enter_async_context(
            tester.alexandria.network_group(6))

        # put advertisements into the database
        other_advertisements = tuple(
            AdvertisementFactory(
                content_key=content_key,
                hash_tree_root=b"unicornsrainbowscupcakessparkles",
            ) for _ in range(6))

        for advertisement, network in zip(other_advertisements, networks):
            network.remote_advertisement_db.add(advertisement)

        # give the the network some time to interconnect.
        with trio.fail_after(30):
            for _ in range(1000):
                await trio.lowlevel.checkpoint()

        bootnodes = tuple(network.enr_manager.enr for network in networks)
        alice_alexandria_network = await stack.enter_async_context(
            alice.alexandria.network(bootnodes=bootnodes))

        local_advertisement = AdvertisementFactory(
            private_key=alice.private_key,
            content_key=content_key,
            hash_tree_root=b"unicornsrainbowscupcakessparkles",
        )
        alice_alexandria_network.local_advertisement_db.add(
            local_advertisement)
        remote_advertisement = AdvertisementFactory(
            private_key=alice.private_key,
            content_key=content_key,
            hash_tree_root=b"unicornsrainbowscupcakessparkles",
        )
        alice_alexandria_network.remote_advertisement_db.add(
            remote_advertisement)

        with trio.fail_after(30):
            for _ in range(1000):
                await trio.lowlevel.checkpoint()

        advertisement_aiter_ctx = alice_alexandria_network.stream_locations(
            content_key, )
        with trio.fail_after(60):
            async with advertisement_aiter_ctx as advertisement_aiter:
                found_advertisements = tuple([
                    advertisement
                    async for advertisement in advertisement_aiter
                ])

        assert len(found_advertisements) >= 5

        assert local_advertisement in found_advertisements
        assert remote_advertisement in found_advertisements

        for advertisement in found_advertisements:
            if advertisement == local_advertisement:
                continue
            elif advertisement == remote_advertisement:
                continue

            assert advertisement in other_advertisements
Exemple #12
0
async def test_alexandria_network_broadcast_api(
    tester,
    alice,
    alice_alexandria_network,
    autojump_clock,
):
    async with AsyncExitStack() as stack:
        network_group = await stack.enter_async_context(
            tester.alexandria.network_group(10))

        furthest_network = max(
            network_group,
            key=lambda network: compute_distance(alice.node_id, network.
                                                 local_node_id),
        )
        closest_network = min(
            network_group,
            key=lambda network: compute_distance(alice.node_id, network.
                                                 local_node_id),
        )

        furthest_node_distance_from_alice = compute_distance(
            furthest_network.local_node_id,
            alice.node_id,
        )

        furthest_ad = AdvertisementFactory()

        for _ in range(100):
            advertisement = AdvertisementFactory()

            distance_from_alice = compute_content_distance(
                alice.node_id, advertisement.content_id)
            distance_from_furthest = compute_content_distance(
                alice.node_id, furthest_ad.content_id)

            if distance_from_alice > distance_from_furthest:
                furthest_ad = advertisement

            if distance_from_furthest >= furthest_node_distance_from_alice:
                break

        async with trio.open_nursery() as nursery:

            async def _respond(network, subscription):
                request = await subscription.receive()
                await network.client.send_ack(
                    request.sender_node_id,
                    request.sender_endpoint,
                    advertisement_radius=network.local_advertisement_radius,
                    acked=(True, ) * len(request.message.payload),
                    request_id=request.request_id,
                )

            for network in network_group:
                subscription = await stack.enter_async_context(
                    network.client.subscribe(AdvertiseMessage))
                nursery.start_soon(_respond, network, subscription)

            alice_alexandria_network.enr_db.set_enr(
                closest_network.enr_manager.enr)
            await alice_alexandria_network.bond(closest_network.local_node_id)

            for _ in range(10000):
                await trio.lowlevel.checkpoint()

            with trio.fail_after(30):
                result = await alice_alexandria_network.broadcast(advertisement
                                                                  )
                assert len(result) > 0

            nursery.cancel_scope.cancel()
Exemple #13
0
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        **kwargs,
    ):
        """
        :param inputs: the callable
        :param on_done: the callback for on_done
        :param on_error: the callback for on_error
        :param on_always: the callback for on_always
        :param kwargs: kwargs for _get_task_name and _get_requests
        :yields: generator over results
        """
        with ImportExtensions(required=True):
            import aiohttp

        self.inputs = inputs
        request_iterator = self._get_requests(**kwargs)

        async with AsyncExitStack() as stack:
            try:
                cm1 = ProgressBar(total_length=self._inputs_length,
                                  disable=not (self.show_progress))
                p_bar = stack.enter_context(cm1)

                proto = 'https' if self.args.tls else 'http'
                url = f'{proto}://{self.args.host}:{self.args.port}/post'
                iolet = await stack.enter_async_context(
                    HTTPClientlet(url=url, logger=self.logger))

                def _request_handler(request: 'Request') -> 'asyncio.Future':
                    """
                    For HTTP Client, for each request in the iterator, we `send_message` using
                    http POST request and add it to the list of tasks which is awaited and yielded.
                    :param request: current request in the iterator
                    :return: asyncio Task for sending message
                    """
                    return asyncio.ensure_future(
                        iolet.send_message(request=request))

                def _result_handler(result):
                    return result

                streamer = RequestStreamer(
                    self.args,
                    request_handler=_request_handler,
                    result_handler=_result_handler,
                )
                async for response in streamer.stream(request_iterator):
                    r_status = response.status

                    r_str = await response.json()
                    if r_status == 404:
                        raise BadClient(f'no such endpoint {url}')
                    elif r_status < 200 or r_status > 300:
                        raise ValueError(r_str)

                    da = None
                    if 'data' in r_str and r_str['data'] is not None:
                        from docarray import DocumentArray

                        da = DocumentArray.from_dict(r_str['data'])
                        del r_str['data']

                    resp = DataRequest(r_str)
                    if da is not None:
                        resp.data.docs = da

                    callback_exec(
                        response=resp,
                        on_error=on_error,
                        on_done=on_done,
                        on_always=on_always,
                        continue_on_error=self.continue_on_error,
                        logger=self.logger,
                    )
                    if self.show_progress:
                        p_bar.update()
                    yield resp

            except aiohttp.ClientError as e:
                self.logger.error(
                    f'Error while fetching response from HTTP server {e!r}')

                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e
Exemple #14
0
async def test_content_manager_enumerates_and_broadcasts_content(
    alice,
    bob,
):
    #
    # Test Setup:
    #
    # 15 pieces of content
    # 5 that don't already have advertisements to test lazy creation
    contents = tuple((b"\xfftest-content-" + bytes([idx]),
                      ContentFactory(256 + 25)[idx:idx + 256])
                     for idx in range(12))

    advertisement_db = alice.alexandria.local_advertisement_db
    pinned_storage = alice.alexandria.pinned_content_storage

    for content_key, content in contents[0:13:3]:
        hash_tree_root = ssz.get_hash_tree_root(content, sedes=content_sedes)
        advertisement = Advertisement.create(
            content_key=content_key,
            hash_tree_root=hash_tree_root,
            private_key=alice.private_key,
        )
        advertisement_db.add(advertisement)

    content_keys, content_payloads = zip(*contents)

    for content_key, content in contents:
        pinned_storage.set_content(content_key, content)

    received_advertisements = []

    async with AsyncExitStack() as stack:
        bob_alexandria_client = await stack.enter_async_context(
            bob.alexandria.client())
        ad_subscription = await stack.enter_async_context(
            bob_alexandria_client.subscribe(AdvertiseMessage))
        ping_subscription = await stack.enter_async_context(
            bob_alexandria_client.subscribe(PingMessage))
        find_nodes_subscription = await stack.enter_async_context(
            bob_alexandria_client.subscribe(FindNodesMessage))

        done = trio.Event()

        async def _do_found_nodes():
            async for request in find_nodes_subscription:
                await bob_alexandria_client.send_found_nodes(
                    request.sender_node_id,
                    request.sender_endpoint,
                    enrs=(),
                    request_id=request.request_id,
                )

        async def _do_pong():
            async for request in ping_subscription:
                await bob_alexandria_client.send_pong(
                    request.sender_node_id,
                    request.sender_endpoint,
                    advertisement_radius=2**256 - 1,
                    enr_seq=bob.enr.sequence_number,
                    request_id=request.request_id,
                )

        async def _do_ack():
            async for request in ad_subscription:
                received_advertisements.extend(request.message.payload)
                await bob_alexandria_client.send_ack(
                    request.sender_node_id,
                    request.sender_endpoint,
                    advertisement_radius=2**256 - 1,
                    acked=tuple(True for _ in request.message.payload),
                    request_id=request.request_id,
                )

                if len(received_advertisements) >= 12:
                    done.set()
                    break

        async with trio.open_nursery() as nursery:
            nursery.start_soon(_do_ack)
            nursery.start_soon(_do_pong)
            nursery.start_soon(_do_found_nodes)

            async with alice.alexandria.network() as alice_alexandria_network:
                await alice_alexandria_network.bond(bob.node_id)

                with trio.fail_after(90):
                    await done.wait()

            nursery.cancel_scope.cancel()

    # 1. All 12 pieces of content should have been advertised
    received_keys = {
        advertisement.content_key
        for advertisement in received_advertisements
    }
    assert len(received_keys) == 12
    assert received_keys == set(content_keys)

    # 2. The 2 contents that didn't have advertisements should have been lazily created.
    assert all(
        advertisement_db.exists(advertisement)
        for advertisement in received_advertisements)
Exemple #15
0
 async def setup(self, **kwargs):
     self.app = web.Application(middlewares=[self.error_middleware])
     if self.cors_domains:
         self.cors = aiohttp_cors.setup(
             self.app,
             defaults={
                 domain: aiohttp_cors.ResourceOptions(
                     allow_headers=("Content-Type", ))
                 for domain in self.cors_domains
             },
         )
     # http://docs.aiohttp.org/en/stable/faq.html#where-do-i-put-my-database-connection-so-handlers-can-access-it
     self.app["exit_stack"] = AsyncExitStack()
     await self.app["exit_stack"].__aenter__()
     self.app.on_shutdown.append(self.on_shutdown)
     self.app["multicomm_contexts"] = {"self": self}
     self.app["multicomm_routes"] = {}
     self.app["sources"] = {}
     self.app["source_contexts"] = {}
     self.app["source_repos_iterkeys"] = {}
     self.app["models"] = {}
     self.app["model_contexts"] = {}
     self.app.update(kwargs)
     # Allow no routes other than pre-registered if in atomic mode
     self.routes = ([] if self.mc_atomic else [
         # HTTP Service specific APIs
         ("POST", "/service/upload/{filepath:.+}", self.service_upload),
         ("GET", "/service/files", self.service_files),
         # DFFML APIs
         ("GET", "/list/sources", self.list_sources),
         (
             "POST",
             "/configure/source/{source}/{label}",
             self.configure_source,
         ),
         (
             "GET",
             "/context/source/{label}/{ctx_label}",
             self.context_source,
         ),
         ("GET", "/list/models", self.list_models),
         (
             "POST",
             "/configure/model/{model}/{label}",
             self.configure_model,
         ),
         (
             "GET",
             "/context/model/{label}/{ctx_label}",
             self.context_model,
         ),
         # MutliComm APIs (Data Flow)
         (
             "POST",
             "/multicomm/{label}/register",
             self.multicomm_register,
         ),
         # Source APIs
         ("GET", "/source/{label}/repo/{key}", self.source_repo),
         ("POST", "/source/{label}/update/{key}", self.source_update),
         (
             "GET",
             "/source/{label}/repos/{chunk_size}",
             self.source_repos,
         ),
         (
             "GET",
             "/source/{label}/repos/{iterkey}/{chunk_size}",
             self.source_repos_iter,
         ),
         # TODO route to delete iterkey before iteration has completed
         # Model APIs
         ("POST", "/model/{label}/train", self.model_train),
         ("POST", "/model/{label}/accuracy", self.model_accuracy),
         # TODO Provide an iterkey method for model prediction
         (
             "POST",
             "/model/{label}/predict/{chunk_size}",
             self.model_predict,
         ),
     ])
     for route in self.routes:
         route = self.app.router.add_route(*route)
         # Add cors to all routes
         if self.cors_domains:
             self.cors.add(route)
     self.runner = web.AppRunner(self.app)
     await self.runner.setup()
Exemple #16
0
async def main_loop(queue: Queue, ctx: Context, k8s_configmap_client: K8SConfigMapClient) -> None:
    namespace = ctx.namespace
    log.info('[appgate-operator/%s] Main loop started:', namespace)
    log.info('[appgate-operator/%s]   + namespace: %s', namespace, namespace)
    log.info('[appgate-operator/%s]   + host: %s', namespace, ctx.controller)
    log.info('[appgate-operator/%s]   + timeout: %s', namespace, ctx.timeout)
    log.info('[appgate-operator/%s]   + dry-run: %s', namespace, ctx.dry_run_mode)
    log.info('[appgate-operator/%s]   + cleanup: %s', namespace, ctx.cleanup_mode)
    log.info('[appgate-operator/%s]   + two-way-sync: %s', namespace, ctx.two_way_sync)
    log.info('[appgate-operator/%s] Getting current state from controller',
             namespace)
    current_appgate_state = await get_current_appgate_state(ctx=ctx)
    if ctx.cleanup_mode:
        expected_appgate_state = AppgateState(
            {k: v.entities_with_tags(ctx.builtin_tags) for k, v in current_appgate_state.entities_set.items()})
    else:
        expected_appgate_state = deepcopy(current_appgate_state)
    log.info('[appgate-operator/%s] Ready to get new events and compute a new plan',
             namespace)
    while True:
        try:
            event: AppgateEvent = await asyncio.wait_for(queue.get(), timeout=ctx.timeout)
            log.info('[appgate-operator/%s}] Event op: %s %s with name %s', namespace,
                     event.op, str(type(event.entity)), event.entity.name)
            expected_appgate_state.with_entity(EntityWrapper(event.entity), event.op, current_appgate_state)
        except asyncio.exceptions.TimeoutError:
            # Log all entities in expected state
            log.info('[appgate-operator/%s] Expected entities:', namespace)
            for entity_type, xs in expected_appgate_state.entities_set.items():
                for entity_name, e in xs.entities_by_name.items():
                    log.info('[appgate-operator/%s] %s: %s: %s', namespace, entity_type, entity_name,
                             e.id)

            # Resolve entities now, in order
            # this will be the Topological sort
            total_conflicts = resolve_appgate_state(appgate_state=expected_appgate_state,
                                                    reverse=False,
                                                    api_spec=ctx.api_spec)
            if total_conflicts:
                log.error('[appgate-operator/%s] Found errors in expected state and plan can'
                          ' not be applied.', namespace)
                entities_conflict_summary(conflicts=total_conflicts, namespace=namespace)
                log.info('[appgate-operator/%s] Waiting for more events that can fix the state.',
                         namespace)
                continue
                
            if ctx.two_way_sync:
                # use current appgate state from controller instead of from memory
                current_appgate_state = await get_current_appgate_state(ctx=ctx)

            # Create a plan
            # Need to copy?
            # Now we use dicts so resolving update the contents of the keys
            plan = create_appgate_plan(current_appgate_state, expected_appgate_state,
                                       ctx.builtin_tags,)
            if plan.needs_apply:
                log.info('[appgate-operator/%s] No more events for a while, creating a plan',
                         namespace)
                async with AsyncExitStack() as exit_stack:
                    appgate_client = None
                    if not ctx.dry_run_mode:
                        if ctx.device_id is None:
                            raise AppgateException('No device id specified')
                        appgate_client = await exit_stack.enter_async_context(AppgateClient(
                            controller=ctx.controller,
                            user=ctx.user, password=ctx.password, provider=ctx.provider,
                            device_id=ctx.device_id,
                            version=ctx.api_spec.api_version, no_verify=ctx.no_verify,
                            cafile=ctx.cafile))
                    else:
                        log.warning('[appgate-operator/%s] Running in dry-mode, nothing will be created',
                                    namespace)
                    new_plan = await appgate_plan_apply(appgate_plan=plan, namespace=namespace,
                                                        entity_clients=generate_api_spec_clients(
                                                            api_spec=ctx.api_spec,
                                                            appgate_client=appgate_client)
                                                        if appgate_client else {},
                                                        k8s_configmap_client=k8s_configmap_client,
                                                        api_spec=ctx.api_spec)

                    if len(new_plan.errors) > 0:
                        log.error('[appgate-operator/%s] Found errors when applying plan:', namespace)
                        for err in new_plan.errors:
                            log.error('[appgate-operator/%s] Error %s:', namespace, err)
                        sys.exit(1)

                    if appgate_client:
                        current_appgate_state = new_plan.appgate_state
                        expected_appgate_state = expected_appgate_state.sync_generations()
            else:
                log.info('[appgate-operator/%s] Nothing changed! Keeping watching!', namespace)
Exemple #17
0
    async def executor(self,
                       executor_protocol: ExecutorProtocol,
                       event_context,
                       extra_parameter={},
                       lru_cache_sets=None):
        lru_cache_sets = lru_cache_sets or {}
        executor_protocol: ExecutorProtocol
        for depend in executor_protocol.dependencies:
            if not inspect.isclass(depend.func):
                depend_func = depend.func
            elif hasattr(depend.func, "__call__"):
                depend_func = depend.func.__call__
            else:
                raise TypeError("must be callable.")

            if depend_func in lru_cache_sets and depend.cache:
                depend_func = lru_cache_sets[depend_func]
            else:
                if depend.cache:
                    original = depend_func
                    if inspect.iscoroutinefunction(depend_func):
                        depend_func = alru_cache(depend_func)
                    else:
                        depend_func = lru_cache(depend_func)
                    lru_cache_sets[original] = depend_func

            result = await self.executor_with_middlewares(
                depend_func, depend.middlewares, event_context, lru_cache_sets)
            if result is TRACEBACKED:
                return TRACEBACKED

        ParamSignatures = argument_signature(executor_protocol.callable)
        PlaceAnnotation = self.get_annotations_mapping()
        CallParams = {}
        for name, annotation, default in ParamSignatures:
            if default:
                if isinstance(default, Depend):
                    if not inspect.isclass(default.func):
                        depend_func = default.func
                    elif hasattr(default.func, "__call__"):
                        depend_func = default.func.__call__
                    else:
                        raise TypeError("must be callable.")

                    if depend_func in lru_cache_sets and default.cache:
                        depend_func = lru_cache_sets[depend_func]
                    else:
                        if default.cache:
                            original = depend_func
                            if inspect.iscoroutinefunction(depend_func):
                                depend_func = alru_cache(depend_func)
                            else:
                                depend_func = lru_cache(depend_func)
                            lru_cache_sets[original] = depend_func

                    CallParams[name] = await self.executor_with_middlewares(
                        depend_func, default.middlewares, event_context,
                        lru_cache_sets)
                    continue
                else:
                    raise RuntimeError("checked a unexpected default value.")
            else:
                if annotation in PlaceAnnotation:
                    CallParams[name] = PlaceAnnotation[annotation](
                        event_context)
                    continue
                else:
                    if name not in extra_parameter:
                        raise RuntimeError(
                            f"checked a unexpected annotation: {annotation}")

        async with AsyncExitStack() as stack:
            sorted_middlewares = self.sort_middlewares(
                executor_protocol.middlewares)
            for async_middleware in sorted_middlewares['async']:
                await stack.enter_async_context(async_middleware)
            for normal_middleware in sorted_middlewares['normal']:
                stack.enter_context(normal_middleware)

            return await self.run_func(executor_protocol.callable,
                                       **CallParams, **extra_parameter)
Exemple #18
0
    async def setup(self, **kwargs):
        self.app = web.Application(middlewares=[self.error_middleware])
        if self.cors_domains:
            self.cors = aiohttp_cors.setup(
                self.app,
                defaults={
                    domain: aiohttp_cors.ResourceOptions(
                        allow_headers=("Content-Type", ))
                    for domain in self.cors_domains
                },
            )
        # http://docs.aiohttp.org/en/stable/faq.html#where-do-i-put-my-database-connection-so-handlers-can-access-it
        self.app["exit_stack"] = AsyncExitStack()
        await self.app["exit_stack"].__aenter__()
        self.app.on_shutdown.append(self.on_shutdown)
        self.app["multicomm_contexts"] = {"self": self}
        self.app["multicomm_routes"] = {}
        self.app["source_records_iterkeys"] = {}

        # Instantiate sources if they aren't instantiated yet
        for i, source in enumerate(self.sources):
            if inspect.isclass(source):
                self.sources[i] = source.withconfig(self.extra_config)

        await self.app["exit_stack"].enter_async_context(self.sources)
        self.app["sources"] = {
            source.ENTRY_POINT_LABEL: source
            for source in self.sources
        }

        mctx = await self.app["exit_stack"].enter_async_context(self.sources())
        self.app["source_contexts"] = {
            source_ctx.parent.ENTRY_POINT_LABEL: source_ctx
            for source_ctx in mctx
        }

        # Instantiate models if they aren't instantiated yet
        for i, model in enumerate(self.models):
            if inspect.isclass(model):
                self.models[i] = model.withconfig(self.extra_config)

        await self.app["exit_stack"].enter_async_context(self.models)
        self.app["models"] = {
            model.ENTRY_POINT_LABEL: model
            for model in self.models
        }

        mctx = await self.app["exit_stack"].enter_async_context(self.models())
        self.app["model_contexts"] = {
            model_ctx.parent.ENTRY_POINT_LABEL: model_ctx
            for model_ctx in mctx
        }

        self.app.update(kwargs)
        # Allow no routes other than pre-registered if in atomic mode
        self.routes = ([] if self.mc_atomic else [
            # HTTP Service specific APIs
            ("POST", "/service/upload/{filepath:.+}", self.service_upload),
            ("GET", "/service/files", self.service_files),
            # DFFML APIs
            ("GET", "/list/sources", self.list_sources),
            (
                "POST",
                "/configure/source/{source}/{label}",
                self.configure_source,
            ),
            (
                "GET",
                "/context/source/{label}/{ctx_label}",
                self.context_source,
            ),
            ("GET", "/list/models", self.list_models),
            (
                "POST",
                "/configure/model/{model}/{label}",
                self.configure_model,
            ),
            (
                "GET",
                "/context/model/{label}/{ctx_label}",
                self.context_model,
            ),
            # MutliComm APIs (Data Flow)
            (
                "POST",
                "/multicomm/{label}/register",
                self.multicomm_register,
            ),
            # Source APIs
            ("GET", "/source/{label}/record/{key}", self.source_record),
            ("POST", "/source/{label}/update/{key}", self.source_update),
            (
                "GET",
                "/source/{label}/records/{chunk_size}",
                self.source_records,
            ),
            (
                "GET",
                "/source/{label}/records/{iterkey}/{chunk_size}",
                self.source_records_iter,
            ),
            # TODO route to delete iterkey before iteration has completed
            # Model APIs
            ("POST", "/model/{label}/train", self.model_train),
            ("POST", "/model/{label}/accuracy", self.model_accuracy),
            # TODO Provide an iterkey method for model prediction
            (
                "POST",
                "/model/{label}/predict/{chunk_size}",
                self.model_predict,
            ),
        ])
        # Serve api.js
        if self.js:
            self.routes.append(("GET", "/api.js", self.api_js))
        # Add all the routes and make them cors if needed
        for route in self.routes:
            route = self.app.router.add_route(*route)
            # Add cors to all routes
            if self.cors_domains:
                self.cors.add(route)
        # Serve static content
        if self.static:
            self.app.router.add_static("/", self.static)
        self.runner = web.AppRunner(self.app)
        await self.runner.setup()
Exemple #19
0
    async def dispatcher_pair(
            self, node_a: NodeAPI, node_b: NodeAPI
    ) -> AsyncIterator[Tuple[DispatcherAPI, DispatcherAPI]]:
        if node_a.node_id < node_b.node_id:
            left = node_a
            right = node_b
        elif node_b.node_id < node_a.node_id:
            left = node_b
            right = node_a
        else:
            raise Exception("Cannot pair with self")

        key = (left.node_id, right.node_id)
        if key in self._running_dispatchers:
            raise Exception("Already running dispatchers for: "
                            f"{humanize_node_id(left.node_id)} <-> "
                            f"{humanize_node_id(right.node_id)}")

        self.logger.info("setting up dispatcher pair: %s <> %s", left, right)

        async with AsyncExitStack() as stack:
            left_pool, left_channels = self._get_or_create_pool_for_node(left)

            if left.node_id in self._managed_dispatchers:
                self.logger.info("dispatcher already present for %s", left)
                left_managed_dispatcher = self._managed_dispatchers[
                    left.node_id]
            else:
                self.logger.info("setting up new dispatcher for %s", left)
                (
                    left_inbound_envelope_send_channel,
                    left_inbound_envelope_receive_channel,
                ) = trio.open_memory_channel[InboundEnvelope](256)

                left_dispatcher = Dispatcher(
                    left_inbound_envelope_receive_channel,
                    left_channels.inbound_message_receive_channel,
                    left_pool,
                    left.enr_db,
                    events=left.events,
                )
                left_managed_dispatcher = ManagedDispatcher(
                    dispatcher=left_dispatcher,
                    send_channel=left_inbound_envelope_send_channel,
                )
                self._managed_dispatchers[
                    left.node_id] = left_managed_dispatcher
                await stack.enter_async_context(
                    background_trio_service(left_dispatcher))

            right_pool, right_channels = self._get_or_create_pool_for_node(
                right)

            if right.node_id in self._managed_dispatchers:
                self.logger.info("dispatcher already present for %s", right)
                right_managed_dispatcher = self._managed_dispatchers[
                    right.node_id]
            else:
                self.logger.info("setting up new dispatcher for %s", right)
                (
                    right_inbound_envelope_send_channel,
                    right_inbound_envelope_receive_channel,
                ) = trio.open_memory_channel[InboundEnvelope](256)

                right_dispatcher = Dispatcher(
                    right_inbound_envelope_receive_channel,
                    right_channels.inbound_message_receive_channel,
                    right_pool,
                    right.enr_db,
                    events=right.events,
                )
                right_managed_dispatcher = ManagedDispatcher(
                    dispatcher=right_dispatcher,
                    send_channel=right_inbound_envelope_send_channel,
                )
                self._managed_dispatchers[
                    right.node_id] = right_managed_dispatcher
                await stack.enter_async_context(
                    background_trio_service(right_dispatcher))

            if left is node_a:
                dispatchers = (
                    left_managed_dispatcher.dispatcher,
                    right_managed_dispatcher.dispatcher,
                )
            elif left is node_b:
                dispatchers = (
                    right_managed_dispatcher.dispatcher,
                    left_managed_dispatcher.dispatcher,
                )
            else:
                raise Exception("Invariant")

            async with staple(
                    left,
                    left_channels.outbound_envelope_receive_channel,
                    right_managed_dispatcher.send_channel.clone(
                    ),  # type: ignore
            ):
                async with staple(
                        right,
                        right_channels.outbound_envelope_receive_channel,
                        left_managed_dispatcher.send_channel.clone(
                        ),  # type: ignore
                ):
                    self._running_dispatchers.add(key)
                    try:
                        yield dispatchers
                    finally:
                        self._running_dispatchers.remove(key)