コード例 #1
0
async def run(pc):
    session = ClientSession()

    async with session.ws_connect("ws://39.102.116.49:8080") as ws:
        async for msg in ws:
            if msg.type == WSMsgType.TEXT:
                data = json.loads(msg.data)

                if data["type"] == "offerOrAnswer":
                    await pc.setRemoteDescription(
                        object_from_string(json.dumps(data["msg"])))

                    if data["msg"]["type"] == "offer":
                        pc.addTrack(FlagVideoStreamTrack())
                        await pc.setLocalDescription(await pc.createAnswer())
                        await ws.send_str(
                            json.dumps({
                                "type":
                                "offerOrAnswer",
                                "msg":
                                json.loads(
                                    object_to_string(pc.localDescription)),
                            }))
                elif data["type"] == "candidate":
                    try:
                        await pc.addIceCandidate(
                            object_from_string(json.dumps(data["msg"])))
                    except:
                        pass
コード例 #2
0
ファイル: ws.py プロジェクト: tkuhrt/aries-cloudagent-python
class WsTransport(BaseOutboundTransport):
    """Websockets outbound transport class."""

    schemes = ("ws", "wss")

    def __init__(self) -> None:
        """Initialize an `WsTransport` instance."""
        super(WsTransport, self).__init__()
        self.logger = logging.getLogger(__name__)

    async def start(self):
        """Start the outbound transport."""
        self.client_session = ClientSession(cookie_jar=DummyCookieJar())
        return self

    async def stop(self):
        """Stop the outbound transport."""
        await self.client_session.close()
        self.client_session = None

    async def handle_message(self, message: OutboundMessage):
        """
        Handle message from queue.

        Args:
            message: `OutboundMessage` to send over transport implementation
        """
        # aiohttp should automatically handle websocket sessions
        async with self.client_session.ws_connect(message.endpoint) as ws:
            if isinstance(message.payload, bytes):
                await ws.send_bytes(message.payload)
            else:
                await ws.send_str(message.payload)
コード例 #3
0
async def ws_backend(cb, session: aiohttp.ClientSession = None):
    ws_logger = logging.getLogger("rcmproxy.run.ws_backend")

    if session is None:
        session = aiohttp.ClientSession()

    while True:
        try:
            async with session.ws_connect(
                    f"http://{UPSTREAM_IP}:{UPSTREAM_PORT_WS}/") as ws:
                async for msg in ws:
                    msg: aiohttp.WSMessage

                    if msg.type == aiohttp.WSMsgType.TEXT:
                        ws_logger.debug(
                            "Received WSMessage of type WSMsgType.TEXT")
                        await cb(msg.data, type_="StreamingData")

                    else:
                        ws_logger.warning(f"Unexpexted WSMsgType: {msg.type}")
                        ws_logger.debug(msg)

        except aiohttp.ClientError as e:
            ws_logger.debug(e, exc_info=True)
            ws_logger.warning(e)
            await asyncio.sleep(0.1)
コード例 #4
0
async def game_client(host, gameid, player):
    session = ClientSession()
    if ":" in host: host = "[" + host + "]"  # IPv6 literal
    async with session.ws_connect("http://%s:8888/ws" % host) as ws:
        stats[0] += 1
        ws.send_json({
            "type": "login",
            "data": {
                "room": gameid,
                "name": str(player)
            }
        })

        async def make_moves():
            # Stagger the requests a bit
            tm = (time.time() - SECONDS_BETWEEN_MOVES +
                  SECONDS_BETWEEN_MOVES // PLAYERS_PER_GAME * player +
                  random.randrange(SECONDS_BETWEEN_MOVES // PLAYERS_PER_GAME))
            while ws:
                tm += SECONDS_BETWEEN_MOVES
                delay = tm - time.time()
                if delay > 0:
                    await asyncio.sleep(delay)
                    if not ws: break
                stats[1] += 1
                ws.send_str(move_data)

        asyncio.ensure_future(make_moves())
        async for msg in ws:
            if msg.type == WSMsgType.TEXT:
                stats[2] += len(msg.data)
    ws = None
コード例 #5
0
async def ws_client():
    session = ClientSession()
    async with session.ws_connect('http://0.0.0.0:8080/ws') as ws:
        await promt(ws)
        async for msg in ws:
            print('Receive from server: ', msg.data)
            await promt(ws)
コード例 #6
0
ファイル: test_app.py プロジェクト: MikhailMS/ws_currency
async def call_ws(loop, data, expected = None):
    session = ClientSession(loop = loop)
    async with session.ws_connect('http://localhost:5000/bank') as websocket:
        await websocket.send_json(data)
        response = await websocket.receive_json()

        print ('resp', response, type(response))
        if expected:
            print ('exp', expected, type(expected))
            assert response == expected

    await session.close()
コード例 #7
0
ファイル: main.py プロジェクト: among-us-bot/aque-worker
async def handle_worker():
    global ws
    session = ClientSession()
    async with session.ws_connect(
            f"ws://{env['WORKER_MANAGER_HOST']}:6060/workers") as ws:
        await ws.send_json({"t": "identify", "d": None})
        message: WSMessage
        async for message in ws:
            if message.type == WSMsgType.TEXT:
                data = message.json(loads=loads)
                handler = handlers.get(data["t"], None)
                if handler is None:
                    continue
                client.loop.create_task(handler(data["d"]))
コード例 #8
0
ファイル: ws.py プロジェクト: szdong/pybotters
async def ws_run_forever(
    url: StrOrURL,
    session: aiohttp.ClientSession,
    event: asyncio.Event,
    *,
    send_str: Optional[str] = None,
    send_json: Optional[Any] = None,
    hdlr_str=None,
    hdlr_json=None,
    **kwargs: Any,
) -> None:
    iscorofunc_str = asyncio.iscoroutinefunction(hdlr_str)
    iscorofunc_json = asyncio.iscoroutinefunction(hdlr_json)
    while not session.closed:
        separator = asyncio.create_task(asyncio.sleep(60.0))
        try:
            async with session.ws_connect(url, **kwargs) as ws:
                event.set()
                if send_str is not None:
                    await ws.send_str(send_str)
                if send_json is not None:
                    await ws.send_json(send_json)
                async for msg in ws:
                    if msg.type == aiohttp.WSMsgType.TEXT:
                        if hdlr_str is not None:
                            try:
                                if iscorofunc_str:
                                    await hdlr_str(msg.data, ws)
                                else:
                                    hdlr_str(msg.data, ws)
                            except Exception as e:
                                logger.error(repr(e))
                        if hdlr_json is not None:
                            try:
                                data = msg.json()
                            except json.decoder.JSONDecodeError:
                                pass
                            else:
                                try:
                                    if iscorofunc_json:
                                        await hdlr_json(data, ws)
                                    else:
                                        hdlr_json(data, ws)
                                except Exception as e:
                                    logger.error(repr(e))
                    elif msg.type == aiohttp.WSMsgType.ERROR:
                        break
        except aiohttp.WSServerHandshakeError as e:
            logger.warning(repr(e))
        await separator
コード例 #9
0
async def main():
    session = ClientSession()

    async with session.ws_connect('http://0.0.0.0:8080/ws') as ws:

        await ws.send_str('Hello server! It is WS Client!')

        async for msg in ws:
            if msg.type == aiohttp.WSMsgType.TEXT:
                print(msg.data)
                break

            elif msg.type == aiohttp.WSMsgType.ERROR:
                break
コード例 #10
0
class ReceiverService(object):
    def __init__(self, billing_id: str, client_id: str, client_secret: str,
                 config: ClientConfig):
        self._logger = config.get_logger(__name__)

        self.auth_service = AuthService(purpose=self.__class__.__name__,
                                        billing_id=billing_id,
                                        client_id=client_id,
                                        client_secret=client_secret,
                                        config=config)

        self._config = config
        self._client = None
        self._running = False
        self._session = ClientSession()

    async def start_timer(self):
        await self.auth_service.start()

    async def start(self, as_json: bool, consumer: Callable[[Any], Any]):
        self._running = True

        await self.auth_service.start()
        uri = self._config.egress_uri + ("?asJson=true" if as_json else "")

        while self._running:
            async with self._session.ws_connect(
                    uri,
                    headers={
                        'Authorization':
                        f'Bearer {self.auth_service.get_access_token()}',
                        'Strm-Driver-Version':
                        self._config.version.brief_string(),
                        'Strm-Driver-Build':
                        self._config.version.release_string()
                    }) as ws:
                async for msg in ws:
                    if msg.type == aiohttp.WSMsgType.TEXT:
                        await consumer(msg.data)
                    elif msg.type == aiohttp.WSMsgType.CLOSED:
                        self._logger.debug("Websocket connection closed")
                        break
                    elif msg.type == aiohttp.WSMsgType.ERROR:
                        self._logger.debug(
                            "Error upon receiving data from websocket")
                        break

    def close(self):
        self._running = False
コード例 #11
0
async def connect_to_service_provider(task):
    global APP
    if 'active' in task and task['active']:
        return
    task['active'] = True
    desc = task['desc']

    ws_id = ''
    try:
        session = ClientSession()
        async with session.ws_connect(task['url'], heartbeat=30) as ws:
            ws_id = ws.id = str(uuid.uuid4())
            async for msg in ws:
                if msg.type == WSMsgType.TEXT:
                    if msg.data == 'close':
                        await ws.close()
                    else:
                        await handle_service_provider_message(
                            APP, msg.data, ws)
                elif msg.type == WSMsgType.CLOSE:
                    log.server_logger.info('%s connection closed normally',
                                           desc)
                    break
                elif msg.type == WSMsgType.ERROR:
                    log.server_logger.info(
                        '%s connection closed with error %s', desc,
                        ws.exception())
                    break

        log.server_logger.info('%s connection closed normally', desc)
    except Exception as e:
        log.server_logger.exception('%s connection closed with exception: %s',
                                    desc, e)
    except:
        log.server_logger.exception('%s uncaught exception!', desc)
    finally:
        if not session.closed:
            try:
                await session.close()
            except:
                pass

        if ws_id:
            await destroy_service_provider(APP, ws_id)
        task['active'] = False
コード例 #12
0
class WsTransport(BaseOutboundTransport):
    """Websockets outbound transport class."""

    schemes = ("ws", "wss")

    def __init__(self, **kwargs) -> None:
        """Initialize an `WsTransport` instance."""
        super().__init__(**kwargs)
        self.logger = logging.getLogger(__name__)

    async def start(self):
        """Start the outbound transport."""
        self.client_session = ClientSession(cookie_jar=DummyCookieJar(),
                                            trust_env=True)
        return self

    async def stop(self):
        """Stop the outbound transport."""
        await self.client_session.close()
        self.client_session = None

    async def handle_message(
        self,
        profile: Profile,
        payload: Union[str, bytes],
        endpoint: str,
        metadata: dict = None,
        api_key: str = None,
    ):
        """
        Handle message from queue.

        Args:
            profile: the profile that produced the message
            payload: message payload in string or byte format
            endpoint: URI endpoint for delivery
            metadata: Additional metadata associated with the payload
        """
        # aiohttp should automatically handle websocket sessions
        async with self.client_session.ws_connect(endpoint,
                                                  headers=metadata) as ws:
            if isinstance(payload, bytes):
                await ws.send_bytes(payload)
            else:
                await ws.send_str(payload)
コード例 #13
0
ファイル: service.py プロジェクト: BruceZhang1993/danmu-cli
class WebsocketDanmuService:
    def __init__(self, ws_address: str, payloads: list, hb: bytes,
                 interval: int, callback: Callable[[WSMessage],
                                                   Optional[Awaitable]]):
        self.ws: Optional[ClientWebSocketResponse] = None
        self.session = ClientSession()
        self.ws_address = ws_address
        self.payloads = payloads
        self.hb = hb
        self.interval = interval
        self.heartbeat_task: Optional['Future'] = None
        self.cb = callback

    async def connect(self):
        async with self.session.ws_connect(self.ws_address) as ws:
            self.ws = ws
            await self.running()

    async def send_heartbeat_once(self):
        await self.ws.send_bytes(self.hb)

    async def send_heartbeat(self):
        while True:
            await asyncio.sleep(self.interval)
            await self.send_heartbeat_once()

    async def running(self):
        for payload in self.payloads:
            await self.ws.send_bytes(payload)
        self.heartbeat_task = asyncio.ensure_future(self.send_heartbeat())
        async for msg in self.ws:
            msg: WSMessage
            if asyncio.iscoroutinefunction(self.cb):
                await self.cb(msg)
            else:
                loop = asyncio.get_event_loop()
                await loop.run_in_executor(None, self.cb, msg)

    async def stop(self):
        self.heartbeat_task.cancel()
        with suppress(asyncio.CancelledError):
            await self.heartbeat_task
        await self.ws.close()
        await self.session.close()
コード例 #14
0
class WsTransport(BaseOutboundTransport):
    """Websockets outbound transport class."""

    schemes = ("ws", "wss")

    def __init__(self, queue: BaseOutboundMessageQueue) -> None:
        """Initialize an `HttpTransport` instance."""
        self.logger = logging.getLogger(__name__)
        self._queue = queue

    async def __aenter__(self):
        """Async context manager enter."""
        self.client_session = ClientSession()
        return self

    async def __aexit__(self, *err):
        """Async context manager exit."""
        await self.client_session.close()
        self.client_session = None
        self.logger.error(err)

    @property
    def queue(self):
        """Accessor for queue."""
        return self._queue

    async def handle_message(self, message: OutboundMessage):
        """
        Handle message from queue.

        Args:
            message: `OutboundMessage` to send over transport implementation
        """
        try:
            # As an example, we can open a websocket channel, send a message, then
            # close the channel immediately. This is not optimal but it works.
            async with self.client_session.ws_connect(message.endpoint) as ws:
                if isinstance(message.payload, bytes):
                    await ws.send_bytes(message.payload)
                else:
                    await ws.send_str(message.payload)
        except Exception:
            # TODO: add retry logic
            self.logger.exception("Error handling outbound message")
コード例 #15
0
    async def _websocket_connect(self, endpoint: str,
                                 session: aiohttp.ClientSession) -> None:
        """
        Helper method to create websocket connection with specified *endpoint*
        using the specified :class:`aiohttp.ClientSession`. Once connected, we
        initialise and start the GraphQL subscription; then wait for any incoming
        messages. Any message received via the websocket connection is cast into
        a :class:`GraphQLSubscriptionEvent` instance and dispatched for handling via
        :method:`handle`.

        :param endpoint: Endpoint to use when creating the websocket connection.
        :param session: Session to use when creating the websocket connection.
        """
        async with session.ws_connect(endpoint) as ws:
            await ws.send_json(data=self.connection_init_request())

            self.callbacks.register(
                GraphQLSubscriptionEventType.CONNECTION_ACK,
                SimpleTriggerCallback(function=ws.send_json,
                                      data=self.connection_start_request()),
            )

            try:
                async for msg in ws:  # type:  aiohttp.WSMessage
                    if msg.type != aiohttp.WSMsgType.TEXT:
                        if msg.type == aiohttp.WSMsgType.ERROR:
                            break
                        continue

                    event = GraphQLSubscriptionEvent(
                        subscription_id=self.id,
                        request=self.request,
                        json=msg.json(),
                    )
                    await self.handle(event=event)

                    if self.is_stop_event(event):
                        break
            except (asyncio.CancelledError, KeyboardInterrupt):
                await ws.send_json(data=self.connection_stop_request())
コード例 #16
0
class WsTransport(BaseOutboundTransport):
    """Websockets outbound transport class."""

    schemes = ("ws", "wss")

    def __init__(self) -> None:
        """Initialize an `WsTransport` instance."""
        super(WsTransport, self).__init__()
        self.logger = logging.getLogger(__name__)

    async def start(self):
        """Start the outbound transport."""
        self.client_session = ClientSession(cookie_jar=DummyCookieJar())
        return self

    async def stop(self):
        """Stop the outbound transport."""
        await self.client_session.close()
        self.client_session = None

    async def handle_message(self, context: InjectionContext,
                             payload: Union[str, bytes], endpoint: str):
        """
        Handle message from queue.

        Args:
            context: the context that produced the message
            payload: message payload in string or byte format
            endpoint: URI endpoint for delivery
        """
        # aiohttp should automatically handle websocket sessions
        async with self.client_session.ws_connect(endpoint) as ws:
            if isinstance(payload, bytes):
                await ws.send_bytes(payload)
            else:
                await ws.send_str(payload)
コード例 #17
0
class WsConn(Conn):
    # url 格式 ws://hostname:port/… 或者 wss://hostname:port/…
    def __init__(
        self,
        url: str,
        receive_timeout: Optional[float] = None,
        session: Optional[ClientSession] = None,
        ws_receive_timeout: Optional[float] = None,  # 自动pingpong时候用的
        ws_heartbeat: Optional[float] = None):  # 自动pingpong时候用的
        super().__init__(receive_timeout)
        result = urlparse(url)
        assert result.scheme == 'ws' or result.scheme == 'wss'
        self._url = url

        if session is None:
            self._is_sharing_session = False
            self._session = ClientSession()
        else:
            self._is_sharing_session = True
            self._session = session
        self._ws_receive_timeout = ws_receive_timeout
        self._ws_heartbeat = ws_heartbeat
        self._ws = None

    async def open(self) -> bool:
        try:
            self._ws = await asyncio.wait_for(self._session.ws_connect(
                self._url,
                receive_timeout=self._ws_receive_timeout,
                heartbeat=self._ws_heartbeat),
                                              timeout=3)
        except asyncio.TimeoutError:
            return False
        except Exception:
            return False
        return True

    async def close(self) -> bool:
        if self._ws is not None:
            await self._ws.close()
        return True

    async def clean(self):
        if not self._is_sharing_session:
            await self._session.close()

    async def send_bytes(self, bytes_data) -> bool:
        try:
            await self._ws.send_bytes(bytes_data)
        except asyncio.CancelledError:
            return False
        except Exception:
            return False
        return True

    async def read_bytes(self) -> Optional[bytes]:
        try:
            bytes_data = await asyncio.wait_for(self._ws.receive_bytes(),
                                                timeout=self._receive_timeout)
        except asyncio.TimeoutError:
            return None
        except Exception:
            return None

        return bytes_data

    async def read_json(self) -> Any:
        try:
            msg = await asyncio.wait_for(self._ws.receive(),
                                         timeout=self._receive_timeout)
            if msg.type == WSMsgType.TEXT:
                return json.loads(msg.data)
            elif msg.type == WSMsgType.BINARY:
                return json.loads(msg.data.decode('utf8'))
        except asyncio.TimeoutError:
            return None
        except Exception:
            return None

        return None
コード例 #18
0
class TestAdminServer(AsyncTestCase):
    async def setUp(self):
        self.message_results = []
        self.webhook_results = []
        self.port = 0

        self.connector = TCPConnector(limit=16, limit_per_host=4)
        session_args = {
            "cookie_jar": DummyCookieJar(),
            "connector": self.connector
        }
        self.client_session = ClientSession(cookie_jar=DummyCookieJar(),
                                            connector=self.connector)

    async def tearDown(self):
        if self.client_session:
            await self.client_session.close()
            self.client_session = None

    async def test_debug_middleware(self):
        with async_mock.patch.object(test_module, "LOGGER",
                                     async_mock.MagicMock()) as mock_logger:
            mock_logger.isEnabledFor = async_mock.MagicMock(return_value=True)
            mock_logger.debug = async_mock.MagicMock()

            request = async_mock.MagicMock(
                method="GET",
                path_qs="/hello/world?a=1&b=2",
                match_info={"match": "info"},
                text=async_mock.CoroutineMock(return_value="abc123"),
            )
            handler = async_mock.CoroutineMock()

            await test_module.debug_middleware(request, handler)
            mock_logger.isEnabledFor.assert_called_once()
            assert mock_logger.debug.call_count == 3

    async def test_ready_middleware(self):
        with async_mock.patch.object(test_module, "LOGGER",
                                     async_mock.MagicMock()) as mock_logger:
            mock_logger.isEnabledFor = async_mock.MagicMock(return_value=True)
            mock_logger.debug = async_mock.MagicMock()
            mock_logger.info = async_mock.MagicMock()
            mock_logger.error = async_mock.MagicMock()

            request = async_mock.MagicMock(
                rel_url="/", app=async_mock.MagicMock(_state={"ready": False}))
            handler = async_mock.CoroutineMock(return_value="OK")
            with self.assertRaises(test_module.web.HTTPServiceUnavailable):
                await test_module.ready_middleware(request, handler)

            request.app._state["ready"] = True
            assert await test_module.ready_middleware(request, handler) == "OK"

            request.app._state["ready"] = True
            handler = async_mock.CoroutineMock(
                side_effect=test_module.LedgerConfigError("Bad config"))
            with self.assertRaises(test_module.LedgerConfigError):
                await test_module.ready_middleware(request, handler)

            request.app._state["ready"] = True
            handler = async_mock.CoroutineMock(
                side_effect=test_module.web.HTTPFound(location="/api/doc"))
            with self.assertRaises(test_module.web.HTTPFound):
                await test_module.ready_middleware(request, handler)

            request.app._state["ready"] = True
            handler = async_mock.CoroutineMock(
                side_effect=test_module.asyncio.CancelledError("Cancelled"))
            with self.assertRaises(test_module.asyncio.CancelledError):
                await test_module.ready_middleware(request, handler)

            request.app._state["ready"] = True
            handler = async_mock.CoroutineMock(
                side_effect=KeyError("No such thing"))
            with self.assertRaises(KeyError):
                await test_module.ready_middleware(request, handler)

    def get_admin_server(self,
                         settings: dict = None,
                         context: InjectionContext = None) -> AdminServer:
        if not context:
            context = InjectionContext()
        if settings:
            context.update_settings(settings)

        # middleware is task queue xor collector: cover both over test suite
        task_queue = (settings or {}).pop("task_queue", None)

        plugin_registry = async_mock.MagicMock(test_module.PluginRegistry,
                                               autospec=True)
        plugin_registry.post_process_routes = async_mock.MagicMock()
        context.injector.bind_instance(test_module.PluginRegistry,
                                       plugin_registry)

        collector = Collector()
        context.injector.bind_instance(test_module.Collector, collector)

        profile = InMemoryProfile.test_profile()

        self.port = unused_port()
        return AdminServer(
            "0.0.0.0",
            self.port,
            context,
            profile,
            self.outbound_message_router,
            self.webhook_router,
            conductor_stop=async_mock.CoroutineMock(),
            task_queue=TaskQueue(max_active=4) if task_queue else None,
            conductor_stats=(None if task_queue else async_mock.CoroutineMock(
                return_value={"a": 1})),
        )

    async def outbound_message_router(self, *args):
        self.message_results.append(args)

    def webhook_router(self, *args):
        self.webhook_results.append(args)

    async def test_start_stop(self):
        with self.assertRaises(AssertionError):
            await self.get_admin_server().start()

        settings = {"admin.admin_insecure_mode": False}
        with self.assertRaises(AssertionError):
            await self.get_admin_server(settings).start()

        settings = {
            "admin.admin_insecure_mode": True,
            "admin.admin_api_key": "test-api-key",
        }
        with self.assertRaises(AssertionError):
            await self.get_admin_server(settings).start()

        settings = {
            "admin.admin_insecure_mode": False,
            "admin.admin_client_max_request_size": 4,
            "admin.admin_api_key": "test-api-key",
        }
        server = self.get_admin_server(settings)
        await server.start()
        assert server.app._client_max_size == 4 * 1024 * 1024
        with async_mock.patch.object(server, "websocket_queues",
                                     async_mock.MagicMock()) as mock_wsq:
            mock_wsq.values = async_mock.MagicMock(return_value=[
                async_mock.MagicMock(stop=async_mock.MagicMock())
            ])
            await server.stop()

        with async_mock.patch.object(web.TCPSite, "start",
                                     async_mock.CoroutineMock()) as mock_start:
            mock_start.side_effect = OSError("Failure to launch")
            with self.assertRaises(AdminSetupError):
                await self.get_admin_server(settings).start()

    async def test_import_routes(self):
        # this test just imports all default admin routes
        # for routes with associated tests, this shouldn't make a difference in coverage
        context = InjectionContext()
        context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
        await DefaultContextBuilder().load_plugins(context)
        server = self.get_admin_server({"admin.admin_insecure_mode": True},
                                       context)
        app = await server.make_application()

    async def test_import_routes_multitenant_middleware(self):
        # imports all default admin routes
        context = InjectionContext()
        context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
        profile = InMemoryProfile.test_profile()
        context.injector.bind_instance(
            test_module.MultitenantManager,
            test_module.MultitenantManager(profile),
        )
        await DefaultContextBuilder().load_plugins(context)
        server = self.get_admin_server(
            {
                "admin.admin_insecure_mode": False,
                "admin.admin_api_key": "test-api-key",
            },
            context,
        )

        # cover multitenancy start code
        app = await server.make_application()
        app["swagger_dict"] = {}
        await server.on_startup(app)

        # multitenant authz
        [mt_authz_middle] = [
            m for m in app.middlewares
            if ".check_multitenant_authorization" in str(m)
        ]

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={"Authorization": "Bearer ..."},
            path="/multitenancy/etc",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        with self.assertRaises(test_module.web.HTTPUnauthorized):
            await mt_authz_middle(mock_request, None)

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={},
            path="/protected/non-multitenancy/non-server",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        with self.assertRaises(test_module.web.HTTPUnauthorized):
            await mt_authz_middle(mock_request, None)

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={"Authorization": "Bearer ..."},
            path="/protected/non-multitenancy/non-server",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        mock_handler = async_mock.CoroutineMock()
        await mt_authz_middle(mock_request, mock_handler)
        assert mock_handler.called_once_with(mock_request)

        # multitenant setup context exception paths
        [setup_ctx_middle
         ] = [m for m in app.middlewares if ".setup_context" in str(m)]

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={"Authorization": "Non-bearer ..."},
            path="/protected/non-multitenancy/non-server",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        with self.assertRaises(test_module.web.HTTPUnauthorized):
            await setup_ctx_middle(mock_request, None)

        mock_request = async_mock.MagicMock(
            method="GET",
            headers={"Authorization": "Bearer ..."},
            path="/protected/non-multitenancy/non-server",
            text=async_mock.CoroutineMock(return_value="abc123"),
        )
        with async_mock.patch.object(
                server.multitenant_manager,
                "get_profile_for_token",
                async_mock.CoroutineMock(),
        ) as mock_get_profile:
            mock_get_profile.side_effect = [
                test_module.MultitenantManagerError("corrupt token"),
                test_module.StorageNotFoundError("out of memory"),
            ]
            for i in range(2):
                with self.assertRaises(test_module.web.HTTPUnauthorized):
                    await setup_ctx_middle(mock_request, None)

    async def test_register_external_plugin_x(self):
        context = InjectionContext()
        context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
        with self.assertRaises(ValueError):
            builder = DefaultContextBuilder(
                settings={"external_plugins": "aries_cloudagent.nosuchmodule"})
            await builder.load_plugins(context)

    async def test_visit_insecure_mode(self):
        settings = {"admin.admin_insecure_mode": True, "task_queue": True}
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.post(
                f"http://127.0.0.1:{self.port}/status/reset",
                headers={}) as response:
            assert response.status == 200

        async with self.client_session.ws_connect(
                f"http://127.0.0.1:{self.port}/ws") as ws:
            result = await ws.receive_json()
            assert result["topic"] == "settings"

        for path in (
                "",
                "plugins",
                "status",
                "status/live",
                "status/ready",
                "shutdown",  # mock conductor has magic-mock stop()
        ):
            async with self.client_session.get(
                    f"http://127.0.0.1:{self.port}/{path}",
                    headers={}) as response:
                assert response.status == 200

        await server.stop()

    async def test_visit_secure_mode(self):
        settings = {
            "admin.admin_insecure_mode": False,
            "admin.admin_api_key": "test-api-key",
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status",
                headers={"x-api-key": "wrong-key"}) as response:
            assert response.status == 401

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status",
                headers={"x-api-key": "test-api-key"},
        ) as response:
            assert response.status == 200

        async with self.client_session.ws_connect(
                f"http://127.0.0.1:{self.port}/ws",
                headers={"x-api-key": "test-api-key"}) as ws:
            result = await ws.receive_json()
            assert result["topic"] == "settings"

        await server.stop()

    async def test_query_config(self):
        settings = {
            "admin.admin_insecure_mode":
            False,
            "admin.admin_api_key":
            "test-api-key",
            "admin.webhook_urls":
            ["localhost:8123/abc#secret", "localhost:8123/def"],
            "multitenant.jwt_secret":
            "abc123",
            "wallet.key":
            "abc123",
            "wallet.rekey":
            "def456",
            "wallet.seed":
            "00000000000000000000000000000000",
            "wallet.storage.creds":
            "secret",
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/config",
                headers={"x-api-key": "test-api-key"},
        ) as response:
            config = json.loads(await response.text())["config"]
            assert "admin.admin_insecure_mode" in config
            assert all(k not in config for k in [
                "admin.admin_api_key",
                "multitenant.jwt_secret",
                "wallet.key",
                "wallet.rekey",
                "wallet.seed",
                "wallet.storage.creds",
            ])
            assert config["admin.webhook_urls"] == [
                "localhost:8123/abc",
                "localhost:8123/def",
            ]

    async def test_visit_shutting_down(self):
        settings = {
            "admin.admin_insecure_mode": True,
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/shutdown",
                headers={}) as response:
            assert response.status == 200

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status",
                headers={}) as response:
            assert response.status == 503

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/live",
                headers={}) as response:
            assert response.status == 200
        await server.stop()

    async def test_server_health_state(self):
        settings = {
            "admin.admin_insecure_mode": True,
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/live",
                headers={}) as response:
            assert response.status == 200
            response_json = await response.json()
            assert response_json["alive"]

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/ready",
                headers={}) as response:
            assert response.status == 200
            response_json = await response.json()
            assert response_json["ready"]

        server.notify_fatal_error()
        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/live",
                headers={}) as response:
            assert response.status == 503

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/ready",
                headers={}) as response:
            assert response.status == 503
        await server.stop()
コード例 #19
0
class JupyterClient:
    log: BoundLoggerLazyProxy
    user: User
    session: ClientSession
    headers: Dict[str, str]
    xsrftoken: str
    jupyter_url: str

    def __init__(self, user: User, log: BoundLoggerLazyProxy,
                 options: Dict[str, Any]):
        self.user = user
        self.log = log
        self.jupyter_base = options.get("nb_url", "/nb/")
        self.jupyter_url = Configuration.environment_url + self.jupyter_base

        self.xsrftoken = "".join(
            random.choices(string.ascii_uppercase + string.digits, k=16))
        self.jupyter_options_form = options.get("jupyter_options_form", {})

        self.headers = {
            "Authorization": "Bearer " + user.token,
            "x-xsrftoken": self.xsrftoken,
        }

        self.session = ClientSession(headers=self.headers)
        self.session.cookie_jar.update_cookies(
            BaseCookie({"_xsrf": self.xsrftoken}))

    __ansi_reg_exp = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")

    @classmethod
    def _ansi_escape(cls, line: str) -> str:
        return cls.__ansi_reg_exp.sub("", line)

    async def hub_login(self) -> None:
        async with self.session.get(self.jupyter_url + "hub/login") as r:
            if r.status != 200:
                await self._raise_error("Error logging into hub", r)

    async def ensure_lab(self) -> None:
        self.log.info("Ensure lab")
        running = await self.is_lab_running()
        if running:
            await self.lab_login()
        else:
            await self.spawn_lab()

    async def lab_login(self) -> None:
        self.log.info("Logging into lab")
        lab_url = self.jupyter_url + f"user/{self.user.username}/lab"
        async with self.session.get(lab_url) as r:
            if r.status != 200:
                await self._raise_error("Error logging into lab", r)

    async def is_lab_running(self) -> bool:
        self.log.info("Is lab running?")
        hub_url = self.jupyter_url + "hub"
        async with self.session.get(hub_url) as r:
            if r.status != 200:
                self.log.error(f"Error {r.status} from {r.url}")

            spawn_url = self.jupyter_url + "hub/spawn"
            self.log.info(f"Going to {hub_url} redirected to {r.url}")
            if str(r.url) == spawn_url:
                return False

        return True

    async def spawn_lab(self) -> None:
        spawn_url = self.jupyter_url + "hub/spawn"
        pending_url = (self.jupyter_url +
                       f"hub/spawn-pending/{self.user.username}")
        lab_url = self.jupyter_url + f"user/{self.user.username}/lab"

        # DM-23864: Do a get on the spawn URL even if I don't have to.
        async with self.session.get(spawn_url) as r:
            await r.text()

        async with self.session.post(spawn_url,
                                     data=self.jupyter_options_form,
                                     allow_redirects=False) as r:
            if r.status != 302:
                await self._raise_error("Spawn did not redirect", r)

            redirect_url = (self.jupyter_base +
                            f"hub/spawn-pending/{self.user.username}")
            if r.headers["Location"] != redirect_url:
                await self._raise_error("Spawn didn't redirect to pending", r)

        # Jupyterlab will give up a spawn after 900 seconds, so we shouldn't
        # wait longer than that.
        max_poll_secs = 900
        poll_interval = 15
        retries = max_poll_secs / poll_interval

        while retries > 0:
            async with self.session.get(pending_url) as r:
                if str(r.url) == lab_url:
                    self.log.info(f"Lab spawned, redirected to {r.url}")
                    return

                if not r.ok:
                    await self._raise_error("Error spawning", r)

                self.log.info(f"Still waiting for lab to spawn {r}")
                retries -= 1
                await asyncio.sleep(poll_interval)

        raise Exception("Giving up waiting for lab to spawn!")

    async def delete_lab(self) -> None:
        headers = {"Referer": self.jupyter_url + "hub/home"}

        server_url = (self.jupyter_url +
                      f"hub/api/users/{self.user.username}/server")
        self.log.info(f"Deleting lab for {self.user.username} at {server_url}")

        async with self.session.delete(server_url, headers=headers) as r:
            if r.status not in [200, 202, 204]:
                await self._raise_error("Error deleting lab", r)

    async def create_kernel(self, kernel_name: str = "LSST") -> str:
        kernel_url = (self.jupyter_url +
                      f"user/{self.user.username}/api/kernels")
        body = {"name": kernel_name}

        async with self.session.post(kernel_url, json=body) as r:
            if r.status != 201:
                await self._raise_error("Error creating kernel", r)

            response = await r.json()
            return response["id"]

    async def run_python(self, kernel_id: str, code: str) -> str:
        kernel_url = (
            self.jupyter_url +
            f"user/{self.user.username}/api/kernels/{kernel_id}/channels")

        msg_id = uuid4().hex

        msg = {
            "header": {
                "username": "",
                "version": "5.0",
                "session": "",
                "msg_id": msg_id,
                "msg_type": "execute_request",
            },
            "parent_header": {},
            "channel": "shell",
            "content": {
                "code": code,
                "silent": False,
                "store_history": False,
                "user_expressions": {},
                "allow_stdin": False,
            },
            "metadata": {},
            "buffers": {},
        }

        async with self.session.ws_connect(kernel_url) as ws:
            await ws.send_json(msg)

            while True:
                r = await ws.receive_json()
                self.log.debug(f"Recieved kernel message: {r}")
                msg_type = r["msg_type"]
                if msg_type == "error":
                    error_message = "".join(r["content"]["traceback"])
                    raise NotebookException(self._ansi_escape(error_message))
                elif (msg_type == "stream"
                      and msg_id == r["parent_header"]["msg_id"]):
                    return r["content"]["text"]
                elif msg_type == "execute_reply":
                    status = r["content"]["status"]
                    if status == "ok":
                        return ""
                    else:
                        raise NotebookException(
                            f"Error content status is {status}")

    def dump(self) -> dict:
        return {
            "cookies": [str(cookie) for cookie in self.session.cookie_jar],
        }

    async def _raise_error(self, msg: str, r: ClientResponse) -> None:
        raise Exception(f"{msg}: {r.status} {r.url}: {r.headers}")
コード例 #20
0
ファイル: client.py プロジェクト: SkyPicker/Skywall
class WebsocketClient:

    def __init__(self, loop):
        self.url = config.get('server.publicUrl')
        self.client_id = config.get('client.id')
        self.client_token = config.get('client.token')
        self.reports_frequency = config.get('client.reports.frequency')
        self.loop = loop
        self.session = None
        self.socket = None

    def __enter__(self):
        try:
            before_client_start.emit(client=self)
            headers = self._headers()
            self.session = ClientSession(loop=self.loop)
            self.socket = self.loop.run_until_complete(self.session.ws_connect(self.url, headers=headers))
            after_client_start.emit(client=self)
            return self
        except:
            self._close()
            raise

    def __exit__(self, exc_type, exc_val, exc_tb):
        before_client_stop.emit(client=self)
        self._close()
        after_client_stop.emit(client=self)

    def _close(self):
        if self.socket:
            self.loop.run_until_complete(self.socket.close(code=WSCloseCode.GOING_AWAY))
        if self.session:
            self.loop.run_until_complete(self.session.close())

    def _headers(self):
        headers = {}
        headers[CLIENT_ID_HEADER] = str(self.client_id)
        headers[CLIENT_TOKEN_HEADER] = str(self.client_token)
        return headers

    def _process_confirm(self, action):
        try:
            print('Received confirmation of action "{}" with payload: {}'.format(action.name, action.payload),
                    flush=True)
            action.after_confirm.emit(client=self, action=action)
            after_server_action_confirm.emit(client=self, action=action)
        except Exception as e:
            print('Processing confirmation of action "{}" failed: {}'.format(action.name, e), flush=True)

    def _process_action(self, action):
        try:
            print('Received action "{}" with payload: {}'.format(action.name, action.payload), flush=True)
            before_client_action_receive.emit(client=self, action=action)
            action.before_receive.emit(client=self, action=action)
            action.execute(self)
            action.after_receive.emit(client=self, action=action)
            after_client_action_receive.emit(client=self, action=action)
            self.socket.send_json(action.send_confirm())
        except Exception as e:
            print('Executing action "{}" failed: {}'.format(action.name, e), flush=True)

    def _process_message(self, msg):
        if msg.type != WSMsgType.TEXT:
            return
        try:
            action = parse_client_action(msg.data)
        except Exception as e:
            print('Invalid message received: {}; Error: {}'.format(msg.data, e), flush=True)
            return
        if action.confirm:
            self._process_confirm(action)
        else:
            self._process_action(action)

    async def connect(self):
        self.send_label()
        self.send_reports()
        async for msg in self.socket:
            self._process_message(msg)

    def send_action(self, action):
        before_server_action_send.emit(client=self, action=action)
        action.before_send.emit(client=self, action=action)
        self.socket.send_json(action.send())
        action.after_send.emit(client=self, action=action)
        after_server_action_send.emit(client=self, action=action)

    async def check_send_action(self, action):
        future = asyncio.Future()
        sent_action = action

        def listener(client, action):
            if client is not self:
                return
            if action.name != sent_action.name:
                return
            if action.action_id != sent_action.action_id:
                return
            if not future.done():
                future.set_result(True)

        with after_server_action_confirm.connected(listener):
            self.send_action(sent_action)
            await asyncio.wait_for(future, ACTION_CONFIRM_TIMEOUT)

    def send_label(self):
        label = config.get('client.label')
        self.send_action(SaveLabelServerAction(label=label))

    def send_reports(self):
        report = collect_report()
        self.send_action(SaveReportServerAction(report=report))
        self.loop.call_later(self.reports_frequency, self.send_reports)
コード例 #21
0
class WsConn(Conn):
    __slots__ = ('_is_sharing_session', '_session', '_ws_receive_timeout',
                 '_ws_heartbeat', '_ws')

    # url 格式 ws://hostname:port/… 或者 wss://hostname:port/…
    def __init__(
        self,
        url: str,
        receive_timeout: Optional[float] = None,
        session: Optional[ClientSession] = None,
        ws_receive_timeout: Optional[float] = None,  # 自动 ping pong 时候用的
        ws_heartbeat: Optional[float] = None):  # 自动 ping pong 时候用的
        super().__init__(url, receive_timeout)
        result = urlparse(url)
        if result.scheme != 'ws' and result.scheme != 'wss':
            raise TypeError(f'url scheme must be websocket ({result.scheme})')
        self._url = url

        if session is None:
            self._is_sharing_session = False
            self._session = ClientSession()
        else:
            self._is_sharing_session = True
            self._session = session
        self._ws_receive_timeout = ws_receive_timeout
        self._ws_heartbeat = ws_heartbeat
        self._ws = None

    async def open(self) -> bool:
        try:
            self._ws = await asyncio.wait_for(self._session.ws_connect(
                self._url,
                receive_timeout=self._ws_receive_timeout,
                heartbeat=self._ws_heartbeat),
                                              timeout=3)
        except (ClientError, asyncio.TimeoutError):
            return False
        return True

    async def close(self) -> bool:
        if self._ws is not None:
            await self._ws.close()
        return True

    async def clean(self) -> None:
        if not self._is_sharing_session:
            await self._session.close()

    async def send_bytes(self, bytes_data: bytes) -> bool:
        try:
            await self._ws.send_bytes(bytes_data)
        except ClientError:
            return False
        except asyncio.CancelledError:
            return False
        return True

    async def read_bytes(self) -> Optional[bytes]:
        try:
            msg = await asyncio.wait_for(self._ws.receive(),
                                         timeout=self._receive_timeout)
            if msg.type == WSMsgType.BINARY:
                return msg.data
        except (ClientError, asyncio.TimeoutError):
            return None
        except asyncio.CancelledError:
            # print('asyncio.CancelledError', 'read_bytes')
            return None
        return None

    async def read_json(self) -> Any:
        try:
            msg = await asyncio.wait_for(self._ws.receive(),
                                         timeout=self._receive_timeout)
            if msg.type == WSMsgType.TEXT:
                return json.loads(msg.data)
            elif msg.type == WSMsgType.BINARY:
                return json.loads(msg.data.decode('utf8'))
        except (ClientError, asyncio.TimeoutError):
            return None
        except asyncio.CancelledError:
            # print('asyncio.CancelledError', 'read_json')
            return None
        return None

    async def read_exactly_bytes(self, n: int) -> Optional[bytes]:
        raise NotImplementedError(
            "Sorry, but I don't think we need this in WebSocket.")

    async def read_exactly_json(self, n: int) -> Any:
        raise NotImplementedError(
            "Sorry, but I don't think we need this in WebSocket.")
コード例 #22
0
ファイル: ws.py プロジェクト: panebinese/pybotters
 async def _run_forever(
     self,
     url: StrOrURL,
     session: aiohttp.ClientSession,
     *,
     send_str: Optional[Union[str, list[str]]] = None,
     send_bytes: Optional[Union[bytes, list[bytes]]] = None,
     send_json: Any = None,
     hdlr_str=None,
     hdlr_bytes=None,
     hdlr_json=None,
     auth=_Auth,
     **kwargs: Any,
 ) -> None:
     if all([hdlr_str is None, hdlr_json is None]):
         hdlr_json = pybotters.print_handler
     iscorofunc_str = asyncio.iscoroutinefunction(hdlr_str)
     iscorofunc_bytes = asyncio.iscoroutinefunction(hdlr_bytes)
     iscorofunc_json = asyncio.iscoroutinefunction(hdlr_json)
     while not session.closed:
         cooldown = asyncio.create_task(asyncio.sleep(60.0))
         try:
             async with session.ws_connect(url, auth=auth, **kwargs) as ws:
                 self.conneted = True
                 self._event.set()
                 if send_str is not None:
                     if isinstance(send_str, list):
                         await asyncio.gather(
                             *[ws.send_str(item) for item in send_str]
                         )
                     else:
                         await ws.send_str(send_str)
                 if send_bytes is not None:
                     if isinstance(send_bytes, list):
                         await asyncio.gather(
                             *[ws.send_bytes(item) for item in send_bytes]
                         )
                     else:
                         await ws.send_bytes(send_bytes)
                 if send_json is not None:
                     if isinstance(send_json, list):
                         await asyncio.gather(
                             *[ws.send_json(item) for item in send_json]
                         )
                     else:
                         await ws.send_json(send_json)
                 async for msg in ws:
                     if msg.type == aiohttp.WSMsgType.TEXT:
                         if hdlr_str is not None:
                             try:
                                 if iscorofunc_str:
                                     await hdlr_str(msg.data, ws)
                                 else:
                                     hdlr_str(msg.data, ws)
                             except Exception as e:
                                 logger.exception(f"{pretty_modulename(e)}: {e}")
                         if hdlr_json is not None:
                             try:
                                 data = msg.json()
                             except json.decoder.JSONDecodeError:
                                 pass
                             else:
                                 try:
                                     if iscorofunc_json:
                                         await hdlr_json(data, ws)
                                     else:
                                         hdlr_json(data, ws)
                                 except Exception as e:
                                     logger.exception(f"{pretty_modulename(e)}: {e}")
                     elif msg.type == aiohttp.WSMsgType.BINARY:
                         if hdlr_bytes is not None:
                             try:
                                 if iscorofunc_bytes:
                                     await hdlr_bytes(msg.data, ws)
                                 else:
                                     hdlr_bytes(msg.data, ws)
                             except Exception as e:
                                 logger.exception(f"{pretty_modulename(e)}: {e}")
                     elif msg.type == aiohttp.WSMsgType.ERROR:
                         break
         except (
             aiohttp.WSServerHandshakeError,
             aiohttp.ClientOSError,
             ConnectionResetError,
         ) as e:
             logger.warning(f"{pretty_modulename(e)}: {e}")
         self.conneted = False
         self._event.clear()
         await cooldown
コード例 #23
0
class TestAdminServer(AsyncTestCase):
    async def setUp(self):
        self.message_results = []
        self.webhook_results = []
        self.port = 0

        self.connector = TCPConnector(limit=16, limit_per_host=4)
        session_args = {
            "cookie_jar": DummyCookieJar(),
            "connector": self.connector
        }
        self.client_session = ClientSession(cookie_jar=DummyCookieJar(),
                                            connector=self.connector)

    async def tearDown(self):
        if self.client_session:
            await self.client_session.close()
            self.client_session = None

    async def test_debug_middleware(self):
        with async_mock.patch.object(test_module, "LOGGER",
                                     async_mock.MagicMock()) as mock_logger:
            mock_logger.isEnabledFor = async_mock.MagicMock(return_value=True)
            mock_logger.debug = async_mock.MagicMock()

            request = async_mock.MagicMock(
                method="GET",
                path_qs="/hello/world?a=1&b=2",
                match_info={"match": "info"},
                text=async_mock.CoroutineMock(return_value="abc123"),
            )
            handler = async_mock.CoroutineMock()

            await test_module.debug_middleware(request, handler)
            mock_logger.isEnabledFor.assert_called_once()
            assert mock_logger.debug.call_count == 3

    def get_admin_server(self,
                         settings: dict = None,
                         context: InjectionContext = None) -> AdminServer:
        if not context:
            context = InjectionContext()
        if settings:
            context.update_settings(settings)

        # middleware is task queue xor collector: cover both over test suite
        task_queue = (settings or {}).pop("task_queue", None)

        plugin_registry = async_mock.MagicMock(test_module.PluginRegistry,
                                               autospec=True)
        plugin_registry.post_process_routes = async_mock.MagicMock()
        context.injector.bind_instance(test_module.PluginRegistry,
                                       plugin_registry)

        collector = Collector()
        context.injector.bind_instance(test_module.Collector, collector)

        self.port = unused_port()
        return AdminServer(
            "0.0.0.0",
            self.port,
            context,
            self.outbound_message_router,
            self.webhook_router,
            conductor_stop=async_mock.CoroutineMock(),
            task_queue=TaskQueue(max_active=4) if task_queue else None,
            conductor_stats=(None if task_queue else async_mock.CoroutineMock(
                return_value=[1, 2])),
        )

    async def outbound_message_router(self, *args):
        self.message_results.append(args)

    def webhook_router(self, *args):
        self.webhook_results.append(args)

    async def test_start_stop(self):
        with self.assertRaises(AssertionError):
            await self.get_admin_server().start()

        settings = {"admin.admin_insecure_mode": False}
        with self.assertRaises(AssertionError):
            await self.get_admin_server(settings).start()

        settings = {
            "admin.admin_insecure_mode": True,
            "admin.admin_api_key": "test-api-key",
        }
        with self.assertRaises(AssertionError):
            await self.get_admin_server(settings).start()

        settings = {
            "admin.admin_insecure_mode": False,
            "admin.admin_api_key": "test-api-key",
        }
        server = self.get_admin_server(settings)
        await server.start()
        with async_mock.patch.object(server, "websocket_queues",
                                     async_mock.MagicMock()) as mock_wsq:
            mock_wsq.values = async_mock.MagicMock(return_value=[
                async_mock.MagicMock(stop=async_mock.MagicMock())
            ])
            await server.stop()

        with async_mock.patch.object(web.TCPSite, "start",
                                     async_mock.CoroutineMock()) as mock_start:
            mock_start.side_effect = OSError("Failure to launch")
            with self.assertRaises(AdminSetupError):
                await self.get_admin_server(settings).start()

    async def test_responder_send(self):
        message = OutboundMessage(payload="{}")
        server = self.get_admin_server()
        await server.responder.send_outbound(message)
        assert self.message_results == [(server.context, message)]

    async def test_responder_webhook(self):
        server = self.get_admin_server()
        test_url = "target_url"
        test_attempts = 99
        server.add_webhook_target(
            target_url=test_url,
            topic_filter=["*"],  # cover vacuous filter
            max_attempts=test_attempts,
        )
        test_topic = "test_topic"
        test_payload = {"test": "TEST"}

        with async_mock.patch.object(server, "websocket_queues",
                                     async_mock.MagicMock()) as mock_wsq:
            mock_wsq.values = async_mock.MagicMock(return_value=[
                async_mock.MagicMock(authenticated=True,
                                     enqueue=async_mock.CoroutineMock())
            ])

            await server.responder.send_webhook(test_topic, test_payload)
            assert self.webhook_results == [(test_topic, test_payload,
                                             test_url, test_attempts)]

        server.remove_webhook_target(target_url=test_url)
        assert test_url not in server.webhook_targets

    async def test_import_routes(self):
        # this test just imports all default admin routes
        # for routes with associated tests, this shouldn't make a difference in coverage
        context = InjectionContext()
        context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
        await DefaultContextBuilder().load_plugins(context)
        server = self.get_admin_server({"admin.admin_insecure_mode": True},
                                       context)
        app = await server.make_application()

    async def test_register_external_plugin_x(self):
        context = InjectionContext()
        context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry())
        with self.assertRaises(ValueError):
            builder = DefaultContextBuilder(
                settings={"external_plugins": "aries_cloudagent.nosuchmodule"})
            await builder.load_plugins(context)

    async def test_visit_insecure_mode(self):
        settings = {"admin.admin_insecure_mode": True, "task_queue": True}
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.post(
                f"http://127.0.0.1:{self.port}/status/reset",
                headers={}) as response:
            assert response.status == 200

        async with self.client_session.ws_connect(
                f"http://127.0.0.1:{self.port}/ws") as ws:
            result = await ws.receive_json()
            assert result["topic"] == "settings"

        for path in (
                "",
                "plugins",
                "status",
                "status/live",
                "status/ready",
                "shutdown",  # mock conductor has magic-mock stop()
        ):
            async with self.client_session.get(
                    f"http://127.0.0.1:{self.port}/{path}",
                    headers={}) as response:
                assert response.status == 200

        await server.stop()

    async def test_visit_secure_mode(self):
        settings = {
            "admin.admin_insecure_mode": False,
            "admin.admin_api_key": "test-api-key",
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status",
                headers={"x-api-key": "wrong-key"}) as response:
            assert response.status == 401

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status",
                headers={"x-api-key": "test-api-key"},
        ) as response:
            assert response.status == 200

        async with self.client_session.ws_connect(
                f"http://127.0.0.1:{self.port}/ws",
                headers={"x-api-key": "test-api-key"}) as ws:
            result = await ws.receive_json()
            assert result["topic"] == "settings"

        await server.stop()

    async def test_visit_shutting_down(self):
        settings = {
            "admin.admin_insecure_mode": True,
        }
        server = self.get_admin_server(settings)
        await server.start()

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/shutdown",
                headers={}) as response:
            assert response.status == 200

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status",
                headers={}) as response:
            assert response.status == 503

        async with self.client_session.get(
                f"http://127.0.0.1:{self.port}/status/live",
                headers={}) as response:
            assert response.status == 200
        await server.stop()
コード例 #24
0
async def proxy_handler(req: web.Request) -> web.Response:
    sess = await get_session(req)
    if "container_name" not in sess.keys():
        raise web.HTTPFound("/login")
    else:
        container_name = sess["container_name"]
    code_server_manager = CodeServerManager(container_name)
    await code_server_manager.find_or_create_container()
    reqH = req.headers.copy()
    base_url = f"http://{container_name}:8080"
    # Do web socket Stuff
    if (
        reqH["connection"] == "Upgrade"
        and reqH["upgrade"] == "websocket"
        and req.method == "GET"
    ):
        ws_server = web.WebSocketResponse()
        await ws_server.prepare(req)
        print(f"##### WS_SERVER {ws_server}")

        client_session = ClientSession(cookies=req.cookies)

        path_qs_cleaned = req.path_qs.removeprefix("/devenv")
        async with client_session.ws_connect(base_url + path_qs_cleaned) as ws_client:
            print(f"##### WS_CLIENT {ws_client}")

            async def wsforward(ws_from, ws_to):
                async for msg in ws_from:
                    print(f">>> msg: {msg}")
                    mt = msg.type
                    md = msg.data
                    if mt == WSMsgType.TEXT:
                        await ws_to.send_str(md)
                    elif mt == WSMsgType.BINARY:
                        await ws_to.send_bytes(md)
                    elif mt == WSMsgType.PING:
                        await ws_to.ping()
                    elif mt == WSMsgType.PONG:
                        await ws_to.pong()
                    elif ws_to.closed:
                        await ws_to.close(code=ws_to.close_code, message=msg.extra)
                    else:
                        raise ValueError(f"unexpected message type: {msg}")

            await asyncio.wait(
                [wsforward(ws_server, ws_client), wsforward(ws_client, ws_server)],
                return_when=asyncio.FIRST_COMPLETED,
            )

            return ws_server
    else:  # Do http proxy
        proxyPath = req.path_qs
        if proxyPath != "":
            proxyPath = (
                proxyPath.removeprefix("/devenv")
                .removeprefix("devenv")
                .removeprefix("/")
            )
            proxyPath = "/" + proxyPath
        async with client.request(
            req.method,
            base_url + proxyPath,
            allow_redirects=False,
            data=await req.read(),
        ) as res:
            headers = res.headers.copy()
            headers["service-worker-allowed"] = "/"
            body = await res.read()
            return web.Response(headers=headers, status=res.status, body=body)
コード例 #25
0
ファイル: driver.py プロジェクト: limoo-im/python-sdk
class LimooDriver:

    _ALLOWED_CONNECTION_ATTEMPTS = 1000000
    _RETRY_DELAY = 2

    @staticmethod
    async def _receive_event(ws):
        while True:
            try:
                return await ws.receive_json()
            except ValueError:
                continue

    @staticmethod
    async def _get_text_body(response):
        try:
            return await response.text()
        except (ClientConnectionError, ClientPayloadError) as ex:
            raise LimooError from ex
        finally:
            await response.release()

    @staticmethod
    async def _get_json_body(response):
        response_text = await LimooDriver._get_text_body(response)
        try:
            return json.loads(response_text)
        except json.JSONDecodeError as ex:
            raise LimooError(
                'Response body is not valid json: {resonse_text}') from ex

    def _with_auth(coro):
        @functools.wraps(coro)
        async def wrapper(self, *args, **kwargs):
            async with self._authlock:
                authenticated = False
                previous_slc = self._successful_login_count
            while True:
                try:
                    return await coro(self, *args, **kwargs)
                except LimooAuthenticationError:
                    if authenticated:
                        raise
                    async with self._authlock:
                        if self._successful_login_count == previous_slc:
                            await self._login()
                            self._successful_login_count += 1
                            authenticated = True
                        previous_slc = self._successful_login_count

        return wrapper

    def __init__(self, limoo_url, bot_username, bot_password, secure=True):
        # Catch a relatively common mistake and report an informative error
        assert not limoo_url.startswith(('http://', 'https://')), (
            'The URL of the Limoo server should not start with'
            f' "http://" or "https://". The received URL was "{limoo_url}"')
        self._credentials = {
            'j_username': bot_username,
            'j_password': bot_password,
        }
        if limoo_url.endswith('/'):
            limoo_url = limoo_url[:-1]
        http_url = f'http{"s" if secure else ""}://{limoo_url}'
        ws_url = f'ws{"s" if secure else ""}://{limoo_url}'
        self._login_url = f'{http_url}/Limonad/j_spring_security_check'
        self._api_url_prefix = f'{http_url}/Limonad/api/v1'
        self._fileop_url = f'{http_url}/fileserver/api/v1/files'
        self._websocket_url = f'{ws_url}/Limonad/websocket'
        self._client_session = ClientSession(cookie_jar=CookieJar(unsafe=True))
        self._successful_login_count = 0
        self._authlock = asyncio.Lock()
        self._listen_task = None
        self._event_handler = lambda event: None
        self.conversations = Conversations(self)
        self.files = Files(self)
        self.messages = Messages(self)
        self.users = Users(self)
        self.workspaces = Workspaces(self)

    async def close(self):
        if self._listen_task:
            self._listen_task.cancel()
            try:
                await self._listen_task
            except asyncio.CancelledError:
                pass
        await self._client_session.close()

    async def _login(self):
        await self._execute_request('POST',
                                    self._login_url,
                                    data=self._credentials)

    @_with_auth
    async def _execute_api_get(self, endpoint):
        return await self._execute_json_request('GET', endpoint)

    @_with_auth
    async def _execute_api_post(self, endpoint, body):
        return await self._execute_json_request('POST', endpoint, body=body)

    async def _execute_json_request(self, method, endpoint, *, body=None):
        return await self._get_json_body(await self._execute_request(
            method, f'{self._api_url_prefix}/{endpoint}', json=body))

    @_with_auth
    async def _upload_file(self, path, name, mime_type):
        formdata = FormData(quote_fields=False)
        run_async = asyncio.get_running_loop().run_in_executor
        async with contextlib.AsyncExitStack() as stack:
            file = await run_async(None, open, path, 'rb')
            stack.push_async_callback(run_async, None, file.close)
            formdata.add_field(name,
                               file,
                               content_type=mime_type,
                               filename=name)
            return await self._get_json_body(await self._execute_request(
                'POST', self._fileop_url, data=formdata))

    @_with_auth
    async def _download_file(self, hash, name):
        params = urllib.parse.urlencode({'hash': hash, 'name': name})
        return StreamReader(await self._execute_request(
            'GET', f'{self._fileop_url}?mode=download&{params}'))

    async def _execute_request(self, method, url, *, data=None, json=None):
        try:
            response = await self._client_session.request(
                method, url, data=data, json=json, params={"is_bot": "true"})
        except ClientConnectionError as ex:
            raise LimooError('Connection Error') from ex
        status = response.status
        if status < 400:
            return response
        response_text = await self._get_text_body(response)
        if status == 401:
            raise LimooAuthenticationError
        else:
            raise LimooError(
                f'Request returned unsuccessfully with status {status} and body {response_text}'
            )

    def set_event_handler(self, event_handler):
        if event_handler is not None and not callable(event_handler):
            raise ValueError(
                'event_handler must either be a callable or None.')
        self._event_handler = event_handler
        if self._event_handler and not self._listen_task:
            self._listen_task = asyncio.create_task(self._listen())
        elif not self._event_handler and self._listen_task:
            self._listen_task.cancel()
            self._listen_task = None

    async def _listen(self):
        _LOGGER.info('WebSocket listen task started.')
        cancel_ex = None
        max_delay = self._ALLOWED_CONNECTION_ATTEMPTS * (self._RETRY_DELAY + 1)
        while not cancel_ex:
            ws = None
            for retry_delay in range(self._RETRY_DELAY,
                                     max_delay + self._RETRY_DELAY,
                                     self._RETRY_DELAY):
                try:
                    ws = await self._try_connecting()
                    break
                except asyncio.CancelledError as ex:
                    cancel_ex = ex
                    break
                except Exception as ex:
                    _LOGGER.error(
                        'Connecting the WebSocket failed with the following exception: %s',
                        ex)
                    if retry_delay == max_delay:
                        break
                    else:
                        _LOGGER.info(
                            'Going to sleep for %d seconds before trying to connect a WebSocket.',
                            retry_delay)
                        try:
                            await asyncio.sleep(retry_delay)
                        except asyncio.CancelledError as ex:
                            cancel_ex = ex
                            break
            if not ws:
                break
            _LOGGER.info('Connected the WebSocket.')
            while True:
                try:
                    event = await self._receive_event(ws)
                except asyncio.CancelledError as ex:
                    cancel_ex = ex
                    break
                except Exception as ex:
                    _LOGGER.error(
                        'The connected WebSocket broke with the following exception: %s',
                        ex)
                    break
                _LOGGER.debug('Received an event.')
                try:
                    self._event_handler(event)
                except Exception as ex:
                    _LOGGER.error(
                        'Calling the event handler failed with the following exception: %s',
                        ex)
            try:
                await ws.close()
                _LOGGER.info('The WebSocket was closed.')
            except asyncio.CancelledError as ex:
                cancel_ex = ex
            except Exception as ex:
                _LOGGER.error(
                    'Closing the WebSocket failed with the following exception: %s',
                    ex)
        _LOGGER.info('WebSocket listen task ended.')
        if cancel_ex:
            raise cancel_ex

    @_with_auth
    async def _try_connecting(self):
        async with contextlib.AsyncExitStack() as stack:
            ws = await stack.enter_async_context(
                self._client_session.ws_connect(self._websocket_url,
                                                receive_timeout=70,
                                                heartbeat=60))
            event = await self._receive_event(ws)
            if event.get('event') == 'authentication_failed':
                raise LimooAuthenticationError
            else:
                stack.pop_all()
                return ws

    del _with_auth